def test_total_mask_random(self): features = 10 hidden_features = 5 * [50] num_params = 1 model = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=False, random_mask=True, ) total_mask = None for module in model.modules(): if isinstance(module, MaskedLinear): if total_mask is None: total_mask = module.mask else: total_mask = module.mask @ total_mask total_mask = (total_mask > 0).float() self.assertEqual(torch.triu(total_mask), torch.zeros([features, features]))
def test_total_mask_sequential(self): features = 10 hidden_features = 5 * [50] num_params = 1 model = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=False, random_mask=False, ) total_mask = None for module in model.modules(): if isinstance(module, MaskedLinear): if total_mask is None: total_mask = module.mask else: total_mask = module.mask @ total_mask total_mask = (total_mask > 0).float() reference = torch.tril(torch.ones([features, features]), -1) self.assertEqual(total_mask, reference)