Esempio n. 1
0
    def __init__(self, dim, padding_type, norm_layer, use_naive, use_bias,
                 coupling):
        super(ReversibleResnetBlock, self).__init__()
        F = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias)
        G = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias)

        if use_naive:
            self.rev_block = ReversibleBlock(F,
                                             G,
                                             coupling,
                                             keep_input=True,
                                             implementation_fwd=2,
                                             implementation_bwd=2)
        else:
            self.rev_block = ReversibleBlock(F, G, coupling)
Esempio n. 2
0
class RevBlock3d(nn.Module):
    def __init__(self, dim, use_bias, norm_layer, use_naive):
        super(RevBlock3d, self).__init__()
        self.F = self.build_conv_block(dim // 2, True, norm_layer)
        self.G = self.build_conv_block(dim // 2, True, norm_layer)
        if use_naive:
            self.rev_block = ReversibleBlock(F,
                                             G,
                                             'additive',
                                             keep_input=True,
                                             implementation_fwd=2,
                                             implementation_bwd=2)
        else:
            self.rev_block = ReversibleBlock(F, self.G, 'additive')

    def build_conv_block(self, dim, use_bias, norm_layer):
        conv_block = []
        conv_block += [nn.ReplicationPad3d(1)]
        conv_block += [
            nn.Conv3d(dim, dim, kernel_size=3, padding=0, bias=use_bias)
        ]
        conv_block += [norm_layer(dim)]
        conv_block += [nn.ReLU(True)]
        conv_block += [nn.ReplicationPad3d(1)]
        conv_block += [
            ZeroInit(dim, dim, kernel_size=3, padding=0, bias=use_bias)
        ]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        return self.rev_block(x)

    def inverse(self, x):
        return self.rev_block.inverse(x)
Esempio n. 3
0
class ReversibleResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_naive, use_bias,
                 coupling):
        super(ReversibleResnetBlock, self).__init__()
        F = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias)
        G = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias)

        if use_naive:
            self.rev_block = ReversibleBlock(F,
                                             G,
                                             coupling,
                                             keep_input=True,
                                             implementation_fwd=2,
                                             implementation_bwd=2)
        else:
            self.rev_block = ReversibleBlock(F, G, coupling)

    def build_conv_block(self, dim, padding_type, norm_layer, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' %
                                      padding_type)

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
            norm_layer(dim),
            nn.ReLU(True)
        ]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' %
                                      padding_type)
        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
            norm_layer(dim)
        ]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        return self.rev_block(x)

    def inverse(self, x):
        return self.rev_block.inverse(x)
Esempio n. 4
0
 def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False):
     super(RevBottleneck, self).__init__()
     if downsample is None and stride == 1:
         gm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation)
         fm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation)
         self.revblock = ReversibleBlock(gm, fm)
     else:
         self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation)
     self.downsample = downsample
     self.stride = stride
Esempio n. 5
0
class ReversibleConvBlock(nn.Module):
    def __init__(self,
                 dim,
                 padding_type,
                 norm_layer,
                 use_dropout,
                 use_bias,
                 coupling,
                 kernel_size=3):
        super(ReversibleConvBlock, self).__init__()
        F = self.build_conv_block(dim // 2, padding_type, norm_layer,
                                  use_dropout, use_bias, kernel_size)
        G = self.build_conv_block(dim // 2, padding_type, norm_layer,
                                  use_dropout, use_bias, kernel_size)
        self.rev_block = ReversibleBlock(F, G, coupling)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,
                         use_bias, kernel_size):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(kernel_size // 2)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(kernel_size // 2)]
        elif padding_type == 'zero':
            p = kernel_size // 2
        else:
            raise NotImplementedError('padding [%s] is not implemented' %
                                      padding_type)

        conv_block += [
            nn.Conv2d(dim,
                      dim,
                      kernel_size=kernel_size,
                      padding=p,
                      bias=use_bias),
            norm_layer(dim),
            nn.ReLU(True)
        ]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        return self.rev_block(x)

    def inverse(self, x):
        return self.rev_block.inverse(x)
Esempio n. 6
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              downsample=None,
              noactivation=False):
     super(RevBasicBlock, self).__init__()
     if downsample is None and stride == 1:
         Gm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         Fm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         self.revblock = ReversibleBlock(Gm, Fm)
     else:
         self.basicblock_sub = BasicBlockSub(inplanes, planes, stride,
                                             noactivation)
     self.downsample = downsample
     self.stride = stride