def test_total_mask_sequential(self): features = 10 hidden_features = 50 num_blocks = 5 output_multiplier = 1 for use_residual_blocks in [True, False]: with self.subTest(use_residual_blocks=use_residual_blocks): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=False, ) total_mask = model.initial_layer.mask for block in model.blocks: if use_residual_blocks: self.assertIsInstance(block, made.MaskedResidualBlock) total_mask = block.linear_layers[0].mask @ total_mask total_mask = block.linear_layers[1].mask @ total_mask else: self.assertIsInstance(block, made.MaskedFeedforwardBlock) total_mask = block.linear.mask @ total_mask total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() reference = torch.tril(torch.ones([features, features]), -1) self.assertEqual(total_mask, reference)
def test_unconditional(self): features = 100 hidden_features = 200 num_blocks = 5 output_multiplier = 3 batch_size = 16 inputs = torch.randn(batch_size, features) for use_residual_blocks, random_mask in [(False, False), (False, True), (True, False)]: with self.subTest(use_residual_blocks=use_residual_blocks, random_mask=random_mask): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=random_mask, ) outputs = model(inputs) self.assertEqual(outputs.dim(), 2) self.assertEqual(outputs.shape[0], batch_size) self.assertEqual(outputs.shape[1], output_multiplier * features)
def test_gradients(self): features = 10 hidden_features = 256 num_blocks = 20 output_multiplier = 3 for use_residual_blocks, random_mask in [(False, False), (False, True), (True, False)]: with self.subTest(use_residual_blocks=use_residual_blocks, random_mask=random_mask): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=random_mask, ) inputs = torch.randn(1, features) inputs.requires_grad = True for k in range(features * output_multiplier): outputs = model(inputs) outputs[0, k].backward() depends = inputs.grad.data[0] != 0.0 dim = k // output_multiplier self.assertEqual(torch.all(depends[dim:] == 0), 1)
def test_total_mask_random(self): features = 10 hidden_features = 50 num_blocks = 5 output_multiplier = 1 model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=False, random_mask=True, ) total_mask = model.initial_layer.mask for block in model.blocks: self.assertIsInstance(block, made.MaskedFeedforwardBlock) total_mask = block.linear.mask @ total_mask total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() self.assertEqual(torch.triu(total_mask), torch.zeros([features, features]))
def __init__(self, features, hidden_features, context_features=None, num_bins=10, tails=None, tail_bound=1., num_blocks=2, use_residual_blocks=True, random_mask=False, permute_mask=False, activation=F.relu, dropout_probability=0., use_batch_norm=False, init_identity=True, min_bin_width=splines.DEFAULT_MIN_BIN_WIDTH, min_bin_height=splines.DEFAULT_MIN_BIN_HEIGHT, min_derivative=splines.DEFAULT_MIN_DERIVATIVE): self.num_bins = num_bins self.min_bin_width = min_bin_width self.min_bin_height = min_bin_height self.min_derivative = min_derivative self.tails = tails if isinstance(self.tails, list) or isinstance(self.tails, tuple): ind_circ = [] for i in range(features): if self.tails[i] == 'circular': ind_circ += [i] if torch.is_tensor(tail_bound): scale_pf = np.pi / tail_bound[ind_circ] else: scale_pf = np.pi / tail_bound preprocessing = PeriodicFeatures(features, ind_circ, scale_pf) else: preprocessing = None autoregressive_net = made_module.MADE( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_blocks, output_multiplier=self._output_dim_multiplier(), use_residual_blocks=use_residual_blocks, random_mask=random_mask, permute_mask=permute_mask, activation=activation, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, preprocessing=preprocessing) if init_identity: torch.nn.init.constant_(autoregressive_net.final_layer.weight, 0.) torch.nn.init.constant_(autoregressive_net.final_layer.bias, np.log(np.exp(1 - min_derivative) - 1)) super().__init__(autoregressive_net) if torch.is_tensor(tail_bound): self.register_buffer('tail_bound', tail_bound) else: self.tail_bound = tail_bound