def top_cond_prior(self, name, cond_top_latents): """Maps the conditional top latents to a distribution. Args: name: variable scope. cond_top_latents: Tensor or a list of tensors. Latent variables at the previous time-step. If "pointwise", this is a single tensor. If "conv_net", this is a list of tensors with length equal to hparams.num_cond_latents. Returns: cond_dist: tfp.distributions.Normal Raises: ValueError: If cond_top_latents are not of the expected length. """ with tf.variable_scope("top", reuse=tf.AUTO_REUSE): if self.hparams.latent_dist_encoder == "pointwise": last_latent = cond_top_latents top = glow_ops.scale_gaussian_prior( name, cond_top_latents, trainable=self.hparams.learn_top_scale) elif self.hparams.latent_dist_encoder == "conv_net": num_cond_latents = (self.hparams.num_cond_latents + int(self.hparams.cond_first_frame)) if len(cond_top_latents) != num_cond_latents: raise ValueError( "Expected length of cond_top_latents %d, got %d" % (num_cond_latents, len(cond_top_latents))) last_latent = cond_top_latents[-1] output_channels = common_layers.shape_list(last_latent)[-1] cond_top_latents = tf.concat(cond_top_latents, axis=-1) # Maps the latent-stack to a distribution. cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) top = glow_ops.latent_to_dist( name, cond_top_latents, hparams=self.hparams, output_channels=output_channels) elif self.hparams.latent_dist_encoder == "conv_lstm": last_latent = cond_top_latents output_channels = common_layers.shape_list(cond_top_latents)[-1] # (h_t, c_t) = LSTM(z_{t-1}; (h_{t-1}, c_{t-1})) # (mu_t, sigma_t) = conv(h_t) cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) _, self.top_state = common_video.conv_lstm_2d( cond_top_latents, self.top_state, self.hparams.latent_encoder_width, kernel_size=3, name="conv_lstm") top = glow_ops.single_conv_dist( name, self.top_state.h, output_channels=output_channels) elif self.hparams.latent_dist_encoder == "conv3d_net": last_latent = cond_top_latents[-1] cond_top_latents = tf.stack(cond_top_latents, axis=1) cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) top = glow_ops.temporal_latent_to_dist( "conv3d", cond_top_latents, self.hparams) # mu(z_{t}) = z_{t-1} + latent_encoder(z_{cond}) if self.hparams.latent_skip: top = tfp.distributions.Normal(last_latent + top.loc, top.scale) return top
def check_latent_to_dist(self, architecture): with tf.Graph().as_default(): x = tf.random_uniform(shape=(16, 5, 5, 32)) hparams = tf.contrib.training.HParams(architecture=architecture) x_prior = glow_ops.latent_to_dist("split_prior", x, hparams=hparams, output_channels=64) mean_t, scale_t = x_prior.loc, x_prior.scale with tf.Session() as session: session.run(tf.global_variables_initializer()) mean, scale = session.run([mean_t, scale_t]) self.assertEqual(mean.shape, (16, 5, 5, 64)) self.assertEqual(scale.shape, (16, 5, 5, 64)) self.assertTrue(np.allclose(mean, 0.0)) self.assertTrue(np.allclose(scale, 1.0))