def inverse(self, y, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs): assert isinstance(y, tuple) cf, ef = y float_ef = at_least_float32(ef) s, t, ml_logits, ml_means, ml_logscales = self.template( cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context ) out = float_ef if self.with_affine: out = tf.exp(-s) * (out - t) out, invscale_logd = self.scale_flow.inverse(out) out = tf.clip_by_value(out, 1e-5, 1. - 1e-5) out = mixlogistic_invcdf(y=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales) logd = mixlogistic_logpdf(x=out, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales) if self.with_affine: assert s.shape == logd.shape logd += s logd = -tf.reduce_sum(tf.layers.flatten(logd), axis=1) assert invscale_logd.shape == logd.shape logd += invscale_logd out, logd = map(to_default_floatx, [out, logd]) assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype return (cf, out), logd
def forward(self, x, init=False, ema=None, dropout_p=0., verbose=True, context=None, **kwargs): assert isinstance(x, tuple) cf, ef = x float_ef = at_least_float32(ef) s, t, ml_logits, ml_means, ml_logscales = self.template( cf, init=init, ema=ema, dropout_p=dropout_p, verbose=verbose, context=context ) out = tf.exp( mixlogistic_logcdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales) ) out, scale_logd = self.scale_flow.forward(out) if self.with_affine: assert out.shape == s.shape == t.shape out = tf.exp(s) * out + t logd = mixlogistic_logpdf(x=float_ef, prior_logits=ml_logits, means=ml_means, logscales=ml_logscales) if self.with_affine: assert s.shape == logd.shape logd += s logd = tf.reduce_sum(tf.layers.flatten(logd), axis=1) assert scale_logd.shape == logd.shape logd += scale_logd out, logd = map(to_default_floatx, [out, logd]) assert out.shape == ef.shape == cf.shape and out.dtype == ef.dtype == logd.dtype == cf.dtype return (cf, out), logd