def test_shape(self):
        module = ElementwiseParams(3)
        y = module(self.x)
        expected_shape = (10, 2, 3)
        self.assertEqual(y.shape, expected_shape)

        module = ElementwiseParams(2)
        y = module(self.x)
        expected_shape = (10, 3, 2)
        self.assertEqual(y.shape, expected_shape)
    def test_bijection_is_well_behaved(self):
        num_bins = 16
        num_mix = 8
        batch_size = 10
        elementwise_params = 3 * num_mix

        self.eps = 1e-6
        for shape in [(6, ), (6, 4, 4)]:
            for num_condition in [None, 1]:
                with self.subTest(shape=shape, num_condition=num_condition):
                    x = torch.rand(batch_size, *shape)
                    if num_condition is None:
                        if len(shape) == 1:
                            net = nn.Sequential(
                                nn.Linear(3, 3 * elementwise_params),
                                ElementwiseParams(elementwise_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(3,
                                          3 * elementwise_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(elementwise_params))
                    else:
                        if len(shape) == 1:
                            net = nn.Sequential(
                                nn.Linear(1, 5 * elementwise_params),
                                ElementwiseParams(elementwise_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(1,
                                          5 * elementwise_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(elementwise_params))
                    bijection = CensoredLogisticMixtureCouplingBijection(
                        net,
                        num_mixtures=num_mix,
                        num_bins=num_bins,
                        num_condition=num_condition)
                    self.assert_bijection_is_well_behaved(bijection,
                                                          x,
                                                          z_shape=(batch_size,
                                                                   *shape))
                    z, _ = bijection.forward(x)
                    if num_condition is None:
                        self.assertEqual(x[:, :3], z[:, :3])
                    else:
                        self.assertEqual(x[:, :1], z[:, :1])
    def test_order(self):
        module = ElementwiseParams(2, mode='interleaved')
        y = module(self.x)
        self.assertEqual(y[:, 0],
                         torch.stack([self.x[:, 0], self.x[:, 3]], dim=-1))
        self.assertEqual(y[:, 1],
                         torch.stack([self.x[:, 1], self.x[:, 4]], dim=-1))
        self.assertEqual(y[:, 2],
                         torch.stack([self.x[:, 2], self.x[:, 5]], dim=-1))

        module = ElementwiseParams(2, mode='sequential')
        y = module(self.x)
        self.assertEqual(y[:, 0], self.x[:, 0:2])
        self.assertEqual(y[:, 1], self.x[:, 2:4])
        self.assertEqual(y[:, 2], self.x[:, 4:6])
Ejemplo n.º 4
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10

        self.eps = 5e-6
        for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']:
            for shape in [(6, ), (6, 8, 8)]:
                for num_condition in [None, 1]:
                    with self.subTest(shape=shape,
                                      num_condition=num_condition,
                                      scale_str=scale_str):
                        x = torch.randn(batch_size, *shape)
                        context = torch.randn(batch_size, *shape)
                        if num_condition is None:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(3 + 6, 3 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(3 + 6,
                                              3 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        else:
                            if len(shape) == 1:
                                net = nn.Sequential(nn.Linear(1 + 6, 5 * 2),
                                                    ElementwiseParams(2))
                            if len(shape) == 3:
                                net = nn.Sequential(
                                    nn.Conv2d(1 + 6,
                                              5 * 2,
                                              kernel_size=3,
                                              padding=1),
                                    ElementwiseParams2d(2))
                        bijection = ConditionalAffineCouplingBijection(
                            net,
                            num_condition=num_condition,
                            scale_fn=scale_fn(scale_str))
                        self.assert_bijection_is_well_behaved(
                            bijection,
                            x,
                            context,
                            z_shape=(batch_size, *shape))
                        z, _ = bijection.forward(x, context)
                        if num_condition is None:
                            self.assertEqual(x[:, :3], z[:, :3])
                        else:
                            self.assertEqual(x[:, :1], z[:, :1])
Ejemplo n.º 5
0
    def test_bijection_is_well_behaved(self):
        num_bins = 16
        batch_size = 10

        num_params = 2 * num_bins + 1

        self.eps = 5e-3
        for shape in [(6, ), (6, 8, 8)]:
            for num_condition in [None, 1]:
                with self.subTest(shape=shape, num_condition=num_condition):
                    x = torch.rand(batch_size, *shape)
                    if num_condition is None:
                        if len(shape) == 1:
                            net = nn.Sequential(nn.Linear(3, 3 * num_params),
                                                ElementwiseParams(num_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(3,
                                          3 * num_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(num_params))
                    else:
                        if len(shape) == 1:
                            net = nn.Sequential(nn.Linear(1, 5 * num_params),
                                                ElementwiseParams(num_params))
                        if len(shape) == 3:
                            net = nn.Sequential(
                                nn.Conv2d(1,
                                          5 * num_params,
                                          kernel_size=3,
                                          padding=1),
                                ElementwiseParams2d(num_params))
                    bijection = QuadraticSplineCouplingBijection(
                        net, num_bins=num_bins, num_condition=num_condition)
                    self.assert_bijection_is_well_behaved(bijection,
                                                          x,
                                                          z_shape=(batch_size,
                                                                   *shape))
                    z, _ = bijection.forward(x)
                    if num_condition is None:
                        self.assertEqual(x[:, :3], z[:, :3])
                    else:
                        self.assertEqual(x[:, :1], z[:, :1])
Ejemplo n.º 6
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist, cond_size):

        D = 2  # Number of data dimensions
        A = D + augment_size  # Number of augmented data dimensions
        P = 2 if affine else 1  # Number of elementwise parameters

        # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow.
        context_init = MLP(input_size=cond_size,
                           output_size=D,
                           hidden_units=hidden_units,
                           activation=activation)

        # initialize flow with either augmentation or Abs surjection
        if augment_size > 0:
            assert augment_size % 2 == 0
            transforms = [Augment(StandardNormal((augment_size, )), x_size=D)]

        else:
            transforms = []
            transforms = [SimpleAbsSurjection()]
            if range_flow == 'logit':
                transforms += [
                    ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                    Logit()
                ]
            elif range_flow == 'softplus':
                transforms += [SoftplusInverse()]

        # apply coupling layer flows
        for _ in range(num_flows):
            net = nn.Sequential(
                MLP(A // 2 + D,
                    P * A // 2,
                    hidden_units=hidden_units,
                    activation=activation), ElementwiseParams(P))
            if affine:
                transforms.append(
                    ConditionalAffineCouplingBijection(
                        net, scale_fn=scale_fn(scale_fn_str)))
            else:
                transforms.append(ConditionalAdditiveCouplingBijection(net))
            if actnorm: transforms.append(ActNormBijection(D))
            transforms.append(Reverse(A))

        transforms.pop()

        if base_dist == "uniform":
            base = StandardUniform((A, ))
        else:
            base = StandardNormal((A, ))

        super(SRFlow, self).__init__(base_dist=base,
                                     transforms=transforms,
                                     context_init=context_init)
Ejemplo n.º 7
0
    def __init__(self,
                 features,
                 num_params,
                 hidden_features,
                 random_order=False,
                 random_mask=False,
                 random_seed=None,
                 activation='relu',
                 dropout_prob=0.,
                 batch_norm=False):

        layers = []

        # Build layers
        data_degrees = MaskedLinear.get_data_degrees(features,
                                                     random_order=random_order,
                                                     random_seed=random_seed)
        in_degrees = copy.deepcopy(data_degrees)
        for i, out_features in enumerate(hidden_features):
            layers.append(
                MaskedLinear(
                    in_degrees=in_degrees,
                    out_features=out_features,
                    data_features=features,
                    random_mask=random_mask,
                    random_seed=random_seed + i if random_seed else
                    None,  # Change random seed to get different masks
                    is_output=False))
            in_degrees = layers[-1].degrees
            if batch_norm:
                layers.append(nn.BatchNorm1d(out_features))
            layers.append(act_module(activation))
            if dropout_prob > 0.0:
                layers.append(nn.Dropout(dropout_prob))

        # Build output layer
        layers.append(
            MaskedLinear(in_degrees=in_degrees,
                         out_features=features * num_params,
                         data_features=features,
                         random_mask=random_mask,
                         random_seed=random_seed,
                         is_output=True,
                         data_degrees=data_degrees))
        layers.append(ElementwiseParams(num_params, mode='sequential'))

        super(MADE, self).__init__(*layers)
Ejemplo n.º 8
0
## Specify model ##
###################

assert args.augdim % 2 == 0

D = 2  # Number of data dimensions
A = 2 + args.augdim  # Number of augmented data dimensions
P = 2 if args.affine else 1  # Number of elementwise parameters

transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)]
for _ in range(args.num_flows):
    net = nn.Sequential(
        MLP(A // 2,
            P * A // 2,
            hidden_units=args.hidden_units,
            activation=args.activation), ElementwiseParams(P))
    if args.affine:
        transforms.append(
            AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:
        transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()

model = Flow(base_dist=StandardNormal((A, )),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################
Ejemplo n.º 9
0
D = args.num_dims # Number of data dimensions
P = 2 if args.affine else 1 # Number of elementwise parameters
C = args.context_size # Size of context

I = D // 2
O = D // 2 + D % 2

# Decoder
if args.num_bits is not None:
    transforms = [Logit()]
    for _ in range(args.num_flows):
        net = nn.Sequential(MLP(C+I, P*O,
                                hidden_units=args.hidden_units,
                                activation=args.activation),
                            ElementwiseParams(P))
        context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1),
                                    MLP(D, C,
                                        hidden_units=args.hidden_units,
                                        activation=args.activation))
        if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
        else:           transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I))
        if args.actnorm: transforms.append(ActNormBijection(D))
        if args.permutation == 'reverse':   transforms.append(Reverse(D))
        elif args.permutation == 'shuffle': transforms.append(Shuffle(D))
    transforms.pop()
    decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device)

# Flow
transforms = []
for _ in range(args.num_flows):
 def test_layer_is_well_behaved(self):
     module = ElementwiseParams(3)
     self.assert_layer_is_well_behaved(module, self.x)
Ejemplo n.º 11
0
def net():
    return nn.Sequential(nn.Linear(1, 200), nn.ReLU(), nn.Linear(200, 100),
                         nn.ReLU(), nn.Linear(100, 2), ElementwiseParams(2))
Ejemplo n.º 12
0
    def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units,
                 activation, range_flow, augment_size, base_dist):

        D = 2  # Number of data dimensions

        if base_dist == "uniform":
            classifier = MLP(D,
                             D // 2,
                             hidden_units=hidden_units,
                             activation=activation,
                             out_lambda=lambda x: x.view(-1))
            transforms = [
                ElementAbsSurjection(classifier=classifier),
                ShiftBijection(shift=torch.tensor([[0.0, 4.0]])),
                ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]]))
            ]
            base = StandardUniform((D, ))

        else:
            A = D + augment_size  # Number of augmented data dimensions
            P = 2 if affine else 1  # Number of elementwise parameters

            # initialize flow with either augmentation or Abs surjection
            if augment_size > 0:
                assert augment_size % 2 == 0
                transforms = [
                    Augment(StandardNormal((augment_size, )), x_size=D)
                ]

            else:
                transforms = [SimpleAbsSurjection()]
                if range_flow == 'logit':
                    transforms += [
                        ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])),
                        Logit()
                    ]
                elif range_flow == 'softplus':
                    transforms += [SoftplusInverse()]

            # apply coupling layer flows
            for _ in range(num_flows):
                net = nn.Sequential(
                    MLP(A // 2,
                        P * A // 2,
                        hidden_units=hidden_units,
                        activation=activation), ElementwiseParams(P))

                if affine:
                    transforms.append(
                        AffineCouplingBijection(
                            net, scale_fn=scale_fn(scale_fn_str)))
                else:
                    transforms.append(AdditiveCouplingBijection(net))
                if actnorm: transforms.append(ActNormBijection(D))
                transforms.append(Reverse(A))

            transforms.pop()
            base = StandardNormal((A, ))

        super(UnconditionalFlow, self).__init__(base_dist=base,
                                                transforms=transforms)