Ejemplo n.º 1
0
 def _elementwise_inverse(self, z, elementwise_params):
     assert elementwise_params.shape[-1] == self._output_dim_multiplier()
     unconstrained_scale, shift, logit_weights, means, log_scales = get_flowpp_params(elementwise_params, num_mixtures=self.num_mixtures)
     scale = self.scale_fn(unconstrained_scale)
     log_scales = log_scales.clamp(min=-7)  # From the code in original Flow++ paper
     x = (z - shift) / scale
     x = x.clamp(1e-5, 1.0 - 1e-5)
     x = logistic_mixture_transform(inputs=x,
                                    logit_weights=logit_weights,
                                    means=means,
                                    log_scales=log_scales,
                                    eps=self.eps,
                                    max_iters=self.max_iters,
                                    inverse=True)
     return x
Ejemplo n.º 2
0
    def _elementwise(self, inputs, elementwise_params, inverse):
        assert elementwise_params.shape[-1] == self._output_dim_multiplier()

        logit_weights, means, log_scales = get_mixture_params(elementwise_params, num_mixtures=self.num_mixtures)

        x = logistic_mixture_transform(inputs=inputs,
                                       logit_weights=logit_weights,
                                       means=means,
                                       log_scales=log_scales,
                                       eps=self.eps,
                                       max_iters=self.max_iters,
                                       inverse=inverse)

        if inverse:
            return x
        else:
            z, ldj_elementwise = x
            ldj = sum_except_batch(ldj_elementwise)
            return z, ldj
Ejemplo n.º 3
0
    def _elementwise_forward(self, x, elementwise_params):
        assert elementwise_params.shape[-1] == self._output_dim_multiplier()
        unconstrained_scale, shift, logit_weights, means, log_scales = get_flowpp_params(elementwise_params, num_mixtures=self.num_mixtures)
        scale = self.scale_fn(unconstrained_scale)
        log_scales = log_scales.clamp(min=-7)  # From the code in original Flow++ paper

        x, ldj_elementwise = logistic_mixture_transform(inputs=x,
                                                        logit_weights=logit_weights,
                                                        means=means,
                                                        log_scales=log_scales,
                                                        eps=self.eps,
                                                        max_iters=self.max_iters,
                                                        inverse=False)

        # affine transformation
        z = scale * x + shift
        logistic_ldj = sum_except_batch(ldj_elementwise)
        scale_ldj = sum_except_batch(torch.log(scale))
        ldj = logistic_ldj + scale_ldj
        return z, ldj