Esempio n. 1
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        shape = (3, 8, 8)
        x = torch.randn(batch_size, *shape)
        net_spatial = nn.Sequential(
            SpatialMaskedConv2d(3,
                                3 * 2,
                                kernel_size=3,
                                padding=1,
                                mask_type='A'), ElementwiseParams2d(2))
        net = nn.Sequential(
            MaskedConv2d(3, 3 * 2, kernel_size=3, padding=1, mask_type='A'),
            ElementwiseParams2d(2))

        self.eps = 1e-6
        for autoregressive_order in ['raster_cwh', 'raster_wh']:
            with self.subTest(autoregressive_order=autoregressive_order):
                if autoregressive_order == 'raster_cwh': autoreg_net = net
                elif autoregressive_order == 'raster_wh':
                    autoreg_net = net_spatial
                bijection = AffineAutoregressiveBijection2d(
                    autoreg_net, autoregressive_order=autoregressive_order)
                self.assert_bijection_is_well_behaved(bijection,
                                                      x,
                                                      z_shape=(batch_size,
                                                               *shape))
Esempio n. 2
0
    def __init__(self, in_channels, num_context, num_blocks, mid_channels,
                 depth, growth, dropout, gated_conv, coupling_network):

        assert in_channels % 2 == 0

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=in_channels // 2 + num_context,
                         out_channels=in_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=growth,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=in_channels // 2 + num_context,
                        out_channels=in_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
    def test_shape(self):
        module = ElementwiseParams2d(3)
        y = module(self.x)
        expected_shape = (10, 2, 4, 4, 3)
        self.assertEqual(y.shape, expected_shape)

        module = ElementwiseParams2d(2)
        y = module(self.x)
        expected_shape = (10, 3, 4, 4, 2)
        self.assertEqual(y.shape, expected_shape)
Esempio n. 4
0
    def __init__(self,
                 x_size,
                 y_size,
                 coupling_network,
                 mid_channels,
                 depth,
                 num_blocks=None,
                 dropout=None,
                 gated_conv=None,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            in_channels = y_size[0] + x_size[0]
            out_channels = y_size[0] * 2
            split_dim = 3
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2
        else:
            in_channels = y_size[0] // 2 + x_size[0]
            out_channels = y_size[0]
            split_dim = 1
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2]
            assert y_size[
                0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible"

        if coupling_network == "densenet":
            coupling_net = nn.Sequential(
                DenseNet(in_channels=in_channels,
                         out_channels=out_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))

        elif coupling_network == "conv":
            coupling_net = nn.Sequential(
                ConvNet(in_channels=in_channels,
                        out_channels=out_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        weight_norm=True,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))

        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(SRCoupling, self).__init__(coupling_net=coupling_net,
                                         scale_fn=scale_fn("tanh_exp"),
                                         split_dim=split_dim,
                                         flip=flip)
Esempio n. 5
0
    def __init__(self,
                 in_channels,
                 num_context,
                 num_blocks,
                 mid_channels,
                 depth,
                 dropout,
                 gated_conv,
                 coupling_network,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            num_in = in_channels + num_context
            num_out = in_channels * 2
            split_dim = 3
        else:
            num_in = in_channels // 2 + num_context
            num_out = in_channels
            split_dim = 1

        assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible"

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=num_in,
                         out_channels=num_out,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=num_in,
                        out_channels=num_out,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net,
                             scale_fn=scale_fn("tanh_exp"),
                             split_dim=split_dim,
                             flip=flip)
    def test_bijection_is_well_behaved(self):
        num_bins = 16
        num_mix = 8
        batch_size = 10
        elementwise_params = 3 * num_mix

        self.eps = 1e-6
        for shape in [(6, ), (6, 4, 4)]:
            for num_condition in [None, 1]:
                with self.subTest(shape=shape, num_condition=num_condition):
                    x = torch.rand(batch_size, *shape)
                    if num_condition is None:
                        if len(shape) == 1:
                            net = nn.Sequential(
                                nn.Linear(3, 3 * elementwise_params),
                                ElementwiseParams(elementwise_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(3,
                                          3 * elementwise_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(elementwise_params))
                    else:
                        if len(shape) == 1:
                            net = nn.Sequential(
                                nn.Linear(1, 5 * elementwise_params),
                                ElementwiseParams(elementwise_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(1,
                                          5 * elementwise_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(elementwise_params))
                    bijection = CensoredLogisticMixtureCouplingBijection(
                        net,
                        num_mixtures=num_mix,
                        num_bins=num_bins,
                        num_condition=num_condition)
                    self.assert_bijection_is_well_behaved(bijection,
                                                          x,
                                                          z_shape=(batch_size,
                                                                   *shape))
                    z, _ = bijection.forward(x)
                    if num_condition is None:
                        self.assertEqual(x[:, :3], z[:, :3])
                    else:
                        self.assertEqual(x[:, :1], z[:, :1])
    def test_bijection_is_well_behaved(self):
        batch_size = 10

        self.eps = 5e-6
        for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']:
            for shape in [(6, ), (6, 8, 8)]:
                for num_condition in [None, 1]:
                    with self.subTest(shape=shape,
                                      num_condition=num_condition,
                                      scale_str=scale_str):
                        x = torch.randn(batch_size, *shape)
                        context = torch.randn(batch_size, *shape)
                        if num_condition is None:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(3 + 6, 3 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(3 + 6,
                                              3 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        else:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(1 + 6, 5 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(1 + 6,
                                              5 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        bijection = ConditionalAffineCouplingBijection(
                            net,
                            num_condition=num_condition,
                            scale_fn=scale_fn(scale_str))
                        self.assert_bijection_is_well_behaved(
                            bijection,
                            x,
                            context,
                            z_shape=(batch_size, *shape))
                        z, _ = bijection.forward(x, context)
                        if num_condition is None:
                            self.assertEqual(x[:, :3], z[:, :3])
                        else:
                            self.assertEqual(x[:, :1], z[:, :1])
Esempio n. 8
0
    def __init__(self,
                 x_size,
                 y_size,
                 mid_channels,
                 num_blocks,
                 num_mixtures,
                 dropout,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            in_channels = y_size[0]
            split_dim = 3
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2
        else:
            in_channels = y_size[0] // 2
            split_dim = 1
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2]
            assert y_size[
                0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible"

        coupling_net = nn.Sequential(
            TransformerNet(in_channels=in_channels,
                           context_channels=x_size[0],
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(SRMixtureCoupling, self).__init__(coupling_net=coupling_net,
                                                num_mixtures=num_mixtures,
                                                scale_fn=scale_fn("tanh_exp"),
                                                split_dim=split_dim,
                                                flip=flip)
Esempio n. 9
0
    def __init__(self,
                 in_channels,
                 mid_channels,
                 num_mixtures,
                 num_blocks,
                 dropout,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            num_in = in_channels
            split_dim = 3
        else:
            num_in = in_channels // 2
            split_dim = 1

        net = nn.Sequential(
            TransformerNet(in_channels=num_in,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(MixtureCoupling, self).__init__(coupling_net=net,
                                              num_mixtures=num_mixtures,
                                              scale_fn=scale_fn("tanh_exp"),
                                              split_dim=split_dim,
                                              flip=flip)
Esempio n. 10
0
def net(in_channels):
    return nn.Sequential(
        TransformerNet(in_channels // 2,
                       mid_channels=16,
                       num_blocks=2,
                       num_mixtures=k,
                       dropout=0.2), ElementwiseParams2d(2 + k * 3))
Esempio n. 11
0
 def net(channels):
     return nn.Sequential(ConvNet(in_channels=channels//2,
                                  out_channels=(channels - channels//2) * 2,
                                  mid_channels=64,
                                  num_layers=1,
                                  activation='relu'),
                          ElementwiseParams2d(2))
Esempio n. 12
0
 def net(channels):
     return nn.Sequential(ConvNet(input_size=channels//2,
                                  output_size=channels,
                                  hidden_units=64,
                                  num_layers=1,
                                  activation='relu'),
                          ElementwiseParams2d(2))
    def test_mode(self):
        module = ElementwiseParams2d(2, mode='interleaved')
        y = module(self.x)
        self.assertEqual(y[:, 0],
                         torch.stack([self.x[:, 0], self.x[:, 3]], dim=-1))
        self.assertEqual(y[:, 1],
                         torch.stack([self.x[:, 1], self.x[:, 4]], dim=-1))
        self.assertEqual(y[:, 2],
                         torch.stack([self.x[:, 2], self.x[:, 5]], dim=-1))

        module = ElementwiseParams2d(2, mode='sequential')
        y = module(self.x)
        self.assertEqual(y[:, 0],
                         torch.stack([self.x[:, 0], self.x[:, 1]], dim=-1))
        self.assertEqual(y[:, 1],
                         torch.stack([self.x[:, 2], self.x[:, 3]], dim=-1))
        self.assertEqual(y[:, 2],
                         torch.stack([self.x[:, 4], self.x[:, 5]], dim=-1))
Esempio n. 14
0
    def test_bijection_is_well_behaved(self):
        num_bins = 16
        batch_size = 10

        num_params = 2 * num_bins + 1

        self.eps = 5e-3
        for shape in [(6, ), (6, 8, 8)]:
            for num_condition in [None, 1]:
                with self.subTest(shape=shape, num_condition=num_condition):
                    x = torch.rand(batch_size, *shape)
                    if num_condition is None:
                        if len(shape) == 1:
                            net = nn.Sequential(nn.Linear(3, 3 * num_params),
                                                ElementwiseParams(num_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(3,
                                          3 * num_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(num_params))
                    else:
                        if len(shape) == 1:
                            net = nn.Sequential(nn.Linear(1, 5 * num_params),
                                                ElementwiseParams(num_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(1,
                                          5 * num_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(num_params))
                    bijection = QuadraticSplineCouplingBijection(
                        net, num_bins=num_bins, num_condition=num_condition)
                    self.assert_bijection_is_well_behaved(bijection,
                                                          x,
                                                          z_shape=(batch_size,
                                                                   *shape))
                    z, _ = bijection.forward(x)
                    if num_condition is None:
                        self.assertEqual(x[:, :3], z[:, :3])
                    else:
                        self.assertEqual(x[:, :1], z[:, :1])
Esempio n. 15
0
def net(channels):
    return nn.Sequential(
        DenseNet(in_channels=channels // 2,
                 out_channels=channels,
                 num_blocks=1,
                 mid_channels=64,
                 depth=8,
                 growth=16,
                 dropout=0.0,
                 gated_conv=True,
                 zero_init=True), ElementwiseParams2d(2))
Esempio n. 16
0
    def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks,
                 dropout):

        net = nn.Sequential(
            TransformerNet(in_channels // 2,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(MixtureCoupling, self).__init__(coupling_net=net,
                                              num_mixtures=num_mixtures,
                                              scale_fn=scale_fn("tanh_exp"))
Esempio n. 17
0
def net(channels):
    return nn.Sequential(
        #       DenseNet(in_channels=channels//2,
        #                                 out_channels=channels,
        #                                 num_blocks=1,
        #                                 mid_channels=64,
        #                                 depth=8,
        #                                 growth=16,
        #                                 dropout=0.0,
        #                                 gated_conv=True,
        #                                 zero_init=True),
        # nn.Conv1d(channels//2, channels, 1, bias=False),
        Net(in_channels=channels // 2, out_channels=channels, mid_channels=64),
        ElementwiseParams2d(2))
Esempio n. 18
0
    def __init__(self, in_channels, num_blocks, mid_channels, depth, growth, dropout, gated_conv):

        assert in_channels % 2 == 0

        net = nn.Sequential(DenseNet(in_channels=in_channels//2,
                                     out_channels=in_channels,
                                     num_blocks=num_blocks,
                                     mid_channels=mid_channels,
                                     depth=depth,
                                     growth=growth,
                                     dropout=dropout,
                                     gated_conv=gated_conv,
                                     zero_init=True),
                            ElementwiseParams2d(2, mode='sequential'))
        super(Coupling, self).__init__(coupling_net=net)
Esempio n. 19
0
    def __init__(self,
                 in_channels,
                 num_params,
                 filters=128,
                 num_blocks=15,
                 output_filters=1024,
                 kernel_size=3,
                 kernel_size_in=7,
                 init_transforms=lambda x: 2 * x - 1):

        layers = [LambdaLayer(init_transforms)] +\
                 [MaskedConv2d(in_channels, 2 * filters, kernel_size=kernel_size_in, padding=kernel_size_in//2, mask_type='A', data_channels=in_channels)] +\
                 [MaskedResidualBlock2d(filters, data_channels=in_channels, kernel_size=kernel_size) for _ in range(num_blocks)] +\
                 [nn.ReLU(True), MaskedConv2d(2 * filters, output_filters, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [nn.ReLU(True), MaskedConv2d(output_filters, num_params * in_channels, kernel_size=1, mask_type='B', data_channels=in_channels)] +\
                 [ElementwiseParams2d(num_params)]

        super(PixelCNN, self).__init__(*layers)
 def __init__(self, channels, context_channels, params,
              autoregressive_order):
     super(CondNet, self).__init__()
     if autoregressive_order == 'raster_cwh':
         self.conv = MaskedConv2d(channels,
                                  channels * params,
                                  kernel_size=3,
                                  padding=1,
                                  mask_type='A')
     elif autoregressive_order == 'raster_wh':
         self.conv = SpatialMaskedConv2d(channels,
                                         channels * params,
                                         kernel_size=3,
                                         padding=1,
                                         mask_type='A')
     self.context = nn.Conv2d(context_channels,
                              channels * params,
                              kernel_size=1)
     self.out = ElementwiseParams2d(params)
Esempio n. 21
0
    def __init__(self,
                 in_channels,
                 num_context,
                 mid_channels,
                 num_mixtures,
                 num_blocks,
                 dropout,
                 use_attn=True):
        coupling_net = nn.Sequential(
            TransformerNet(in_channels // 2,
                           context_channels=num_context,
                           mid_channels=mid_channels,
                           num_blocks=num_blocks,
                           num_mixtures=num_mixtures,
                           use_attn=use_attn,
                           dropout=dropout),
            ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential'))

        super(ConditionalMixtureCoupling,
              self).__init__(coupling_net=coupling_net,
                             num_mixtures=num_mixtures,
                             scale_fn=scale_fn("tanh_exp"))
 def test_layer_is_well_behaved(self):
     module = ElementwiseParams2d(3)
     self.assert_layer_is_well_behaved(module, self.x)
    def test_bijection_is_well_behaved(self):
        num_bins = 4
        num_mix = 8
        batch_size = 10
        shape = (3,4,4)
        elementwise_params = 3 * num_mix
        x = torch.rand(batch_size, *shape)
        net_spatial = nn.Sequential(SpatialMaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))
        net = nn.Sequential(MaskedConv2d(3,3*elementwise_params, kernel_size=3, padding=1, mask_type='A'), ElementwiseParams2d(elementwise_params))

        self.eps = 1e-6
        bijection = CensoredLogisticMixtureAutoregressiveBijection2d(net, num_mixtures=num_mix, num_bins=num_bins, autoregressive_order='raster_cwh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))

        bijection = CensoredLogisticMixtureAutoregressiveBijection2d(net_spatial, num_mixtures=num_mix, num_bins=num_bins, autoregressive_order='raster_wh')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape))