Exemplo n.º 1
0
    def inverse(self, y, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(y, tuple)
        cf, ef = y
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = float_ef
        if self.with_affine:
            out = tf.exp(-s) * (out - t)
        out, invscale_logd = self.scale_flow.inverse(out)
        out = tf.clip_by_value(out, 1e-5, 1. - 1e-5)
        out = mixlogistic_invcdf(y=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)

        logd = mixlogistic_logpdf(x=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = -tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert invscale_logd.shape == logd.shape
        logd += invscale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd
Exemplo n.º 2
0
    def forward(self, x, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs):
        assert isinstance(x, tuple)
        cf, ef = x
        float_ef = at_least_float32(ef)
        s, t, ml_logits, ml_means, ml_logscales = self.template(
            cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context
        )

        out = tf.exp(
            mixlogistic_logcdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        )
        out, scale_logd = self.scale_flow.forward(out)
        if self.with_affine:
            assert out.shape == s.shape == t.shape
            out = tf.exp(s) * out + t

        logd = mixlogistic_logpdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales)
        if self.with_affine:
            assert s.shape == logd.shape
            logd += s
        logd = tf.reduce_sum(tf.layers.flatten(logd), axis=1)
        assert scale_logd.shape == logd.shape
        logd += scale_logd

        out, logd = map(to_default_floatx, [out, logd])
        assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype
        return (cf, out), logd