SemanticPullbacks / lib /modules.py
Maciej Satkiewicz
app works locally
ed46d32
import torch
import torch.nn.functional as F
from torch import nn
class TwoWayReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, z, temperature=1.0):
ctx.save_for_backward(z)
ctx.temperature = temperature
return F.relu(z)
@staticmethod
def backward(ctx, grad_output):
(z,) = ctx.saved_tensors
temp = ctx.temperature
gate = F.sigmoid(z / temp)
return grad_output * gate, None
class TwoWayReLU(nn.Module):
def __init__(self, temperature=1.0):
super().__init__()
self.temperature = temperature
def forward(self, x):
return TwoWayReLUFunction.apply(x, self.temperature)
def extra_repr(self):
return f"temperature={self.temperature}"
class SoftMaxPool2d(nn.MaxPool2d):
def __init__(self, *args, temperature=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.temperature = temperature
def forward(self, x):
B, C, H, W = x.shape
kH, kW = self.kernel_size, self.kernel_size
# Unfold input to patches
x_unf = F.unfold(
x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
)
x_unf = x_unf.view(B, C, kH * kW, -1)
# Softmax pooling over spatial positions
weights = F.softmax(x_unf / self.temperature, dim=2)
pooled = (x_unf * weights).sum(dim=2)
# Reshape back to image
out_H = (H + 2 * self.padding - kH) // self.stride + 1
out_W = (W + 2 * self.padding - kW) // self.stride + 1
return pooled.view(B, C, out_H, out_W)
def extra_repr(self):
ret = super().extra_repr()
return f"{ret}, temperature={self.temperature}"
class SurrogateSoftMaxPool2d(SoftMaxPool2d):
def forward(self, x):
soft = super().forward(x)
hard = F.max_pool2d(
x,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
ceil_mode=self.ceil_mode,
return_indices=self.return_indices,
)
return hard.detach() + (soft - soft.detach())