Exemplo n.º 1
0
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size = inputs.shape[0]

        transform_params = autoregressive_params.view(batch_size,
                                                      self.features,
                                                      self.num_bins * 2 + 2)

        unnormalized_widths = transform_params[..., :self.num_bins]
        unnormalized_heights = transform_params[..., self.num_bins:2 *
                                                self.num_bins]
        derivatives = transform_params[..., 2 * self.num_bins:]
        unnorm_derivatives_left = derivatives[..., 0][..., None]
        unnorm_derivatives_right = derivatives[..., 1][..., None]

        if hasattr(self.autoregressive_net, 'hidden_features'):
            unnormalized_widths /= np.sqrt(
                self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(
                self.autoregressive_net.hidden_features)

        outputs, logabsdet = splines.cubic_spline(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnorm_derivatives_left=unnorm_derivatives_left,
            unnorm_derivatives_right=unnorm_derivatives_right,
            inverse=inverse)
        return outputs, utils.sum_except_batch(logabsdet)
Exemplo n.º 2
0
 def _elementwise_inverse(self, inputs, autoregressive_params):
     unconstrained_scale, shift = self._unconstrained_scale_and_shift(
         autoregressive_params)
     scale = torch.sigmoid(unconstrained_scale + 2.) + 1e-3
     log_scale = torch.log(scale)
     outputs = (inputs - shift) / scale
     logabsdet = -utils.sum_except_batch(log_scale, num_batch_dims=1)
     return outputs, logabsdet
Exemplo n.º 3
0
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = autoregressive_params.view(
            batch_size, self.features, self._output_dim_multiplier())

        outputs, logabsdet = splines.linear_spline(
            inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse)

        return outputs, utils.sum_except_batch(logabsdet)
Exemplo n.º 4
0
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size, features = inputs.shape[0], inputs.shape[1]

        transform_params = autoregressive_params.view(
            batch_size, features, self._output_dim_multiplier())

        unnormalized_widths = transform_params[..., :self.num_bins]
        unnormalized_heights = transform_params[..., self.num_bins:2 *
                                                self.num_bins]
        unnormalized_derivatives = transform_params[..., 2 * self.num_bins:]

        if hasattr(self.autoregressive_net, 'hidden_features'):
            unnormalized_widths /= np.sqrt(
                self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(
                self.autoregressive_net.hidden_features)

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        elif self.tails == 'linear':
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {
                'tails': self.tails,
                'tail_bound': self.tail_bound
            }
        else:
            raise ValueError

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs)

        return outputs, utils.sum_except_batch(logabsdet)