def _elementwise(self, inputs, elementwise_params, inverse):
        assert elementwise_params.shape[1] == self._params()
        batch_size = inputs.shape[0]

        single_pixel = False
        if inverse and inputs.dim() == 2:
            inputs = inputs.reshape(inputs.shape + (1, 1))
            elementwise_params = elementwise_params.reshape(
                elementwise_params.shape + (1, 1))
            single_pixel = True

        logit_weights, means, log_scales, unnormalized_corr = get_matching_multivariate_mixture_params_2d(
            elementwise_params, num_mixtures=self.num_mixtures)

        x = multivariate_cmol_transform(inputs=inputs,
                                        logit_weights=logit_weights,
                                        means=means,
                                        log_scales=log_scales,
                                        unnormalized_corr=unnormalized_corr,
                                        K=self.num_bins,
                                        eps=self.eps,
                                        max_iters=self.max_iters,
                                        mean_lambd=self.mean_lambd,
                                        inverse=inverse)

        if inverse:
            if single_pixel:
                x = x.squeeze(-1).squeeze(-1)
            return x
        else:
            z, ldj_elementwise = x
            ldj = sum_except_batch(ldj_elementwise)
            return z, ldj
 def _elementwise_forward(self, x, elementwise_params):
     assert elementwise_params.shape[-1] == self._output_dim_multiplier()
     z, ldj_elementwise = splines.linear_spline(x,
                                                elementwise_params,
                                                inverse=False)
     ldj = sum_except_batch(ldj_elementwise)
     return z, ldj
 def _elementwise_forward(self, x, elementwise_params):
     assert elementwise_params.shape[-1] == self._output_dim_multiplier()
     unconstrained_scale, shift = self._unconstrained_scale_and_shift(
         elementwise_params)
     scale = torch.sigmoid(unconstrained_scale + 2.) + 1e-3
     z = scale * x + shift
     ldj = sum_except_batch(torch.log(scale))
     return z, ldj
 def _elementwise_forward(self, x, elementwise_params):
     assert elementwise_params.shape[-1] == self._output_dim_multiplier()
     unnormalized_widths, unnormalized_heights = elementwise_params[
         ..., :self.num_bins], elementwise_params[..., self.num_bins:]
     z, ldj_elementwise = splines.quadratic_spline(
         x,
         unnormalized_widths=unnormalized_widths,
         unnormalized_heights=unnormalized_heights,
         inverse=False)
     ldj = sum_except_batch(ldj_elementwise)
     return z, ldj
 def _elementwise_forward(self, x, elementwise_params):
     assert elementwise_params.shape[-1] == self._output_dim_multiplier()
     unnormalized_widths = elementwise_params[..., :self.num_bins]
     unnormalized_heights = elementwise_params[..., self.num_bins:2 *
                                               self.num_bins]
     unnorm_derivatives_left = elementwise_params[..., 2 * self.num_bins:2 *
                                                  self.num_bins + 1]
     unnorm_derivatives_right = elementwise_params[...,
                                                   2 * self.num_bins + 1:]
     z, ldj_elementwise = splines.cubic_spline(
         x,
         unnormalized_widths=unnormalized_widths,
         unnormalized_heights=unnormalized_heights,
         unnorm_derivatives_left=unnorm_derivatives_left,
         unnorm_derivatives_right=unnorm_derivatives_right,
         inverse=False)
     ldj = sum_except_batch(ldj_elementwise)
     return z, ldj
    def _elementwise(self, inputs, elementwise_params, inverse):
        assert elementwise_params.shape[-1] == self._output_dim_multiplier()

        logit_weights, means, log_scales = get_mixture_params(
            elementwise_params, num_mixtures=self.num_mixtures)

        x = cmol_transform(inputs=inputs,
                           logit_weights=logit_weights,
                           means=means,
                           log_scales=log_scales,
                           K=self.num_bins,
                           eps=self.eps,
                           max_iters=self.max_iters,
                           inverse=inverse)

        if inverse:
            return x
        else:
            z, ldj_elementwise = x
            ldj = sum_except_batch(ldj_elementwise)
            return z, ldj
 def log_prob(self, x):
     z_lower, z_upper = self.forward_transform(x)
     log_prob = sum_except_batch(torch.log(
         (z_upper - z_lower).clamp(1e-12)))  # Prob. in [0,2^bits]
     return log_prob
예제 #8
0
 def log_prob(self, x):
     return self.base_measure - 0.5 * sum_except_batch(x**2)