Пример #1
0
    def test_total_mask_random(self):
        features = 10
        hidden_features = 5 * [50]
        num_params = 1

        model = MADE(
            features=features,
            num_params=num_params,
            hidden_features=hidden_features,
            random_order=False,
            random_mask=True,
        )
        total_mask = None
        for module in model.modules():
            if isinstance(module, MaskedLinear):
                if total_mask is None:
                    total_mask = module.mask
                else:
                    total_mask = module.mask @ total_mask
        total_mask = (total_mask > 0).float()
        self.assertEqual(torch.triu(total_mask),
                         torch.zeros([features, features]))
Пример #2
0
    def test_total_mask_sequential(self):
        features = 10
        hidden_features = 5 * [50]
        num_params = 1

        model = MADE(
            features=features,
            num_params=num_params,
            hidden_features=hidden_features,
            random_order=False,
            random_mask=False,
        )
        total_mask = None
        for module in model.modules():
            if isinstance(module, MaskedLinear):
                if total_mask is None:
                    total_mask = module.mask
                else:
                    total_mask = module.mask @ total_mask
        total_mask = (total_mask > 0).float()
        reference = torch.tril(torch.ones([features, features]), -1)
        self.assertEqual(total_mask, reference)