Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class TwoWayReLUFunction(torch.autograd.Function): | |
| def forward(ctx, z, temperature=1.0): | |
| ctx.save_for_backward(z) | |
| ctx.temperature = temperature | |
| return F.relu(z) | |
| def backward(ctx, grad_output): | |
| (z,) = ctx.saved_tensors | |
| temp = ctx.temperature | |
| gate = F.sigmoid(z / temp) | |
| return grad_output * gate, None | |
| class TwoWayReLU(nn.Module): | |
| def __init__(self, temperature=1.0): | |
| super().__init__() | |
| self.temperature = temperature | |
| def forward(self, x): | |
| return TwoWayReLUFunction.apply(x, self.temperature) | |
| def extra_repr(self): | |
| return f"temperature={self.temperature}" | |
| class SoftMaxPool2d(nn.MaxPool2d): | |
| def __init__(self, *args, temperature=1.0, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.temperature = temperature | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| kH, kW = self.kernel_size, self.kernel_size | |
| # Unfold input to patches | |
| x_unf = F.unfold( | |
| x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding | |
| ) | |
| x_unf = x_unf.view(B, C, kH * kW, -1) | |
| # Softmax pooling over spatial positions | |
| weights = F.softmax(x_unf / self.temperature, dim=2) | |
| pooled = (x_unf * weights).sum(dim=2) | |
| # Reshape back to image | |
| out_H = (H + 2 * self.padding - kH) // self.stride + 1 | |
| out_W = (W + 2 * self.padding - kW) // self.stride + 1 | |
| return pooled.view(B, C, out_H, out_W) | |
| def extra_repr(self): | |
| ret = super().extra_repr() | |
| return f"{ret}, temperature={self.temperature}" | |
| class SurrogateSoftMaxPool2d(SoftMaxPool2d): | |
| def forward(self, x): | |
| soft = super().forward(x) | |
| hard = F.max_pool2d( | |
| x, | |
| self.kernel_size, | |
| self.stride, | |
| self.padding, | |
| self.dilation, | |
| ceil_mode=self.ceil_mode, | |
| return_indices=self.return_indices, | |
| ) | |
| return hard.detach() + (soft - soft.detach()) | |