def test_autoregressive_type_A(self): batch_size = 16 features = 10 hidden_features = 2 * [50] num_params = 3 x = torch.randn(batch_size, features) x_altered = copy.deepcopy(x) x_altered[:, 2] += 100.0 # Alter feature number 2 for random_mask in [True, False]: with self.subTest(random_mask=random_mask): module = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=False, random_mask=random_mask, ) y = module(x) y_altered = module(x_altered) # Assert all elements up to (and including) 2 are unaltered self.assertEqual(y[:, :3], y_altered[:, :3]) # Assert all elements from 2 are altered self.assertFalse((y[:, 3:] == y_altered[:, 3:]).view(-1).all())
def test_bijection_is_well_behaved(self): batch_size = 10 features = 7 x = torch.randn(batch_size, features) net = MADE(features, num_params=2, hidden_features=[21]) self.eps = 1e-6 bijection = AffineAutoregressiveBijection(net, autoregressive_order='ltr') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, features))
def test_bijection_is_well_behaved(self): num_mix = 4 batch_size = 10 features = 7 x = torch.randn(batch_size, features) net = MADE(features, num_params=3*num_mix, hidden_features=[21]) self.eps = 5e-5 bijection = LogisticMixtureAutoregressiveBijection(net, num_mixtures=num_mix, autoregressive_order='ltr') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, features))
def test_total_mask_random(self): features = 10 hidden_features = 5 * [50] num_params = 1 model = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=False, random_mask=True, ) total_mask = None for module in model.modules(): if isinstance(module, MaskedLinear): if total_mask is None: total_mask = module.mask else: total_mask = module.mask @ total_mask total_mask = (total_mask > 0).float() self.assertEqual(torch.triu(total_mask), torch.zeros([features, features]))
def test_total_mask_sequential(self): features = 10 hidden_features = 5 * [50] num_params = 1 model = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=False, random_mask=False, ) total_mask = None for module in model.modules(): if isinstance(module, MaskedLinear): if total_mask is None: total_mask = module.mask else: total_mask = module.mask @ total_mask total_mask = (total_mask > 0).float() reference = torch.tril(torch.ones([features, features]), -1) self.assertEqual(total_mask, reference)
def test_bijection_is_well_behaved(self): num_bins = 4 batch_size = 10 features = 7 x = torch.rand(batch_size, features) net = MADE(features, num_params=3 * num_bins + 1, hidden_features=[21]) self.eps = 1e-5 bijection = RationalQuadraticSplineAutoregressiveBijection( net, num_bins=num_bins, autoregressive_order='ltr') self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, features))
def test_shape(self): batch_size = 16 features = 10 hidden_features = 5 * [50] num_params = 3 inputs = torch.randn(batch_size, features) for random_order, random_mask in [(False, False), (False, True), (True, False), (True, True)]: with self.subTest(random_order=random_order, random_mask=random_mask): model = MADE( features=features, num_params=num_params, hidden_features=hidden_features, random_order=random_order, random_mask=random_mask, ) outputs = model(inputs) self.assertEqual(outputs.dim(), 3) self.assertEqual(outputs.shape[0], batch_size) self.assertEqual(outputs.shape[1], features) self.assertEqual(outputs.shape[2], num_params)