import torch
import numpy as np
import logging
log = logging.getLogger(__name__)
[docs]class ModuleListMerge(torch.nn.Sequential):
[docs] def forward(self, input):
return [module.forward(input) for module in self._modules.values()]
[docs]class ModuleConcat(ModuleListMerge):
[docs] def forward(self, input):
tt_list = super().forward(input)
assert len(set([tuple(list(tt.size())[2:]) for tt in tt_list])) == 1, (
"Sizes of tensors must match except in dimension 1. "
"\n{}\n got tensor sizes: \n{}\n".format(
self, [tt.size() for tt in tt_list]
)
)
return torch.cat(tt_list, dim=1)
def _check_size_match(self, tt_list):
assert (
len(set([tuple(list(tt.size())) for tt in tt_list])) == 1
), "Sizes of tensors must match. " "\n{}\n got tensor sizes: \n{}\n".format(
self, [tt.size() for tt in tt_list]
)
[docs]def element_wise_sum(tt_list):
return torch.sum(torch.stack(tt_list), dim=0)
[docs]class ModuleSum(ModuleListMerge):
[docs] def forward(self, input):
tt_list = super().forward(input)
_check_size_match(self, tt_list)
return element_wise_sum(tt_list)
[docs]def element_wise_average(tt_list):
return torch.mean(torch.stack(tt_list), dim=0)
[docs]class ModuleAvg(ModuleListMerge):
[docs] def forward(self, input):
tt_list = super().forward(input)
_check_size_match(self, tt_list)
return element_wise_average(tt_list)
[docs]def element_wise_prod(tt_list):
return torch.prod(torch.stack(tt_list), dim=0)
[docs]class ModuleProd(ModuleListMerge):
[docs] def forward(self, input):
tt_list = super().forward(input)
_check_size_match(self, tt_list)
return element_wise_prod(tt_list)
[docs]class StatModule(torch.nn.Module):
[docs] def __init__(self, dim, keepdim=False):
if isinstance(dim, list):
dim = tuple(dim)
if isinstance(dim, int):
dim = (dim,)
assert isinstance(dim, tuple)
self.dim = dim
self.keepdim = keepdim
super().__init__()
[docs]class Pool1dMixIn:
[docs] def __init__(self, keepdim=False):
super().__init__(dim=(2,), keepdim=keepdim)
[docs]class Pool2dMixIn:
[docs] def __init__(self, keepdim=False):
super().__init__(dim=(3, 2), keepdim=keepdim)
[docs]class Pool3dMixIn:
[docs] def __init__(self, keepdim=False):
super().__init__(dim=(4, 3, 2), keepdim=keepdim)
[docs]class TensorMean(StatModule):
[docs] def forward(self, input):
return torch.mean(input, dim=self.dim, keepdim=self.keepdim)
[docs]class TensorGlobalAvgPool1d(Pool1dMixIn, TensorMean):
pass
[docs]class TensorGlobalAvgPool2d(Pool2dMixIn, TensorMean):
pass
[docs]class TensorGlobalAvgPool3d(Pool3dMixIn, TensorMean):
pass
[docs]class TensorSum(StatModule):
[docs] def forward(self, input):
return torch.sum(input, dim=self.dim, keepdim=self.keepdim)
[docs]class TensorGlobalSumPool1d(Pool1dMixIn, TensorSum):
pass
[docs]class TensorGlobalSumPool2d(Pool2dMixIn, TensorSum):
pass
[docs]class TensorGlobalSumPool3d(Pool3dMixIn, TensorSum):
pass
[docs]class TensorMax(StatModule, torch.nn.Module):
[docs] def forward(self, input):
return tensor_max(input, dim=self.dim, keepdim=self.keepdim)
[docs]def tensor_max(input, dim, keepdim=False):
if isinstance(dim, int):
return torch.max(input, dim=dim, keepdim=keepdim)[0]
else:
if isinstance(dim, tuple):
dim = list(dim)
for d in dim:
input = torch.max(input, dim=d, keepdim=keepdim)[0]
return input
[docs]class TensorGlobalMaxPool1d(Pool1dMixIn, TensorMax):
pass
[docs]class TensorGlobalMaxPool2d(Pool2dMixIn, TensorMax):
pass
[docs]class TensorGlobalMaxPool3d(Pool3dMixIn, TensorMax):
pass
[docs]class TensorMin(StatModule, torch.nn.Module):
[docs] def forward(self, input):
return tensor_min(input, dim=self.dim, keepdim=self.keepdim)
[docs]def tensor_min(input, dim, keepdim=False):
if isinstance(dim, int):
return torch.min(input, dim=dim, keepdim=keepdim)[0]
else:
if isinstance(dim, tuple):
dim = list(dim)
for d in dim:
input = torch.min(input, dim=d, keepdim=keepdim)[0]
return input
[docs]class TensorGlobalMinPool1d(Pool1dMixIn, TensorMin):
pass
[docs]class TensorGlobalMinPool2d(Pool2dMixIn, TensorMin):
pass
[docs]class TensorGlobalMinPool3d(Pool3dMixIn, TensorMin):
pass
[docs]class TensorRange(StatModule, torch.nn.Module):
[docs] def forward(self, input):
return tensor_max(input, dim=self.dim, keepdim=self.keepdim) - tensor_min(
input, dim=self.dim, keepdim=self.keepdim
)
[docs]class TensorGlobalRangePool1d(Pool1dMixIn, TensorRange):
pass
[docs]class TensorGlobalRangePool2d(Pool2dMixIn, TensorRange):
pass
[docs]class TensorGlobalRangePool3d(Pool3dMixIn, TensorRange):
pass
[docs]def to_array(input):
if not isinstance(input, (tuple, list)):
input = [input]
input = np.array(input)
return input
[docs]def as_tuple(x):
return tuple(x) if isinstance(x, (list, type(np.array))) else x
[docs]def setup_conv_params(
kernel_size=1,
dilation=None,
padding=None,
stride=None,
raise_error=False,
*args,
**kwargs
):
kwargs["kernel_size"] = as_tuple(kernel_size)
if dilation is not None:
kwargs["dilation"] = as_tuple(dilation)
if padding is None:
d = dilation or 1
d = to_array(d)
k = to_array(kernel_size)
p, r = np.divmod(d * (k - 1), 2)
if raise_error and r:
raise ValueError(
"Invalid combination of kernel_size: {}, dilation: {}. "
"If dilation is odd, kernel_size must be even.".format(
kernel_size, dilation
)
)
kwargs["padding"] = tuple(p)
else:
kwargs["padding"] = as_tuple(padding)
if stride is not None:
kwargs["stride"] = as_tuple(stride)
return args, kwargs
batchnorm_dict = {
"1": torch.nn.BatchNorm1d,
"2": torch.nn.BatchNorm2d,
"3": torch.nn.BatchNorm3d,
}
[docs]class ModuleConvWrap(torch.nn.Sequential):
core = None
[docs] def __init__(self, batchnorm=None, activation=None, *args, **kwargs):
args, kwargs = setup_conv_params(*args, **kwargs)
module = self.core(*args, **kwargs)
modules = [module]
if batchnorm:
if len(args) >= 2:
out_channels = args[1]
else:
out_channels = kwargs["out_channels"]
dim_str = self.core.__name__[-2]
batchnorm_obj = batchnorm_dict[dim_str]
if isinstance(batchnorm, dict):
batchnorm_module = batchnorm_obj(num_features=out_channels, **batchnorm)
else:
batchnorm_module = batchnorm_obj(num_features=out_channels)
modules.append(batchnorm_module)
if activation:
if isinstance(activation, str):
activation = getattr(torch.nn, activation)()
modules.append(activation)
super().__init__(*modules)
[docs]class TensorConv1d(ModuleConvWrap):
core = torch.nn.Conv1d
[docs]class TensorConv2d(ModuleConvWrap):
core = torch.nn.Conv2d
[docs]class TensorConv3d(ModuleConvWrap):
core = torch.nn.Conv3d
[docs]class TensorMaxPool1d(ModuleConvWrap):
core = torch.nn.MaxPool1d
[docs]class TensorMaxPool2d(ModuleConvWrap):
core = torch.nn.MaxPool2d
[docs]class TensorMaxPool3d(ModuleConvWrap):
core = torch.nn.MaxPool3d
[docs]class TensorAvgPool1d(ModuleConvWrap):
core = torch.nn.AvgPool1d
[docs]class TensorAvgPool2d(ModuleConvWrap):
core = torch.nn.AvgPool2d
[docs]class TensorAvgPool3d(ModuleConvWrap):
core = torch.nn.AvgPool3d
[docs]class ModuleBottleneck2d(torch.nn.Sequential):
[docs] def __init__(
self,
in_channels,
out_channels,
kernel_size=(1, 1),
stride=(1, 1),
mid_channels=None,
batch_norm=None,
activation=None,
**kwargs
):
mid_channels = mid_channels or in_channels // 2 or 1
batch_norm = batch_norm or TensorSkip()
activation = activation or TensorSkip()
super().__init__(
TensorConv2d(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=(1, 1),
stride=(1, 1),
**kwargs
),
batch_norm,
activation,
TensorConv2d(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
**kwargs
),
batch_norm,
activation,
TensorConv2d(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=(1, 1),
**kwargs
),
)
[docs]class TensorSkip(torch.nn.Module):
[docs] def forward(self, input):
return input
[docs]class TensorIdentity(torch.nn.Module):
[docs] def forward(self, input):
return input
[docs]class ModuleConcatSkip(ModuleConcat):
[docs] def __init__(self, *modules):
super().__init__(TensorIdentity(), torch.nn.Sequential(*modules))
[docs]class ModuleSumSkip(ModuleSum):
[docs] def __init__(self, *modules):
super().__init__(TensorIdentity(), torch.nn.Sequential(*modules))
[docs]class TensorForward(torch.nn.Module):
[docs] def __init__(self, func=None):
func = func or (lambda x: x)
assert callable(func)
self._func = func
[docs] def forward(self, input):
return self._func(input)
[docs]class TensorConstantLinear(torch.nn.Module):
[docs] def __init__(self, weight=1, bias=0):
self.weight = weight
self.bias = bias
super().__init__()
[docs] def forward(self, input):
return self.weight * input + self.bias
[docs]class TensorExp(torch.nn.Module):
[docs] def forward(self, input):
return torch.exp(input)
[docs]class TensorLog(torch.nn.Module):
[docs] def forward(self, input):
return torch.log(input)
[docs]class TensorFlatten(torch.nn.Module):
[docs] def forward(self, input):
return input.view(input.size(0), -1)
[docs]class TensorSqueeze(torch.nn.Module):
[docs] def __init__(self, dim=None):
super().__init__()
self.dim = dim
[docs] def forward(self, input):
return torch.squeeze(input, dim=self.dim)
[docs]class TensorUnsqueeze(torch.nn.Module):
[docs] def __init__(self, dim):
super().__init__()
self.dim = dim
[docs] def forward(self, input):
return torch.unsqueeze(input, dim=self.dim)
[docs]class TensorSlice(torch.nn.Module):
[docs] def __init__(self, start=0, end=None, step=1):
super().__init__()
self.start = start
self.end = end
self.step = step
[docs] def forward(self, input):
return input[:, self.start : (self.end or input.shape[1]) : self.step, ...]
[docs]def step_binary(input, output_size, compare=torch.ge):
index_1dtt = input.type(dtype=torch.long)
h_1dtt = torch.arange(output_size)
h_2dtt = compare(h_1dtt.reshape(1, -1), index_1dtt.reshape(-1, 1))
return h_2dtt
[docs]class StepBinary(torch.nn.Module):
[docs] def __init__(self, size, desc=False, compare=None, dtype=None):
super().__init__()
assert isinstance(size, int)
self.out_size = size
if compare is None:
assert isinstance(desc, bool)
desc_dict = {False: torch.ge, True: torch.le}
compare = desc_dict.get(desc)
else:
assert not desc, "'desc' and 'compare' cannot be specified together."
self.compare = compare
self.dtype = dtype
[docs] def forward(self, input):
output = step_binary(input, self.out_size, self.compare)
dtype = self.dtype or input.type()
return output.type(dtype=dtype)
[docs]class TensorNearestPad(torch.nn.Module):
[docs] def __init__(self, lower=1, upper=1):
super().__init__()
assert isinstance(lower, int) and lower >= 0
assert isinstance(upper, int) and upper >= 0
self.lower = lower
self.upper = upper
[docs] def forward(self, input):
return torch.cat(
[
input[:, :1].expand(-1, self.lower),
input,
input[:, -1:].expand(-1, self.upper),
],
dim=1,
)
[docs]class TensorCumsum(torch.nn.Module):
[docs] def __init__(self, dim=1):
super().__init__()
self.dim = dim
[docs] def forward(self, input):
return torch.cumsum(input, dim=self.dim)
[docs]class TensorClamp(torch.nn.Module):
[docs] def __init__(self, min=None, max=None):
super().__init__()
self.min = min
self.max = max
[docs] def forward(self, input):
return torch.clamp(input, min=self.min, max=self.max)
[docs]class TensorClampMax(torch.nn.Module):
[docs] def __init__(self, max=None):
super().__init__()
self.max = max
[docs] def forward(self, input):
return torch.clamp_max(input, max=self.max)
[docs]class TensorClampMin(torch.nn.Module):
[docs] def __init__(self, min=None):
super().__init__()
self.min = min
[docs] def forward(self, input):
return torch.clamp_min(input, min=self.min)
[docs]class TensorProba(torch.nn.Module):
[docs] def __init__(self, dim=1):
self.dim = dim
super().__init__()
[docs] def forward(self, input):
total = torch.sum(input, dim=self.dim, keepdim=True)
return input / total
[docs]def nl_loss(input, *args, **kwargs):
return torch.nn.functional.nll_loss(input.log(), *args, **kwargs)
[docs]class NLLoss(torch.nn.NLLLoss):
"""The negative likelihood loss.
To compute Cross Entropy Loss, there are 3 options.
NLLoss with torch.nn.Softmax
torch.nn.NLLLoss with torch.nn.LogSoftmax
torch.nn.CrossEntropyLoss
"""
[docs] def forward(self, input, target):
return super().forward(input.log(), target)
[docs]class CrossEntropyLoss2d(torch.nn.CrossEntropyLoss):
[docs] def forward(self, input, target):
input_hw = list(input.shape)[-2:]
target_hw = list(target.shape)[-2:]
if input_hw != target_hw:
input = torch.nn.functional.interpolate(
input, size=target_hw, mode="bilinear", align_corners=True
)
input_4dtt = to_channel_last_tensor(input)
input_2dtt = input_4dtt.reshape(-1, input_4dtt.shape[-1])
target_1dtt = target.reshape(-1)
return super().forward(input_2dtt, target_1dtt)
_to_channel_last_dict = {3: (-2, -1, -3), 4: (0, -2, -1, -3)}
[docs]def to_channel_last_tensor(a):
if a.ndim in {3, 4}:
return a.permute(*_to_channel_last_dict.get(a.ndim))
else:
return a
_to_channel_first_dict = {3: (-1, -3, -2), 4: (0, -1, -3, -2)}
[docs]def to_channel_first_tensor(a):
if a.ndim in {3, 4}:
return a.permute(*_to_channel_first_dict.get(a.ndim))
else:
return a