def __init__(self, num_bins, features, hidden_features, context_features=None, num_blocks=2, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=0., use_batch_norm=False): self.num_bins = num_bins self.features = features made = made_module.MADE( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_blocks, output_multiplier=self._output_dim_multiplier(), use_residual_blocks=use_residual_blocks, random_mask=random_mask, activation=activation, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) super(MaskedPiecewiseCubicAutoregressiveTransform, self).__init__(made)
def test_total_mask_sequential(self): features = 10 hidden_features = 50 num_blocks = 5 output_multiplier = 1 for use_residual_blocks in [True, False]: with self.subTest(use_residual_blocks=use_residual_blocks): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=False, ) total_mask = model.initial_layer.mask for block in model.blocks: if use_residual_blocks: self.assertIsInstance(block, made.MaskedResidualBlock) total_mask = block.linear_layers[0].mask @ total_mask total_mask = block.linear_layers[1].mask @ total_mask else: self.assertIsInstance(block, made.MaskedFeedforwardBlock) total_mask = block.linear.mask @ total_mask total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() reference = torch.tril(torch.ones([features, features]), -1) self.assertEqual(total_mask, reference)
def test_gradients(self): features = 10 hidden_features = 256 num_blocks = 20 output_multiplier = 3 for use_residual_blocks, random_mask in [(False, False), (False, True), (True, False)]: with self.subTest(use_residual_blocks=use_residual_blocks, random_mask=random_mask): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=random_mask, ) inputs = torch.randn(1, features) inputs.requires_grad = True for k in range(features * output_multiplier): outputs = model(inputs) outputs[0, k].backward() depends = inputs.grad.data[0] != 0.0 dim = k // output_multiplier self.assertEqual(torch.all(depends[dim:] == 0), 1)
def test_unconditional(self): features = 100 hidden_features = 200 num_blocks = 5 output_multiplier = 3 batch_size = 16 inputs = torch.randn(batch_size, features) for use_residual_blocks, random_mask in [(False, False), (False, True), (True, False)]: with self.subTest(use_residual_blocks=use_residual_blocks, random_mask=random_mask): model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=use_residual_blocks, random_mask=random_mask, ) outputs = model(inputs) self.assertEqual(outputs.dim(), 2) self.assertEqual(outputs.shape[0], batch_size) self.assertEqual(outputs.shape[1], output_multiplier * features)
def test_total_mask_random(self): features = 10 hidden_features = 50 num_blocks = 5 output_multiplier = 1 model = made.MADE( features=features, hidden_features=hidden_features, num_blocks=num_blocks, output_multiplier=output_multiplier, use_residual_blocks=False, random_mask=True, ) total_mask = model.initial_layer.mask for block in model.blocks: self.assertIsInstance(block, made.MaskedFeedforwardBlock) total_mask = block.linear.mask @ total_mask total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() self.assertEqual(torch.triu(total_mask), torch.zeros([features, features]))
def __init__( self, features, hidden_features, context_features=None, num_bins=10, tails=None, tail_bound=1., num_blocks=2, use_residual_blocks=True, random_mask=False, activation=F.relu, dropout_probability=0., use_batch_norm=False, min_bin_width=splines.rational_quadratic.DEFAULT_MIN_BIN_WIDTH, min_bin_height=splines.rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, min_derivative=splines.rational_quadratic.DEFAULT_MIN_DERIVATIVE): self.num_bins = num_bins self.min_bin_width = min_bin_width self.min_bin_height = min_bin_height self.min_derivative = min_derivative self.tails = tails self.tail_bound = tail_bound autoregressive_net = made_module.MADE( features=features, hidden_features=hidden_features, context_features=context_features, num_blocks=num_blocks, output_multiplier=self._output_dim_multiplier(), use_residual_blocks=use_residual_blocks, random_mask=random_mask, activation=activation, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm, ) super().__init__(autoregressive_net)