def _threecouplinglayers(in_channels, mid_channels, num_blocks, mask_class): layers = [ CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False)), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=True)), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False)) ] return layers
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None): super(RealNVPMaskHorizontal3Layers, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskHorizontal(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskHorizontal(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskHorizontal(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), FlatJoin() )
def _threecouplinglayers(in_channels, mid_channels, num_blocks, mask_class, init_zeros=False, st_type='resnet', use_batch_norm=True, img_width=28, skip=True, latent_dim=100): layers = [ CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_width, skip=skip, latent_dim=latent_dim), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_width, skip=skip, latent_dim=latent_dim), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_width, skip=skip, latent_dim=latent_dim) ] return layers
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=100): super(RealNVP, self).__init__() layers = [addZslot(), passThrough(iLogits())] self.output_shapes = [] _, _, img_width = img_shape for scale in range(num_scales): in_couplings = self._threecouplinglayers(in_channels, mid_channels, num_blocks, MaskCheckerboard, init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim) layers.append(passThrough(*in_couplings)) if scale == num_scales - 1: layers.append(passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim))) else: layers.append(passThrough(SqueezeLayer(2))) img_width = img_width // 2 if st_type != 'autoencoder': # in the autoencoder case we probably want the bottleneck size to be fixed? mid_channels *= 2 out_couplings = self._threecouplinglayers(4*in_channels, mid_channels, num_blocks, MaskChannelwise, init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim) layers.append(passThrough(*out_couplings)) layers.append(keepChannels(2*in_channels)) in_channels *= 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8): super(RealNVP, self).__init__() layers = [addZslot(), passThrough(iLogits())] for scale in range(num_scales): in_couplings = self._threecouplinglayers(in_channels, mid_channels, num_blocks, MaskCheckerboard) layers.append(passThrough(*in_couplings)) if scale == num_scales - 1: layers.append( passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True)))) else: layers.append(passThrough(SqueezeLayer(2))) out_couplings = self._threecouplinglayers( 4 * in_channels, 2 * mid_channels, num_blocks, MaskChannelwise) layers.append(passThrough(*out_couplings)) layers.append(keepChannels(2 * in_channels)) in_channels *= 2 mid_channels *= 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def _threecouplinglayers(in_channels, mid_channels, num_blocks, mask_class): layers = [ ActNorm2d(in_channels), iConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False)), ActNorm2d(in_channels), iConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=True)), ActNorm2d(in_channels), iConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=False)) ] return layers
def _glow_step(in_channels, mid_channels, num_blocks): layers = [ ActNorm(in_channels), iConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False)), ] return layers
def _glow_step(in_channels, mid_channels, actnorm_scale, st_type, num_layers): layers = [ ActNorm2d(in_channels, actnorm_scale), InvertibleConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_layers, MaskChannelwise(reverse_mask=False), st_type=st_type), ] return layers
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None): super(RealNVPCycleMask, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), FlatJoin() )
def _threecouplinglayers(in_channels, mid_channels, num_blocks, mask_class, num_classes=10): layers = [] for i in range(3): layers.append( passThrough( iConv1x1(in_channels), CouplingLayer(in_channels, mid_channels, num_blocks, mask_class(reverse_mask=not i % 2)), ActNorm2d(in_channels))) layers.append(iCategoricalFiLM(num_classes, in_channels)) return layers
def __init__(self, in_channels=1, mid_channels=64, num_blocks=4): super(RealNVPMNIST, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough(SqueezeLayer(2)), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=True))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), keepChannels(2 * in_channels), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough(SqueezeLayer(2)), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=True))), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), keepChannels(4 * in_channels), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), FlatJoin())