def _elementwise_inverse(self, z, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unconstrained_scale, shift, logit_weights, means, log_scales = get_flowpp_params(elementwise_params, num_mixtures=self.num_mixtures) scale = self.scale_fn(unconstrained_scale) log_scales = log_scales.clamp(min=-7) # From the code in original Flow++ paper x = (z - shift) / scale x = x.clamp(1e-5, 1.0 - 1e-5) x = logistic_mixture_transform(inputs=x, logit_weights=logit_weights, means=means, log_scales=log_scales, eps=self.eps, max_iters=self.max_iters, inverse=True) return x
def _elementwise(self, inputs, elementwise_params, inverse): assert elementwise_params.shape[-1] == self._output_dim_multiplier() logit_weights, means, log_scales = get_mixture_params(elementwise_params, num_mixtures=self.num_mixtures) x = logistic_mixture_transform(inputs=inputs, logit_weights=logit_weights, means=means, log_scales=log_scales, eps=self.eps, max_iters=self.max_iters, inverse=inverse) if inverse: return x else: z, ldj_elementwise = x ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unconstrained_scale, shift, logit_weights, means, log_scales = get_flowpp_params(elementwise_params, num_mixtures=self.num_mixtures) scale = self.scale_fn(unconstrained_scale) log_scales = log_scales.clamp(min=-7) # From the code in original Flow++ paper x, ldj_elementwise = logistic_mixture_transform(inputs=x, logit_weights=logit_weights, means=means, log_scales=log_scales, eps=self.eps, max_iters=self.max_iters, inverse=False) # affine transformation z = scale * x + shift logistic_ldj = sum_except_batch(ldj_elementwise) scale_ldj = sum_except_batch(torch.log(scale)) ldj = logistic_ldj + scale_ldj return z, ldj