def __init__(self, in_channels, mid_channels, num_blocks, num_mixtures, dropout, use_attn=True, context_channels=0, in_lambda=None, out_lambda=None): super(TransformerNet, self).__init__() self.in_lambda = LambdaLayer(in_lambda) if in_lambda else None self.conv1 = Conv2d(in_channels + context_channels, mid_channels, kernel_size=3, padding=1) attn_blocks = [ ConvAttnBlock(mid_channels, dropout, use_attn) for _ in range(num_blocks) ] self.attn_blocks = nn.ModuleList(attn_blocks) self.conv2 = Conv2d(mid_channels, in_channels * (2 + 3 * num_mixtures), kernel_size=3, padding=1) self.out_lambda = layers.append( LambdaLayer(out_lambda)) if out_lambda else None
def __init__(self, in_size, out_size, mid_channels, activation='relu', in_lambda=None, out_lambda=None): layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) layers.append(Conv2dZeros(in_size[0], mid_channels)) if activation is not None: layers.append(act_module(activation, allow_concat=False)) layers.append( Conv2dResize(in_size=(mid_channels, in_size[1], in_size[2]), out_size=out_size)) if activation is not None: layers.append(act_module(activation, allow_concat=False)) layers.append(Conv2dZeros(out_size[0], out_size[0])) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(ResizeConvNet, self).__init__(*layers)
def __init__(self, in_channels, out_channels, mid_channels, num_layers=1, activation='relu', weight_norm=True, in_lambda=None, out_lambda=None): layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) layers += [Conv2d(in_channels, mid_channels, weight_norm=weight_norm)] for i in range(num_layers): if activation is not None: layers.append(act_module(activation, allow_concat=True)) layers.append( Conv2d(mid_channels, mid_channels, kernel_size=(1, 1), weight_norm=weight_norm)) if activation is not None: layers.append(act_module(activation, allow_concat=True)) layers.append(Conv2dZeros(mid_channels, out_channels)) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(ConvNet, self).__init__(*layers)
def __init__(self, input_size, output_size, hidden_units, activation='relu', in_lambda=None, out_lambda=None): layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) for in_size, out_size in zip([input_size] + hidden_units[:-1], hidden_units): layers.append(nn.Linear(in_size, out_size)) layers.append(act_module(activation)) layers.append(nn.Linear(hidden_units[-1], output_size)) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(MLP, self).__init__(*layers)
def __init__(self, image_shape, output_dim, num_bits, autoregressive_order='cwh', d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", kdim=None, vdim=None, attn_bias=True, output_bias=True, checkpoint_blocks=False, in_lambda=lambda x: x, out_lambda=lambda x: x): super(DecoderOnlyTransformer2d, self).__init__() self.image_shape = torch.Size(image_shape) self.autoregressive_order = autoregressive_order self.d_model = d_model self.num_layers = num_layers # Encoding layers self.encode = nn.Sequential( LambdaLayer(in_lambda), nn.Embedding(2**num_bits, d_model), PositionalEncodingImage(image_shape=image_shape, embedding_dim=d_model)) self.im2seq = Image2Seq(autoregressive_order, image_shape) self.seq2im = Seq2Image(autoregressive_order, image_shape) self.ar_shift = AutoregressiveShift(d_model) self.transformer = DecoderOnlyTransformer( d_model=d_model, nhead=nhead, num_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, kdim=kdim, vdim=vdim, attn_bias=attn_bias, checkpoint_blocks=checkpoint_blocks) self.out_linear = nn.Linear(d_model, output_dim, bias=output_bias) self.out_lambda = LambdaLayer(out_lambda) self._reset_parameters()
def test_surjection_is_well_behaved(self): batch_size = 10 shape = [8, 4, 4] num_bits_list = [2, 5, 8] for num_bits in num_bits_list: with self.subTest(num_bits=num_bits): x = torch.randint(0, 2**num_bits, (batch_size, ) + torch.Size(shape)) encoder = ConditionalInverseFlow( base_dist=DiagonalNormal(shape), transforms=[ ConditionalAffineBijection( nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), nn.Conv2d(shape[0], 2 * shape[0], kernel_size=3, padding=1))), Sigmoid() ]) surjection = VariationalDequantization(encoder, num_bits=num_bits) self.assert_surjection_is_well_behaved(surjection, x, z_shape=(batch_size, *shape), z_dtype=torch.float)
def __init__(self, channels, context_channels=None, num_blocks=1, dropout=0.0, in_lambda=None, out_lambda=None): assert num_blocks > 0 layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) layers += [ GatedConv(channels, context_channels=context_channels, dropout=dropout) for i in range(num_blocks) ] if out_lambda: layers.append(LambdaLayer(out_lambda)) super(GatedConvNet, self).__init__(*layers)
def test_layer_is_well_behaved(self): batch_size = 10 shape = (6,) x = torch.randn(batch_size, *shape) module = LambdaLayer(lambda x: 2 * x - 1) self.assert_layer_is_well_behaved(module, x) y = module(x) self.assertEqual(y, 2 * x - 1)
def __init__(self, input_size, output_size, hidden_units, activation='relu', in_lambda=None, out_lambda=None): if isinstance(hidden_units, list) == False: hidden_units = [hidden_units] layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) if len(hidden_units) > 0: for in_size, out_size in zip([input_size] + hidden_units[:-1], hidden_units): layers.append(nn.Linear(in_size, out_size)) if activation is not None: layers.append(act_module(activation)) layers.append(nn.Linear(hidden_units[-1], output_size)) else: layers.append(nn.Linear(input_size, output_size)) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(MLP, self).__init__(*layers)
def __init__(self, l_input=50, d_input=1, d_output=2, d_model=512, nhead=8, num_layers=6, dim_feedforward=512, dropout=0.1, activation="gelu", kdim=None, vdim=None, attn_bias=True, checkpoint_blocks=False, in_lambda=lambda x: x, out_lambda=lambda x: x): super(PositionalDenseTransformer, self).__init__() decoder_layer = DenseTransformerBlock(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, kdim=kdim, vdim=vdim, attn_bias=attn_bias, checkpoint=checkpoint_blocks) self.in_lambda = LambdaLayer(in_lambda) self.in_linear = nn.Linear(d_input, d_model) self.encode = PositionalEncoding1d(l_input, d_model) self.layers = _get_clones(decoder_layer, num_layers) self.out_norm = nn.LayerNorm(d_model) self.out_linear = nn.Linear(d_model, d_output) self.out_lambda = LambdaLayer(out_lambda) self.num_layers = num_layers self.d_model = d_model self.nhead = nhead self._reset_parameters()
def setUp(self): self.batch_size = 16 self.num_classes = 5 self.size = 10 self.x = torch.randint(self.num_classes, [self.batch_size, self.size]) self.context = torch.randn(self.batch_size, 20) net = nn.Sequential( nn.Linear(20, self.size * self.num_classes), LambdaLayer(lambda x: x.reshape(self.batch_size, self.size, self. num_classes))) self.distribution = ConditionalCategorical(net=net)
def __init__(self, in_channels, out_shape, mid_channels=[256, 128, 64], batch_norm=True, in_lambda=None, out_lambda=None): assert isinstance( mid_channels, list) and len(mid_channels) > 0, f"mid_channels={mid_channels}" layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) layers.append( _ConvDecoder(in_channels, out_shape=out_shape, mid_channels=mid_channels, batch_norm=batch_norm, init_weights=True)) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(ConvDecoderNet, self).__init__(*layers)
def __init__(self, in_channels, out_channels, mid_channels=[64, 128, 256], max_pool=True, batch_norm=True, in_lambda=None, out_lambda=None): assert isinstance( mid_channels, list) and len(mid_channels) > 0, f"mid_channels={mid_channels}" layers = [] if in_lambda: layers.append(LambdaLayer(in_lambda)) layers.append( _ConvEncoder(in_channels, out_channels, mid_channels=mid_channels, max_pool=max_pool, batch_norm=batch_norm)) if out_lambda: layers.append(LambdaLayer(out_lambda)) super(ConvEncoderNet, self).__init__(*layers)
def __init__(self, data_shape, num_bits, num_steps, num_context, num_blocks, mid_channels, depth, growth, dropout, gated_conv): context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**num_bits-1)-1), DenseBlock(in_channels=data_shape[0], out_channels=mid_channels, depth=4, growth=16, dropout=dropout, gated_conv=gated_conv, zero_init=False), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0), DenseBlock(in_channels=mid_channels, out_channels=num_context, depth=4, growth=16, dropout=dropout, gated_conv=gated_conv, zero_init=False)) transforms = [] sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2) for i in range(num_steps): transforms.extend([ Conv1x1(sample_shape[0]), ConditionalCoupling(in_channels=sample_shape[0], num_context=num_context, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv) ]) # Final shuffle of channels, squeeze and sigmoid transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid() ]) super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape), transforms=transforms, context_init=context_net)
def __init__(self, in_channels, num_params, filters=128, num_blocks=15, output_filters=1024, kernel_size=3, kernel_size_in=7, init_transforms=lambda x: 2 * x - 1): layers = [LambdaLayer(init_transforms)] +\ [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\ [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\ [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\ [ElementwiseParams2d(num_params)] super(PixelCNN, self).__init__(*layers)
def test_range(self): batch_size = 10 shape = [8] z = torch.randn(batch_size, *shape) encoder = ConditionalInverseFlow( base_dist=DiagonalNormal(shape), transforms=[ ConditionalAffineBijection( nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / 255 - 1), nn.Linear(shape[0], 2 * shape[0]))), Sigmoid() ]) surjection = VariationalDequantization(encoder, num_bits=8) x = surjection.inverse(z) self.assertTrue(x.min() >= 0) self.assertTrue(x.max() <= 255)
C = args.context_size # Size of context assert D % 2 == 0, 'Only even dimension supported currently.' assert C % D == 0, 'context_size needs to be multiple of num_dims.' # Decoder if args.num_bits is not None: transforms = [Logit()] for _ in range(args.num_flows): net = nn.Sequential( MLP(C + D // 2, P * D // 2, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) context_net = nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**args.num_bits - 1) - 1), MLP(D, C, hidden_units=args.hidden_units, activation=args.activation)) if args.affine: transforms.append( ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn( args.scale_fn))) else: transforms.append( ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net)) if args.actnorm: transforms.append(ActNormBijection(D))
def __init__(self, data_shape, num_bits, num_steps, coupling_network, num_context, num_blocks, mid_channels, depth, growth=None, dropout=None, gated_conv=None, num_mixtures=None): #context_network_type = "conv" context_network_type = coupling_network if context_network_type == "densenet": context_net = nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), DenseBlock(in_channels=data_shape[0], out_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=False), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0), DenseBlock(in_channels=mid_channels, out_channels=num_context, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=False)) elif context_network_type == "transformer": layers = [ LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), Conv2d(in_channels=data_shape[0], out_channels=mid_channels, kernel_size=3, padding=1), nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0) ] for i in range(num_blocks): layers.append( ConvAttnBlock(channels=mid_channels, dropout=0.0, use_attn=False, context_channels=None)) layers.append( Conv2d(in_channels=mid_channels, out_channels=num_context, kernel_size=3, padding=1)) context_net = nn.Sequential(*layers) elif context_network_type == "conv": context_net = nn.Sequential( LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1), Conv2d(data_shape[0], mid_channels // 2, kernel_size=3, stride=1), nn.Conv2d(mid_channels // 2, mid_channels, kernel_size=2, stride=2, padding=0), GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0), Conv2dZeros(in_channels=mid_channels, out_channels=num_context)) else: raise ValueError( f"Unknown dequantization context_network_type type: {context_network_type}" ) # layer transformations of the dequantization flow transforms = [] sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2) for i in range(num_steps): #transforms.append(ActNormBijection2d(sample_shape[0])) transforms.extend([Conv1x1(sample_shape[0])]) if coupling_network in ["conv", "densenet"]: transforms.append( ConditionalCoupling(in_channels=sample_shape[0], num_context=num_context, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, coupling_network=coupling_network)) elif coupling_network == "transformer": transforms.append( ConditionalMixtureCoupling(in_channels=sample_shape[0], num_context=num_context, mid_channels=mid_channels, num_mixtures=num_mixtures, num_blocks=num_blocks, dropout=dropout, use_attn=False)) else: raise ValueError( f"Unknown dequantization coupling network type: {coupling_network}" ) # Final shuffle of channels, squeeze and sigmoid transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid()]) super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape), transforms=transforms, context_init=context_net)
D = args.num_dims # Number of data dimensions P = 2 if args.affine else 1 # Number of elementwise parameters C = args.context_size # Size of context I = D // 2 O = D // 2 + D % 2 # Decoder if args.num_bits is not None: transforms = [Logit()] for _ in range(args.num_flows): net = nn.Sequential(MLP(C+I, P*O, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1), MLP(D, C, hidden_units=args.hidden_units, activation=args.activation)) if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I)) else: transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I)) if args.actnorm: transforms.append(ActNormBijection(D)) if args.permutation == 'reverse': transforms.append(Reverse(D)) elif args.permutation == 'shuffle': transforms.append(Shuffle(D)) transforms.pop() decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device) # Flow transforms = [] for _ in range(args.num_flows): net = nn.Sequential(MLP(I, P*O,