Beispiel #1
0
    def __init__(self,
                 in_channels,
                 mid_channels,
                 num_blocks,
                 num_mixtures,
                 dropout,
                 use_attn=True,
                 context_channels=0,
                 in_lambda=None,
                 out_lambda=None):
        super(TransformerNet, self).__init__()

        self.in_lambda = LambdaLayer(in_lambda) if in_lambda else None

        self.conv1 = Conv2d(in_channels + context_channels,
                            mid_channels,
                            kernel_size=3,
                            padding=1)
        attn_blocks = [
            ConvAttnBlock(mid_channels, dropout, use_attn)
            for _ in range(num_blocks)
        ]
        self.attn_blocks = nn.ModuleList(attn_blocks)

        self.conv2 = Conv2d(mid_channels,
                            in_channels * (2 + 3 * num_mixtures),
                            kernel_size=3,
                            padding=1)
        self.out_lambda = layers.append(
            LambdaLayer(out_lambda)) if out_lambda else None
Beispiel #2
0
    def __init__(self,
                 in_size,
                 out_size,
                 mid_channels,
                 activation='relu',
                 in_lambda=None,
                 out_lambda=None):

        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))

        layers.append(Conv2dZeros(in_size[0], mid_channels))
        if activation is not None:
            layers.append(act_module(activation, allow_concat=False))

        layers.append(
            Conv2dResize(in_size=(mid_channels, in_size[1], in_size[2]),
                         out_size=out_size))
        if activation is not None:
            layers.append(act_module(activation, allow_concat=False))

        layers.append(Conv2dZeros(out_size[0], out_size[0]))
        if out_lambda: layers.append(LambdaLayer(out_lambda))

        super(ResizeConvNet, self).__init__(*layers)
Beispiel #3
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels,
                 num_layers=1,
                 activation='relu',
                 weight_norm=True,
                 in_lambda=None,
                 out_lambda=None):

        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))
        layers += [Conv2d(in_channels, mid_channels, weight_norm=weight_norm)]
        for i in range(num_layers):
            if activation is not None:
                layers.append(act_module(activation, allow_concat=True))
            layers.append(
                Conv2d(mid_channels,
                       mid_channels,
                       kernel_size=(1, 1),
                       weight_norm=weight_norm))

        if activation is not None:
            layers.append(act_module(activation, allow_concat=True))
        layers.append(Conv2dZeros(mid_channels, out_channels))
        if out_lambda: layers.append(LambdaLayer(out_lambda))

        super(ConvNet, self).__init__(*layers)
Beispiel #4
0
    def __init__(self, input_size, output_size, hidden_units, activation='relu', in_lambda=None, out_lambda=None):
        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))
        for in_size, out_size in zip([input_size] + hidden_units[:-1], hidden_units):
            layers.append(nn.Linear(in_size, out_size))
            layers.append(act_module(activation))
        layers.append(nn.Linear(hidden_units[-1], output_size))
        if out_lambda: layers.append(LambdaLayer(out_lambda))

        super(MLP, self).__init__(*layers)
Beispiel #5
0
    def __init__(self,
                 image_shape,
                 output_dim,
                 num_bits,
                 autoregressive_order='cwh',
                 d_model=512,
                 nhead=8,
                 num_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 kdim=None,
                 vdim=None,
                 attn_bias=True,
                 output_bias=True,
                 checkpoint_blocks=False,
                 in_lambda=lambda x: x,
                 out_lambda=lambda x: x):
        super(DecoderOnlyTransformer2d, self).__init__()
        self.image_shape = torch.Size(image_shape)
        self.autoregressive_order = autoregressive_order
        self.d_model = d_model
        self.num_layers = num_layers

        # Encoding layers
        self.encode = nn.Sequential(
            LambdaLayer(in_lambda), nn.Embedding(2**num_bits, d_model),
            PositionalEncodingImage(image_shape=image_shape,
                                    embedding_dim=d_model))

        self.im2seq = Image2Seq(autoregressive_order, image_shape)
        self.seq2im = Seq2Image(autoregressive_order, image_shape)
        self.ar_shift = AutoregressiveShift(d_model)

        self.transformer = DecoderOnlyTransformer(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            kdim=kdim,
            vdim=vdim,
            attn_bias=attn_bias,
            checkpoint_blocks=checkpoint_blocks)

        self.out_linear = nn.Linear(d_model, output_dim, bias=output_bias)
        self.out_lambda = LambdaLayer(out_lambda)

        self._reset_parameters()
    def test_surjection_is_well_behaved(self):
        batch_size = 10
        shape = [8, 4, 4]
        num_bits_list = [2, 5, 8]

        for num_bits in num_bits_list:
            with self.subTest(num_bits=num_bits):
                x = torch.randint(0, 2**num_bits,
                                  (batch_size, ) + torch.Size(shape))
                encoder = ConditionalInverseFlow(
                    base_dist=DiagonalNormal(shape),
                    transforms=[
                        ConditionalAffineBijection(
                            nn.Sequential(
                                LambdaLayer(lambda x: 2 * x.float() /
                                            (2**num_bits - 1) - 1),
                                nn.Conv2d(shape[0],
                                          2 * shape[0],
                                          kernel_size=3,
                                          padding=1))),
                        Sigmoid()
                    ])
                surjection = VariationalDequantization(encoder,
                                                       num_bits=num_bits)
                self.assert_surjection_is_well_behaved(surjection,
                                                       x,
                                                       z_shape=(batch_size,
                                                                *shape),
                                                       z_dtype=torch.float)
Beispiel #7
0
 def __init__(self,
              channels,
              context_channels=None,
              num_blocks=1,
              dropout=0.0,
              in_lambda=None,
              out_lambda=None):
     assert num_blocks > 0
     layers = []
     if in_lambda: layers.append(LambdaLayer(in_lambda))
     layers += [
         GatedConv(channels,
                   context_channels=context_channels,
                   dropout=dropout) for i in range(num_blocks)
     ]
     if out_lambda: layers.append(LambdaLayer(out_lambda))
     super(GatedConvNet, self).__init__(*layers)
Beispiel #8
0
    def test_layer_is_well_behaved(self):
        batch_size = 10
        shape = (6,)
        x = torch.randn(batch_size, *shape)

        module = LambdaLayer(lambda x: 2 * x - 1)
        self.assert_layer_is_well_behaved(module, x)

        y = module(x)
        self.assertEqual(y, 2 * x - 1)
Beispiel #9
0
    def __init__(self, input_size, output_size, hidden_units, activation='relu', in_lambda=None, out_lambda=None):
        if isinstance(hidden_units, list) == False:
            hidden_units = [hidden_units]
        
        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))

        if len(hidden_units) > 0:
            for in_size, out_size in zip([input_size] + hidden_units[:-1], hidden_units):
                layers.append(nn.Linear(in_size, out_size))
                if activation is not None:
                    layers.append(act_module(activation))

            layers.append(nn.Linear(hidden_units[-1], output_size))
        else:
            layers.append(nn.Linear(input_size, output_size))

        if out_lambda: layers.append(LambdaLayer(out_lambda))

        super(MLP, self).__init__(*layers)
Beispiel #10
0
    def __init__(self,
                 l_input=50,
                 d_input=1,
                 d_output=2,
                 d_model=512,
                 nhead=8,
                 num_layers=6,
                 dim_feedforward=512,
                 dropout=0.1,
                 activation="gelu",
                 kdim=None,
                 vdim=None,
                 attn_bias=True,
                 checkpoint_blocks=False,
                 in_lambda=lambda x: x,
                 out_lambda=lambda x: x):
        super(PositionalDenseTransformer, self).__init__()

        decoder_layer = DenseTransformerBlock(d_model=d_model,
                                              nhead=nhead,
                                              dim_feedforward=dim_feedforward,
                                              dropout=dropout,
                                              activation=activation,
                                              kdim=kdim,
                                              vdim=vdim,
                                              attn_bias=attn_bias,
                                              checkpoint=checkpoint_blocks)

        self.in_lambda = LambdaLayer(in_lambda)
        self.in_linear = nn.Linear(d_input, d_model)
        self.encode = PositionalEncoding1d(l_input, d_model)
        self.layers = _get_clones(decoder_layer, num_layers)
        self.out_norm = nn.LayerNorm(d_model)
        self.out_linear = nn.Linear(d_model, d_output)
        self.out_lambda = LambdaLayer(out_lambda)

        self.num_layers = num_layers
        self.d_model = d_model
        self.nhead = nhead

        self._reset_parameters()
    def setUp(self):
        self.batch_size = 16
        self.num_classes = 5
        self.size = 10
        self.x = torch.randint(self.num_classes, [self.batch_size, self.size])
        self.context = torch.randn(self.batch_size, 20)

        net = nn.Sequential(
            nn.Linear(20, self.size * self.num_classes),
            LambdaLayer(lambda x: x.reshape(self.batch_size, self.size, self.
                                            num_classes)))
        self.distribution = ConditionalCategorical(net=net)
Beispiel #12
0
    def __init__(self,
                 in_channels,
                 out_shape,
                 mid_channels=[256, 128, 64],
                 batch_norm=True,
                 in_lambda=None,
                 out_lambda=None):
        assert isinstance(
            mid_channels,
            list) and len(mid_channels) > 0, f"mid_channels={mid_channels}"

        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))
        layers.append(
            _ConvDecoder(in_channels,
                         out_shape=out_shape,
                         mid_channels=mid_channels,
                         batch_norm=batch_norm,
                         init_weights=True))
        if out_lambda: layers.append(LambdaLayer(out_lambda))
        super(ConvDecoderNet, self).__init__(*layers)
Beispiel #13
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=[64, 128, 256],
                 max_pool=True,
                 batch_norm=True,
                 in_lambda=None,
                 out_lambda=None):
        assert isinstance(
            mid_channels,
            list) and len(mid_channels) > 0, f"mid_channels={mid_channels}"

        layers = []
        if in_lambda: layers.append(LambdaLayer(in_lambda))
        layers.append(
            _ConvEncoder(in_channels,
                         out_channels,
                         mid_channels=mid_channels,
                         max_pool=max_pool,
                         batch_norm=batch_norm))
        if out_lambda: layers.append(LambdaLayer(out_lambda))
        super(ConvEncoderNet, self).__init__(*layers)
    def __init__(self, data_shape, num_bits, num_steps, num_context,
                 num_blocks, mid_channels, depth, growth, dropout, gated_conv):

        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**num_bits-1)-1),
                                    DenseBlock(in_channels=data_shape[0],
                                               out_channels=mid_channels,
                                               depth=4,
                                               growth=16,
                                               dropout=dropout,
                                               gated_conv=gated_conv,
                                               zero_init=False),
                                    nn.Conv2d(mid_channels, mid_channels, kernel_size=2, stride=2, padding=0),
                                    DenseBlock(in_channels=mid_channels,
                                               out_channels=num_context,
                                               depth=4,
                                               growth=16,
                                               dropout=dropout,
                                               gated_conv=gated_conv,
                                               zero_init=False))



        transforms = []
        sample_shape = (data_shape[0] * 4, data_shape[1] // 2, data_shape[2] // 2)
        for i in range(num_steps):
            transforms.extend([
                Conv1x1(sample_shape[0]),
                ConditionalCoupling(in_channels=sample_shape[0],
                                    num_context=num_context,
                                    num_blocks=num_blocks,
                                    mid_channels=mid_channels,
                                    depth=depth,
                                    growth=growth,
                                    dropout=dropout,
                                    gated_conv=gated_conv)
            ])

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]),
                           Unsqueeze2d(),
                           Sigmoid()
                          ])

        super(DequantizationFlow, self).__init__(base_dist=ConvNormal2d(sample_shape),
                                                 transforms=transforms,
                                                 context_init=context_net)
Beispiel #15
0
    def __init__(self,
                 in_channels,
                 num_params,
                 filters=128,
                 num_blocks=15,
                 output_filters=1024,
                 kernel_size=3,
                 kernel_size_in=7,
                 init_transforms=lambda x: 2 * x - 1):

        layers = [LambdaLayer(init_transforms)] +\
                 [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\
                 [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\
                 [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [ElementwiseParams2d(num_params)]

        super(PixelCNN, self).__init__(*layers)
    def test_range(self):
        batch_size = 10
        shape = [8]

        z = torch.randn(batch_size, *shape)
        encoder = ConditionalInverseFlow(
            base_dist=DiagonalNormal(shape),
            transforms=[
                ConditionalAffineBijection(
                    nn.Sequential(
                        LambdaLayer(lambda x: 2 * x.float() / 255 - 1),
                        nn.Linear(shape[0], 2 * shape[0]))),
                Sigmoid()
            ])
        surjection = VariationalDequantization(encoder, num_bits=8)
        x = surjection.inverse(z)
        self.assertTrue(x.min() >= 0)
        self.assertTrue(x.max() <= 255)
Beispiel #17
0
C = args.context_size  # Size of context

assert D % 2 == 0, 'Only even dimension supported currently.'
assert C % D == 0, 'context_size needs to be multiple of num_dims.'

# Decoder
if args.num_bits is not None:
    transforms = [Logit()]
    for _ in range(args.num_flows):
        net = nn.Sequential(
            MLP(C + D // 2,
                P * D // 2,
                hidden_units=args.hidden_units,
                activation=args.activation), ElementwiseParams(P))
        context_net = nn.Sequential(
            LambdaLayer(lambda x: 2 * x.float() / (2**args.num_bits - 1) - 1),
            MLP(D,
                C,
                hidden_units=args.hidden_units,
                activation=args.activation))
        if args.affine:
            transforms.append(
                ConditionalAffineCouplingBijection(coupling_net=net,
                                                   context_net=context_net,
                                                   scale_fn=scale_fn(
                                                       args.scale_fn)))
        else:
            transforms.append(
                ConditionalAdditiveCouplingBijection(coupling_net=net,
                                                     context_net=context_net))
        if args.actnorm: transforms.append(ActNormBijection(D))
Beispiel #18
0
    def __init__(self,
                 data_shape,
                 num_bits,
                 num_steps,
                 coupling_network,
                 num_context,
                 num_blocks,
                 mid_channels,
                 depth,
                 growth=None,
                 dropout=None,
                 gated_conv=None,
                 num_mixtures=None):

        #context_network_type = "conv"
        context_network_type = coupling_network

        if context_network_type == "densenet":
            context_net = nn.Sequential(
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                DenseBlock(in_channels=data_shape[0],
                           out_channels=mid_channels,
                           depth=depth,
                           growth=growth,
                           dropout=dropout,
                           gated_conv=gated_conv,
                           zero_init=False),
                nn.Conv2d(mid_channels,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0),
                DenseBlock(in_channels=mid_channels,
                           out_channels=num_context,
                           depth=depth,
                           growth=growth,
                           dropout=dropout,
                           gated_conv=gated_conv,
                           zero_init=False))

        elif context_network_type == "transformer":
            layers = [
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                Conv2d(in_channels=data_shape[0],
                       out_channels=mid_channels,
                       kernel_size=3,
                       padding=1),
                nn.Conv2d(mid_channels,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0)
            ]
            for i in range(num_blocks):
                layers.append(
                    ConvAttnBlock(channels=mid_channels,
                                  dropout=0.0,
                                  use_attn=False,
                                  context_channels=None))
            layers.append(
                Conv2d(in_channels=mid_channels,
                       out_channels=num_context,
                       kernel_size=3,
                       padding=1))
            context_net = nn.Sequential(*layers)

        elif context_network_type == "conv":
            context_net = nn.Sequential(
                LambdaLayer(lambda x: 2 * x.float() / (2**num_bits - 1) - 1),
                Conv2d(data_shape[0],
                       mid_channels // 2,
                       kernel_size=3,
                       stride=1),
                nn.Conv2d(mid_channels // 2,
                          mid_channels,
                          kernel_size=2,
                          stride=2,
                          padding=0),
                GatedConvNet(channels=mid_channels, num_blocks=2, dropout=0.0),
                Conv2dZeros(in_channels=mid_channels,
                            out_channels=num_context))
        else:
            raise ValueError(
                f"Unknown dequantization context_network_type type: {context_network_type}"
            )

        # layer transformations of the dequantization flow
        transforms = []
        sample_shape = (data_shape[0] * 4, data_shape[1] // 2,
                        data_shape[2] // 2)
        for i in range(num_steps):
            #transforms.append(ActNormBijection2d(sample_shape[0]))
            transforms.extend([Conv1x1(sample_shape[0])])

            if coupling_network in ["conv", "densenet"]:
                transforms.append(
                    ConditionalCoupling(in_channels=sample_shape[0],
                                        num_context=num_context,
                                        num_blocks=num_blocks,
                                        mid_channels=mid_channels,
                                        depth=depth,
                                        growth=growth,
                                        dropout=dropout,
                                        gated_conv=gated_conv,
                                        coupling_network=coupling_network))
            elif coupling_network == "transformer":
                transforms.append(
                    ConditionalMixtureCoupling(in_channels=sample_shape[0],
                                               num_context=num_context,
                                               mid_channels=mid_channels,
                                               num_mixtures=num_mixtures,
                                               num_blocks=num_blocks,
                                               dropout=dropout,
                                               use_attn=False))
            else:
                raise ValueError(
                    f"Unknown dequantization coupling network type: {coupling_network}"
                )

        # Final shuffle of channels, squeeze and sigmoid
        transforms.extend([Conv1x1(sample_shape[0]), Unsqueeze2d(), Sigmoid()])
        super(DequantizationFlow,
              self).__init__(base_dist=ConvNormal2d(sample_shape),
                             transforms=transforms,
                             context_init=context_net)
Beispiel #19
0
D = args.num_dims # Number of data dimensions
P = 2 if args.affine else 1 # Number of elementwise parameters
C = args.context_size # Size of context

I = D // 2
O = D // 2 + D % 2

# Decoder
if args.num_bits is not None:
    transforms = [Logit()]
    for _ in range(args.num_flows):
        net = nn.Sequential(MLP(C+I, P*O,
                                hidden_units=args.hidden_units,
                                activation=args.activation),
                            ElementwiseParams(P))
        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1),
                                    MLP(D, C,
                                        hidden_units=args.hidden_units,
                                        activation=args.activation))
        if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
        else:           transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse':   transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(I, P*O,