ldcast_code / ldcast /models /utils.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
import torch
from torch import nn
def normalization(channels, norm_type="group", num_groups=32):
if norm_type == "batch":
return nn.BatchNorm3d(channels)
elif norm_type == "group":
return nn.GroupNorm(num_groups=num_groups, num_channels=channels)
elif (not norm_type) or (norm_type.tolower() == 'none'):
return nn.Identity()
else:
raise NotImplementedError(norm)
def activation(act_type="swish"):
if act_type == "swish":
return nn.SiLU()
elif act_type == "gelu":
return nn.GELU()
elif act_type == "relu":
return nn.ReLU()
elif act_type == "tanh":
return nn.Tanh()
elif not act_type:
return nn.Identity()
else:
raise NotImplementedError(act_type)