def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8): super(RealNVP, self).__init__() layers = [addZslot(), passThrough(iLogits())] for scale in range(num_scales): in_couplings = self._threecouplinglayers(in_channels, mid_channels, num_blocks, MaskCheckerboard) layers.append(passThrough(*in_couplings)) if scale == num_scales - 1: layers.append( passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True)))) else: layers.append(passThrough(SqueezeLayer(2))) out_couplings = self._threecouplinglayers( 4 * in_channels, 2 * mid_channels, num_blocks, MaskChannelwise) layers.append(passThrough(*out_couplings)) layers.append(keepChannels(2 * in_channels)) in_channels *= 2 mid_channels *= 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=100): super(RealNVP, self).__init__() layers = [addZslot(), passThrough(iLogits())] self.output_shapes = [] _, _, img_width = img_shape for scale in range(num_scales): in_couplings = self._threecouplinglayers(in_channels, mid_channels, num_blocks, MaskCheckerboard, init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim) layers.append(passThrough(*in_couplings)) if scale == num_scales - 1: layers.append(passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim))) else: layers.append(passThrough(SqueezeLayer(2))) img_width = img_width // 2 if st_type != 'autoencoder': # in the autoencoder case we probably want the bottleneck size to be fixed? mid_channels *= 2 out_couplings = self._threecouplinglayers(4*in_channels, mid_channels, num_blocks, MaskChannelwise, init_zeros, st_type, use_batch_norm, img_width, skip, latent_dim) layers.append(passThrough(*out_couplings)) layers.append(keepChannels(2*in_channels)) in_channels *= 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None): super(RealNVP8Layers, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm, img_width=img_shape[2])), FlatJoin() )
def iResBlockConv(outer_channels, inner_channels): gnet = nn.Sequential( Swish(), SpectralNormConv2d(outer_channels, inner_channels, 3, padding=1, atol=0.001, rtol=0.001, coeff=0.98, stride=1), Swish(), SpectralNormConv2d(inner_channels, inner_channels, 1, padding=0, atol=0.001, rtol=0.001, coeff=0.98, stride=1), Swish(), SpectralNormConv2d(inner_channels, outer_channels, 3, padding=1, atol=0.001, rtol=0.001, coeff=0.98, stride=1)) return iSequential(iResBlock(gnet, n_dist='poisson'), ActNorm2d(outer_channels))
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8): super(Glow, self).__init__() layers = [addZslot(), passThrough(iLogits())] for scale in range(num_scales): num_in = 4 if scale == num_scales - 1 else 3 for _ in range(num_in): layers.append( passThrough(*self._glow_step(in_channels, mid_channels, num_blocks))) layers.append(passThrough(SqueezeLayer(2))) num_out = 0 if scale == num_scales - 1 else 3 for _ in range(num_in): layers.append( passThrough(*self._glow_step(4 * in_channels, 2 * mid_channels, num_blocks))) layers.append(keepChannels(2 * in_channels)) in_channels *= 2 mid_channels *= 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def __init__(self, in_channels=1, mid_channels=64, num_blocks=4): super(RealNVPMNIST, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough(SqueezeLayer(2)), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=True))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), keepChannels(2 * in_channels), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough(SqueezeLayer(2)), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=True))), passThrough( CouplingLayer(8 * in_channels, mid_channels, num_blocks, MaskChannelwise(reverse_mask=False))), keepChannels(4 * in_channels), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=False))), passThrough( CouplingLayer(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard(reverse_mask=True))), FlatJoin())
def __init__(self, in_dim=2, num_coupling_layers=6, hidden_dim=256, num_layers=2, init_zeros=False, dropout=False): super(RealNVPTabular, self).__init__() self.body = iSequential(*[ CouplingLayerTabular( in_dim, hidden_dim, num_layers, MaskTabular(reverse_mask=bool(i%2)), init_zeros=init_zeros, dropout=dropout) for i in range(num_coupling_layers) ])
def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8, init_zeros=False, st_type='resnet', use_batch_norm=True, img_shape=(1, 28, 28), skip=True, latent_dim=None): super(RealNVPCycleMask, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=0, output_quadrant=1), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=1, output_quadrant=2), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=2, output_quadrant=3), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), passThrough(CouplingLayer(in_channels, mid_channels, num_blocks, MaskQuadrant(input_quadrant=3, output_quadrant=0), init_zeros=init_zeros, st_type=st_type, use_batch_norm=use_batch_norm)), FlatJoin() )
def iResBlockLinear(outer_channels, inner_channels): gnet = nn.Sequential( Swish(), SpectralNormLinear(outer_channels, inner_channels, atol=0.001, rtol=0.001, coeff=0.98), Swish(), SpectralNormLinear(inner_channels, inner_channels, atol=0.001, rtol=0.001, coeff=0.98), Swish(), SpectralNormLinear(inner_channels, outer_channels, atol=0.001, rtol=0.001, coeff=0.98)) return iSequential(iResBlock(gnet, n_dist='poisson'), ActNorm1d(outer_channels))
def __init__(self, image_shape, mid_channels=64, num_scales=2, num_coupling_layers_per_scale=8, num_layers=3, actnorm_scale=1., multi_scale=True, st_type='glow_convnet'): super(Glow, self).__init__() layers = [addZslot(), passThrough(iLogits())] self.output_shapes = [] C, H, W = image_shape for scale in range(num_scales): # Squeeze C, H, W = C * 4, H // 2, W // 2 layers.append(passThrough(SqueezeLayer(downscale_factor=2))) self.output_shapes.append([-1, C, H, W]) # Flow steps for _ in range(num_coupling_layers_per_scale): layers.append( passThrough(*self._glow_step(in_channels=C, mid_channels=mid_channels, actnorm_scale=actnorm_scale, st_type=st_type, num_layers=num_layers))) self.output_shapes.append([-1, C, H, W]) # Split and factor out if multi_scale: if scale < num_scales - 1: layers.append(keepChannels(C // 2)) self.output_shapes.append([-1, C // 2, H, W]) C = C // 2 layers.append(FlatJoin()) self.body = iSequential(*layers)
def __init__(self, in_channels=3, num_classes=10, k=512, num_per_block=16): super().__init__() self.num_classes = num_classes self.flow = iSequential( #iLogits(), SqueezeLayer(), *[iResBlockConv(in_channels * 4, k) for i in range(num_per_block)], SqueezeLayer(), *[ iResBlockConv(in_channels * 16, k) for i in range(num_per_block) ], SqueezeLayer(), *[ iResBlockConv(in_channels * 64, k) for i in range(num_per_block) ], Flatten(), *[iResBlockLinear(3 * 32 * 32, k // 4) for i in range(4)], ) self.k = k self.prior = lambda device: StandardNormal(3 * 32 * 32, device)
def __init__(self, in_channels=1, mid_channels=64, num_blocks=4): super(GlowMNIST, self).__init__() self.body = iSequential( addZslot(), passThrough(iLogits()), passThrough(*self._glow_step(in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(SqueezeLayer(2)), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskChannelwise)), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskChannelwise)), keepChannels(2 * in_channels), passThrough(*self._glow_step(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(2 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(SqueezeLayer(2)), passThrough(*self._glow_step(8 * in_channels, mid_channels, num_blocks, MaskChannelwise)), passThrough(*self._glow_step(8 * in_channels, mid_channels, num_blocks, MaskChannelwise)), keepChannels(4 * in_channels), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskCheckerboard)), passThrough(*self._glow_step(4 * in_channels, mid_channels, num_blocks, MaskChannelwise)), FlatJoin())