class _UpsampleConv3dBase(_nn.ConvNormAct3d):
    upsample_constructor = _nn.LayerPartial(_torch.nn.Upsample,
                                            scale_factor=2,
                                            align_corners=True,
                                            mode='trilinear')

    def __init__(self, *args, **kwargs):
        super(_UpsampleConv3dBase, self).__init__(*args, **kwargs)
        self.upsampling_layer = self.upsample_constructor()
        self.out_channels = self.convolution_layer.out_channels

    def main_forward(self, x):
        x = self.upsampling_layer(x)
        x = super(_UpsampleConv3dBase, self).main_forward(x)
        return x
class InstanceNorm3d(_BINormCommon):
    layer_constructor = _nn.LayerPartial(_torch.nn.InstanceNorm3d,
                                         **_BINormCommon.default_kwargs)
class BatchNorm2d(_BINormCommon):
    layer_constructor = _nn.LayerPartial(_torch.nn.BatchNorm2d,
                                         **_BINormCommon.default_kwargs)