def __init__(self, dim, padding_type, norm_layer, use_naive, use_bias, coupling): super(ReversibleResnetBlock, self).__init__() F = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias) G = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias) if use_naive: self.rev_block = ReversibleBlock(F, G, coupling, keep_input=True, implementation_fwd=2, implementation_bwd=2) else: self.rev_block = ReversibleBlock(F, G, coupling)
class RevBlock3d(nn.Module): def __init__(self, dim, use_bias, norm_layer, use_naive): super(RevBlock3d, self).__init__() self.F = self.build_conv_block(dim // 2, True, norm_layer) self.G = self.build_conv_block(dim // 2, True, norm_layer) if use_naive: self.rev_block = ReversibleBlock(F, G, 'additive', keep_input=True, implementation_fwd=2, implementation_bwd=2) else: self.rev_block = ReversibleBlock(F, self.G, 'additive') def build_conv_block(self, dim, use_bias, norm_layer): conv_block = [] conv_block += [nn.ReplicationPad3d(1)] conv_block += [ nn.Conv3d(dim, dim, kernel_size=3, padding=0, bias=use_bias) ] conv_block += [norm_layer(dim)] conv_block += [nn.ReLU(True)] conv_block += [nn.ReplicationPad3d(1)] conv_block += [ ZeroInit(dim, dim, kernel_size=3, padding=0, bias=use_bias) ] return nn.Sequential(*conv_block) def forward(self, x): return self.rev_block(x) def inverse(self, x): return self.rev_block.inverse(x)
class ReversibleResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_naive, use_bias, coupling): super(ReversibleResnetBlock, self).__init__() F = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias) G = self.build_conv_block(dim // 2, padding_type, norm_layer, use_bias) if use_naive: self.rev_block = ReversibleBlock(F, G, coupling, keep_input=True, implementation_fwd=2, implementation_bwd=2) else: self.rev_block = ReversibleBlock(F, G, coupling) def build_conv_block(self, dim, padding_type, norm_layer, use_bias): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True) ] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim) ] return nn.Sequential(*conv_block) def forward(self, x): return self.rev_block(x) def inverse(self, x): return self.rev_block.inverse(x)
def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(RevBottleneck, self).__init__() if downsample is None and stride == 1: gm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation) fm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation) self.revblock = ReversibleBlock(gm, fm) else: self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride
class ReversibleConvBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, coupling, kernel_size=3): super(ReversibleConvBlock, self).__init__() F = self.build_conv_block(dim // 2, padding_type, norm_layer, use_dropout, use_bias, kernel_size) G = self.build_conv_block(dim // 2, padding_type, norm_layer, use_dropout, use_bias, kernel_size) self.rev_block = ReversibleBlock(F, G, coupling) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel_size): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(kernel_size // 2)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(kernel_size // 2)] elif padding_type == 'zero': p = kernel_size // 2 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [ nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True) ] return nn.Sequential(*conv_block) def forward(self, x): return self.rev_block(x) def inverse(self, x): return self.rev_block.inverse(x)
def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(RevBasicBlock, self).__init__() if downsample is None and stride == 1: Gm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) Fm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) self.revblock = ReversibleBlock(Gm, Fm) else: self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride