File size: 2,151 Bytes
ed46d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn.functional as F
from torch import nn


class TwoWayReLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z, temperature=1.0):
        ctx.save_for_backward(z)
        ctx.temperature = temperature
        return F.relu(z)

    @staticmethod
    def backward(ctx, grad_output):
        (z,) = ctx.saved_tensors
        temp = ctx.temperature

        gate = F.sigmoid(z / temp)

        return grad_output * gate, None


class TwoWayReLU(nn.Module):
    def __init__(self, temperature=1.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, x):
        return TwoWayReLUFunction.apply(x, self.temperature)

    def extra_repr(self):
        return f"temperature={self.temperature}"


class SoftMaxPool2d(nn.MaxPool2d):
    def __init__(self, *args, temperature=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature

    def forward(self, x):
        B, C, H, W = x.shape
        kH, kW = self.kernel_size, self.kernel_size

        # Unfold input to patches
        x_unf = F.unfold(
            x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
        )
        x_unf = x_unf.view(B, C, kH * kW, -1)

        # Softmax pooling over spatial positions
        weights = F.softmax(x_unf / self.temperature, dim=2)
        pooled = (x_unf * weights).sum(dim=2)

        # Reshape back to image
        out_H = (H + 2 * self.padding - kH) // self.stride + 1
        out_W = (W + 2 * self.padding - kW) // self.stride + 1
        return pooled.view(B, C, out_H, out_W)

    def extra_repr(self):
        ret = super().extra_repr()
        return f"{ret}, temperature={self.temperature}"


class SurrogateSoftMaxPool2d(SoftMaxPool2d):
    def forward(self, x):
        soft = super().forward(x)
        hard = F.max_pool2d(
            x,
            self.kernel_size,
            self.stride,
            self.padding,
            self.dilation,
            ceil_mode=self.ceil_mode,
            return_indices=self.return_indices,
        )

        return hard.detach() + (soft - soft.detach())