Beispiel #1
0
    def test_autoregressive_type_A(self):
        batch_size = 16
        features = 10
        hidden_features = 2 * [50]
        num_params = 3
        x = torch.randn(batch_size, features)
        x_altered = copy.deepcopy(x)
        x_altered[:, 2] += 100.0  # Alter feature number 2

        for random_mask in [True, False]:
            with self.subTest(random_mask=random_mask):
                module = MADE(
                    features=features,
                    num_params=num_params,
                    hidden_features=hidden_features,
                    random_order=False,
                    random_mask=random_mask,
                )
                y = module(x)
                y_altered = module(x_altered)

                # Assert all elements up to (and including) 2 are unaltered
                self.assertEqual(y[:, :3], y_altered[:, :3])

                # Assert all elements from 2 are altered
                self.assertFalse((y[:, 3:] == y_altered[:, 3:]).view(-1).all())
Beispiel #2
0
    def test_bijection_is_well_behaved(self):
        batch_size = 10
        features = 7
        x = torch.randn(batch_size, features)
        net = MADE(features, num_params=2, hidden_features=[21])

        self.eps = 1e-6
        bijection = AffineAutoregressiveBijection(net, autoregressive_order='ltr')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, features))
Beispiel #3
0
    def test_bijection_is_well_behaved(self):
        num_mix = 4
        batch_size = 10
        features = 7
        x = torch.randn(batch_size, features)
        net = MADE(features, num_params=3*num_mix, hidden_features=[21])

        self.eps = 5e-5
        bijection = LogisticMixtureAutoregressiveBijection(net, num_mixtures=num_mix, autoregressive_order='ltr')
        self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, features))
Beispiel #4
0
    def test_total_mask_random(self):
        features = 10
        hidden_features = 5 * [50]
        num_params = 1

        model = MADE(
            features=features,
            num_params=num_params,
            hidden_features=hidden_features,
            random_order=False,
            random_mask=True,
        )
        total_mask = None
        for module in model.modules():
            if isinstance(module, MaskedLinear):
                if total_mask is None:
                    total_mask = module.mask
                else:
                    total_mask = module.mask @ total_mask
        total_mask = (total_mask > 0).float()
        self.assertEqual(torch.triu(total_mask),
                         torch.zeros([features, features]))
Beispiel #5
0
    def test_total_mask_sequential(self):
        features = 10
        hidden_features = 5 * [50]
        num_params = 1

        model = MADE(
            features=features,
            num_params=num_params,
            hidden_features=hidden_features,
            random_order=False,
            random_mask=False,
        )
        total_mask = None
        for module in model.modules():
            if isinstance(module, MaskedLinear):
                if total_mask is None:
                    total_mask = module.mask
                else:
                    total_mask = module.mask @ total_mask
        total_mask = (total_mask > 0).float()
        reference = torch.tril(torch.ones([features, features]), -1)
        self.assertEqual(total_mask, reference)
    def test_bijection_is_well_behaved(self):
        num_bins = 4
        batch_size = 10
        features = 7
        x = torch.rand(batch_size, features)
        net = MADE(features, num_params=3 * num_bins + 1, hidden_features=[21])

        self.eps = 1e-5
        bijection = RationalQuadraticSplineAutoregressiveBijection(
            net, num_bins=num_bins, autoregressive_order='ltr')
        self.assert_bijection_is_well_behaved(bijection,
                                              x,
                                              z_shape=(batch_size, features))
Beispiel #7
0
    def test_shape(self):
        batch_size = 16
        features = 10
        hidden_features = 5 * [50]
        num_params = 3

        inputs = torch.randn(batch_size, features)

        for random_order, random_mask in [(False, False), (False, True),
                                          (True, False), (True, True)]:
            with self.subTest(random_order=random_order,
                              random_mask=random_mask):
                model = MADE(
                    features=features,
                    num_params=num_params,
                    hidden_features=hidden_features,
                    random_order=random_order,
                    random_mask=random_mask,
                )
                outputs = model(inputs)
                self.assertEqual(outputs.dim(), 3)
                self.assertEqual(outputs.shape[0], batch_size)
                self.assertEqual(outputs.shape[1], features)
                self.assertEqual(outputs.shape[2], num_params)