class _UpsampleConv3dBase(_nn.ConvNormAct3d): upsample_constructor = _nn.LayerPartial(_torch.nn.Upsample, scale_factor=2, align_corners=True, mode='trilinear') def __init__(self, *args, **kwargs): super(_UpsampleConv3dBase, self).__init__(*args, **kwargs) self.upsampling_layer = self.upsample_constructor() self.out_channels = self.convolution_layer.out_channels def main_forward(self, x): x = self.upsampling_layer(x) x = super(_UpsampleConv3dBase, self).main_forward(x) return x
class InstanceNorm3d(_BINormCommon): layer_constructor = _nn.LayerPartial(_torch.nn.InstanceNorm3d, **_BINormCommon.default_kwargs)
class BatchNorm2d(_BINormCommon): layer_constructor = _nn.LayerPartial(_torch.nn.BatchNorm2d, **_BINormCommon.default_kwargs)