Spaces:
Running on Zero
Running on Zero
File size: 2,151 Bytes
ed46d32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import torch
import torch.nn.functional as F
from torch import nn
class TwoWayReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, z, temperature=1.0):
ctx.save_for_backward(z)
ctx.temperature = temperature
return F.relu(z)
@staticmethod
def backward(ctx, grad_output):
(z,) = ctx.saved_tensors
temp = ctx.temperature
gate = F.sigmoid(z / temp)
return grad_output * gate, None
class TwoWayReLU(nn.Module):
def __init__(self, temperature=1.0):
super().__init__()
self.temperature = temperature
def forward(self, x):
return TwoWayReLUFunction.apply(x, self.temperature)
def extra_repr(self):
return f"temperature={self.temperature}"
class SoftMaxPool2d(nn.MaxPool2d):
def __init__(self, *args, temperature=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.temperature = temperature
def forward(self, x):
B, C, H, W = x.shape
kH, kW = self.kernel_size, self.kernel_size
# Unfold input to patches
x_unf = F.unfold(
x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
)
x_unf = x_unf.view(B, C, kH * kW, -1)
# Softmax pooling over spatial positions
weights = F.softmax(x_unf / self.temperature, dim=2)
pooled = (x_unf * weights).sum(dim=2)
# Reshape back to image
out_H = (H + 2 * self.padding - kH) // self.stride + 1
out_W = (W + 2 * self.padding - kW) // self.stride + 1
return pooled.view(B, C, out_H, out_W)
def extra_repr(self):
ret = super().extra_repr()
return f"{ret}, temperature={self.temperature}"
class SurrogateSoftMaxPool2d(SoftMaxPool2d):
def forward(self, x):
soft = super().forward(x)
hard = F.max_pool2d(
x,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
ceil_mode=self.ceil_mode,
return_indices=self.return_indices,
)
return hard.detach() + (soft - soft.detach())
|