def _elementwise(self, inputs, elementwise_params, inverse): assert elementwise_params.shape[1] == self._params() batch_size = inputs.shape[0] single_pixel = False if inverse and inputs.dim() == 2: inputs = inputs.reshape(inputs.shape + (1, 1)) elementwise_params = elementwise_params.reshape( elementwise_params.shape + (1, 1)) single_pixel = True logit_weights, means, log_scales, unnormalized_corr = get_matching_multivariate_mixture_params_2d( elementwise_params, num_mixtures=self.num_mixtures) x = multivariate_cmol_transform(inputs=inputs, logit_weights=logit_weights, means=means, log_scales=log_scales, unnormalized_corr=unnormalized_corr, K=self.num_bins, eps=self.eps, max_iters=self.max_iters, mean_lambd=self.mean_lambd, inverse=inverse) if inverse: if single_pixel: x = x.squeeze(-1).squeeze(-1) return x else: z, ldj_elementwise = x ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() z, ldj_elementwise = splines.linear_spline(x, elementwise_params, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unconstrained_scale, shift = self._unconstrained_scale_and_shift( elementwise_params) scale = torch.sigmoid(unconstrained_scale + 2.) + 1e-3 z = scale * x + shift ldj = sum_except_batch(torch.log(scale)) return z, ldj
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths, unnormalized_heights = elementwise_params[ ..., :self.num_bins], elementwise_params[..., self.num_bins:] z, ldj_elementwise = splines.quadratic_spline( x, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths = elementwise_params[..., :self.num_bins] unnormalized_heights = elementwise_params[..., self.num_bins:2 * self.num_bins] unnorm_derivatives_left = elementwise_params[..., 2 * self.num_bins:2 * self.num_bins + 1] unnorm_derivatives_right = elementwise_params[..., 2 * self.num_bins + 1:] z, ldj_elementwise = splines.cubic_spline( x, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, unnorm_derivatives_left=unnorm_derivatives_left, unnorm_derivatives_right=unnorm_derivatives_right, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise(self, inputs, elementwise_params, inverse): assert elementwise_params.shape[-1] == self._output_dim_multiplier() logit_weights, means, log_scales = get_mixture_params( elementwise_params, num_mixtures=self.num_mixtures) x = cmol_transform(inputs=inputs, logit_weights=logit_weights, means=means, log_scales=log_scales, K=self.num_bins, eps=self.eps, max_iters=self.max_iters, inverse=inverse) if inverse: return x else: z, ldj_elementwise = x ldj = sum_except_batch(ldj_elementwise) return z, ldj
def log_prob(self, x): z_lower, z_upper = self.forward_transform(x) log_prob = sum_except_batch(torch.log( (z_upper - z_lower).clamp(1e-12))) # Prob. in [0,2^bits] return log_prob
def log_prob(self, x): return self.base_measure - 0.5 * sum_except_batch(x**2)