| import torch | |
| from torch import nn | |
| def normalization(channels, norm_type="group", num_groups=32): | |
| if norm_type == "batch": | |
| return nn.BatchNorm3d(channels) | |
| elif norm_type == "group": | |
| return nn.GroupNorm(num_groups=num_groups, num_channels=channels) | |
| elif (not norm_type) or (norm_type.tolower() == 'none'): | |
| return nn.Identity() | |
| else: | |
| raise NotImplementedError(norm) | |
| def activation(act_type="swish"): | |
| if act_type == "swish": | |
| return nn.SiLU() | |
| elif act_type == "gelu": | |
| return nn.GELU() | |
| elif act_type == "relu": | |
| return nn.ReLU() | |
| elif act_type == "tanh": | |
| return nn.Tanh() | |
| elif not act_type: | |
| return nn.Identity() | |
| else: | |
| raise NotImplementedError(act_type) | |