Пример #1
0
 def __init__(
     self,
     features,
     hidden_features,
     context_features=None,
     num_blocks=2,
     use_residual_blocks=True,
     random_mask=False,
     activation=F.relu,
     dropout_probability=0.0,
     use_batch_norm=False,
 ):
     self.features = features
     made = made_module.MADE(
         features=features,
         hidden_features=hidden_features,
         context_features=context_features,
         num_blocks=num_blocks,
         output_multiplier=self._output_dim_multiplier(),
         use_residual_blocks=use_residual_blocks,
         random_mask=random_mask,
         activation=activation,
         dropout_probability=dropout_probability,
         use_batch_norm=use_batch_norm,
     )
     self._epsilon = 1e-3
     super(MaskedAffineAutoregressiveTransform, self).__init__(made)
Пример #2
0
 def __init__(
     self,
     num_bins,
     features,
     hidden_features,
     context_features=None,
     num_blocks=2,
     use_residual_blocks=True,
     random_mask=False,
     activation=F.relu,
     dropout_probability=0.0,
     use_batch_norm=False,
 ):
     self.num_bins = num_bins
     self.features = features
     made = made_module.MADE(
         features=features,
         hidden_features=hidden_features,
         context_features=context_features,
         num_blocks=num_blocks,
         output_multiplier=self._output_dim_multiplier(),
         use_residual_blocks=use_residual_blocks,
         random_mask=random_mask,
         activation=activation,
         dropout_probability=dropout_probability,
         use_batch_norm=use_batch_norm,
     )
     super().__init__(made)
Пример #3
0
    def test_gradients(self):
        features = 10
        hidden_features = 256
        num_blocks = 20
        output_multiplier = 3

        for use_residual_blocks, random_mask in [
            (False, False),
            (False, True),
            (True, False),
        ]:
            with self.subTest(use_residual_blocks=use_residual_blocks,
                              random_mask=random_mask):
                model = made.MADE(
                    features=features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks,
                    output_multiplier=output_multiplier,
                    use_residual_blocks=use_residual_blocks,
                    random_mask=random_mask,
                )
                inputs = torch.randn(1, features)
                inputs.requires_grad = True
                for k in range(features * output_multiplier):
                    outputs = model(inputs)
                    outputs[0, k].backward()
                    depends = inputs.grad.data[0] != 0.0
                    dim = k // output_multiplier
                    self.assertEqual(torch.all(depends[dim:] == 0), 1)
Пример #4
0
    def test_unconditional(self):
        features = 100
        hidden_features = 200
        num_blocks = 5
        output_multiplier = 3
        batch_size = 16

        inputs = torch.randn(batch_size, features)

        for use_residual_blocks, random_mask in [
            (False, False),
            (False, True),
            (True, False),
        ]:
            with self.subTest(use_residual_blocks=use_residual_blocks,
                              random_mask=random_mask):
                model = made.MADE(
                    features=features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks,
                    output_multiplier=output_multiplier,
                    context_features=None,
                    use_residual_blocks=use_residual_blocks,
                    random_mask=random_mask,
                )
                outputs = model(inputs)
                self.assertEqual(outputs.dim(), 2)
                self.assertEqual(outputs.shape[0], batch_size)
                self.assertEqual(outputs.shape[1],
                                 output_multiplier * features)
Пример #5
0
    def test_total_mask_sequential(self):
        features = 10
        hidden_features = 50
        num_blocks = 5
        output_multiplier = 1

        for use_residual_blocks in [True, False]:
            with self.subTest(use_residual_blocks=use_residual_blocks):
                model = made.MADE(
                    features=features,
                    hidden_features=hidden_features,
                    num_blocks=num_blocks,
                    output_multiplier=output_multiplier,
                    use_residual_blocks=use_residual_blocks,
                    random_mask=False,
                )
                total_mask = model.initial_layer.mask
                for block in model.blocks:
                    if use_residual_blocks:
                        self.assertIsInstance(block, made.MaskedResidualBlock)
                        total_mask = block.linear_layers[0].mask @ total_mask
                        total_mask = block.linear_layers[1].mask @ total_mask
                    else:
                        self.assertIsInstance(block,
                                              made.MaskedFeedforwardBlock)
                        total_mask = block.linear.mask @ total_mask
                total_mask = model.final_layer.mask @ total_mask
                total_mask = (total_mask > 0).float()
                reference = torch.tril(torch.ones([features, features]), -1)
                self.assertEqual(total_mask, reference)
Пример #6
0
    def __init__(self,
                 features,
                 hidden_features,
                 lower_bounds,
                 upper_bounds,
                 context_features=None,
                 num_bins=10,
                 tails=None,
                 tail_bound=1.0,
                 num_blocks=2,
                 use_residual_blocks=True,
                 random_mask=False,
                 activation=F.relu,
                 dropout_probability=0.0,
                 use_batch_norm=False,
                 min_bin_width=rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
                 min_bin_height=rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
                 min_derivative=rational_quadratic.DEFAULT_MIN_DERIVATIVE,
                 permutation=0):
        self.num_bins = num_bins
        self.min_bin_width = min_bin_width
        self.min_bin_height = min_bin_height
        self.min_derivative = min_derivative
        self.tails = tails
        self.tail_bound = tail_bound
        self.lower_bounds = lower_bounds
        self.upper_bounds = upper_bounds
        self.permutation = permutation

        autoregressive_net = made_module.MADE(
            features=2 * features,
            hidden_features=hidden_features,
            autoregressive_features=features,
            context_features=context_features,
            num_blocks=num_blocks,
            output_multiplier=self._output_dim_multiplier(),
            use_residual_blocks=use_residual_blocks,
            random_mask=random_mask,
            activation=activation,
            dropout_probability=dropout_probability,
            use_batch_norm=use_batch_norm,
        )

        super().__init__(autoregressive_net)

        self.phase = torch.nn.Parameter(torch.zeros_like(self.lower_bounds))
Пример #7
0
    def test_total_mask_random(self):
        features = 10
        hidden_features = 50
        num_blocks = 5
        output_multiplier = 1

        model = made.MADE(
            features=features,
            hidden_features=hidden_features,
            num_blocks=num_blocks,
            output_multiplier=output_multiplier,
            use_residual_blocks=False,
            random_mask=True,
        )
        total_mask = model.initial_layer.mask
        for block in model.blocks:
            self.assertIsInstance(block, made.MaskedFeedforwardBlock)
            total_mask = block.linear.mask @ total_mask
        total_mask = model.final_layer.mask @ total_mask
        total_mask = (total_mask > 0).float()
        self.assertEqual(torch.triu(total_mask),
                         torch.zeros([features, features]))
Пример #8
0
 def __init__(
     self,
     features,
     hidden_features,
     context_features=None,
     num_bins=10,
     num_blocks=2,
     tails=None,
     tail_bound=1.0,
     use_residual_blocks=True,
     random_mask=False,
     activation=F.relu,
     dropout_probability=0.0,
     use_batch_norm=False,
     min_bin_width=rational_quadratic.DEFAULT_MIN_BIN_WIDTH,
     min_bin_height=rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
     min_derivative=rational_quadratic.DEFAULT_MIN_DERIVATIVE,
 ):
     self.num_bins = num_bins
     self.min_bin_width = min_bin_width
     self.min_bin_height = min_bin_height
     self.min_derivative = min_derivative
     self.tails = tails
     self.tail_bound = tail_bound
     self.features = features
     made = made_module.MADE(
         features=features,
         hidden_features=hidden_features,
         context_features=context_features,
         num_blocks=num_blocks,
         output_multiplier=self._output_dim_multiplier(),
         use_residual_blocks=use_residual_blocks,
         random_mask=random_mask,
         activation=activation,
         dropout_probability=dropout_probability,
         use_batch_norm=use_batch_norm,
     )
     super().__init__(made)