def top_cond_prior(self, name, cond_top_latents): """Maps the conditional top latents to a distribution. Args: name: variable scope. cond_top_latents: Tensor or a list of tensors. Latent variables at the previous time-step. If "pointwise", this is a single tensor. If "conv_net", this is a list of tensors with length equal to hparams.num_cond_latents. Returns: cond_dist: tfp.distributions.Normal Raises: ValueError: If cond_top_latents are not of the expected length. """ with tf.variable_scope("top", reuse=tf.AUTO_REUSE): if self.hparams.latent_dist_encoder == "pointwise": last_latent = cond_top_latents top = glow_ops.scale_gaussian_prior( name, cond_top_latents, trainable=self.hparams.learn_top_scale) elif self.hparams.latent_dist_encoder == "conv_net": num_cond_latents = (self.hparams.num_cond_latents + int(self.hparams.cond_first_frame)) if len(cond_top_latents) != num_cond_latents: raise ValueError( "Expected length of cond_top_latents %d, got %d" % (num_cond_latents, len(cond_top_latents))) last_latent = cond_top_latents[-1] output_channels = common_layers.shape_list(last_latent)[-1] cond_top_latents = tf.concat(cond_top_latents, axis=-1) # Maps the latent-stack to a distribution. cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) top = glow_ops.latent_to_dist( name, cond_top_latents, hparams=self.hparams, output_channels=output_channels) elif self.hparams.latent_dist_encoder == "conv_lstm": last_latent = cond_top_latents output_channels = common_layers.shape_list(cond_top_latents)[-1] # (h_t, c_t) = LSTM(z_{t-1}; (h_{t-1}, c_{t-1})) # (mu_t, sigma_t) = conv(h_t) cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) _, self.top_state = common_video.conv_lstm_2d( cond_top_latents, self.top_state, self.hparams.latent_encoder_width, kernel_size=3, name="conv_lstm") top = glow_ops.single_conv_dist( name, self.top_state.h, output_channels=output_channels) elif self.hparams.latent_dist_encoder == "conv3d_net": last_latent = cond_top_latents[-1] cond_top_latents = tf.stack(cond_top_latents, axis=1) cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams) top = glow_ops.temporal_latent_to_dist( "conv3d", cond_top_latents, self.hparams) # mu(z_{t}) = z_{t-1} + latent_encoder(z_{cond}) if self.hparams.latent_skip: top = tfp.distributions.Normal(last_latent + top.loc, top.scale) return top
def level_cond_prior(prior_dist, z, latent, hparams, state): """Returns a conditional prior for each level. Args: prior_dist: Distribution conditioned on the previous levels. z: Tensor, output of the previous levels. latent: Tensor or a list of tensors to condition the latent_distribution. hparams: next_frame_glow hparams. state: Current LSTM state. Used only if hparams.latent_dist_encoder is a lstm. Raises: ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape of latent is different from z. """ latent_dist_encoder = hparams.get("latent_dist_encoder", None) latent_skip = hparams.get("latent_skip", False) if latent_dist_encoder == "pointwise": merge_std = hparams.level_scale latent_shape = common_layers.shape_list(latent) z_shape = common_layers.shape_list(z) if latent_shape != z_shape: raise ValueError("Expected latent_shape to be %s, got %s" % (latent_shape, z_shape)) latent_dist = scale_gaussian_prior("latent_prior", latent, logscale_factor=3.0) cond_dist = merge_level_and_latent_dist(prior_dist, latent_dist, merge_std=merge_std) elif latent_dist_encoder == "conv_net": output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1) cond_dist = tensor_to_dist( "latent_stack", latent_stack, output_channels=output_channels, architecture=hparams.latent_architecture, depth=hparams.latent_encoder_depth, pre_output_channels=hparams.latent_pre_output_channels, width=hparams.latent_encoder_width) if latent_skip: cond_dist = tf.distributions.Normal(cond_dist.loc + latent[-1], cond_dist.scale) elif latent_dist_encoder == "conv_lstm": output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat((prior_dist.loc, latent), axis=-1) _, state = common_video.conv_lstm_2d(latent_stack, state, hparams.latent_encoder_width, kernel_size=3, name="conv_lstm") cond_dist = tensor_to_dist("state_to_dist", state.h, output_channels=output_channels) if latent_skip: cond_dist = tf.distributions.Normal(cond_dist.loc + latent, cond_dist.scale) return cond_dist.loc, cond_dist.scale, state
def level_cond_prior(prior_dist, z, latent, hparams, state): """Returns a conditional prior for each level. Args: prior_dist: Distribution conditioned on the previous levels. z: Tensor, output of the previous levels. latent: Tensor or a list of tensors to condition the latent_distribution. hparams: next_frame_glow hparams. state: Current LSTM state. Used only if hparams.latent_dist_encoder is a lstm. Raises: ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape of latent is different from z. """ latent_dist_encoder = hparams.get("latent_dist_encoder", None) latent_skip = hparams.get("latent_skip", False) if latent_dist_encoder == "pointwise": merge_std = hparams.level_scale latent_shape = common_layers.shape_list(latent) z_shape = common_layers.shape_list(z) if latent_shape != z_shape: raise ValueError("Expected latent_shape to be %s, got %s" % (latent_shape, z_shape)) latent_dist = scale_gaussian_prior("latent_prior", latent, logscale_factor=3.0) cond_dist = merge_level_and_latent_dist(prior_dist, latent_dist, merge_std=merge_std) elif latent_dist_encoder == "conv_net": output_channels = common_layers.shape_list(z)[-1] last_latent = latent[-1] latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1) cond_dist = latent_to_dist("latent_stack", latent_stack, hparams=hparams, output_channels=output_channels) elif latent_dist_encoder == "conv3d_net": last_latent = latent[-1] output_channels = common_layers.shape_list(last_latent)[-1] num_steps = len(latent) # Stack across time. cond_latents = tf.stack(latent, axis=1) # Concat latents from previous levels across channels. prev_latents = tf.tile(tf.expand_dims(prior_dist.loc, axis=1), [1, num_steps, 1, 1, 1]) cond_latents = tf.concat((cond_latents, prev_latents), axis=-1) cond_dist = temporal_latent_to_dist("latent_stack", cond_latents, hparams, output_channels=output_channels) elif latent_dist_encoder == "conv_lstm": last_latent = latent output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat((prior_dist.loc, latent), axis=-1) _, state = common_video.conv_lstm_2d(latent_stack, state, hparams.latent_encoder_width, kernel_size=3, name="conv_lstm") cond_dist = single_conv_dist("state_to_dist", state.h, output_channels=output_channels) if latent_skip: new_mean = cond_dist.loc + last_latent cond_dist = tf.distributions.Normal(new_mean, cond_dist.scale) return cond_dist.loc, cond_dist.scale, state
def compute_prior(name, z, latent, hparams, state=None): """Distribution on z_t conditioned on z_{t-1} and latent. Args: name: variable scope. z: 4-D Tensor. latent: optional, if hparams.latent_dist_encoder == "pointwise", this is a list of 4-D Tensors of length hparams.num_cond_latents. else, this is just a 4-D Tensor The first-three dimensions of the latent should be the same as z. hparams: next_frame_glow_hparams. state: tf.contrib.rnn.LSTMStateTuple. the current state of a LSTM used to model the distribution. Used only if hparams.latent_dist_encoder = "conv_lstm". Returns: prior_dist: instance of tf.distributions.Normal state: Returns updated state. Raises: ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape of latent is different from z. """ with tf.variable_scope(name, reuse=tf.AUTO_REUSE): prior_dist = tensor_to_dist("level_prior", z, architecture="single_conv") # TODO(mechcoder) Refactor into separate sub-functions. if latent is not None: latent_dist_encoder = hparams.get("latent_dist_encoder", None) latent_skip = hparams.get("latent_skip", False) if latent_dist_encoder == "pointwise": merge_std = hparams.level_scale latent_shape = common_layers.shape_list(latent) z_shape = common_layers.shape_list(z) if latent_shape != z_shape: raise ValueError("Expected latent_shape to be %s, got %s" % (latent_shape, z_shape)) latent_dist = scale_gaussian_prior("latent_prior", latent, logscale_factor=3.0) prior_dist = merge_level_and_latent_dist(prior_dist, latent_dist, merge_std=merge_std) elif latent_dist_encoder == "conv_net": output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1) prior_dist = tensor_to_dist( "latent_stack", latent_stack, output_channels=output_channels, architecture=hparams.latent_architecture, depth=hparams.latent_encoder_depth, pre_output_channels=hparams.latent_pre_output_channels) if latent_skip: prior_dist = tf.distributions.Normal( prior_dist.loc + latent[-1], prior_dist.scale) elif latent_dist_encoder == "conv_lstm": output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat((prior_dist.loc, latent), axis=-1) _, state = common_video.conv_lstm_2d(latent_stack, state, output_channels, kernel_size=3, name="conv_lstm") prior_dist = tensor_to_dist("state_to_dist", state.h, output_channels=output_channels) if latent_skip: prior_dist = tf.distributions.Normal( prior_dist.loc + latent, prior_dist.scale) tf.summary.histogram("split_prior_mean", prior_dist.loc) tf.summary.histogram("split_prior_scale", prior_dist.scale) return prior_dist, state