def visualize_predictions(self, real_frames, gen_frames, actions=None): def concat_on_y_axis(x): x = tf.unstack(x, axis=1) x = tf.concat(x, axis=1) return x frames_gd = common_video.swap_time_and_batch_axes(real_frames) frames_pd = common_video.swap_time_and_batch_axes(gen_frames) if actions is not None: actions = common_video.swap_time_and_batch_axes(actions) if self.is_per_pixel_softmax: frames_pd_shape = common_layers.shape_list(frames_pd) frames_pd = tf.reshape(frames_pd, [-1, 256]) frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1)) frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3]) frames_gd = concat_on_y_axis(frames_gd) frames_pd = concat_on_y_axis(frames_pd) if actions is not None: actions = tf.clip_by_value(actions, 0, 1) summary("action_vid", tf.cast(actions * 255, tf.uint8)) actions = concat_on_y_axis(actions) side_by_side_video = tf.concat([frames_gd, frames_pd, actions], axis=2) else: side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video)
def visualize_predictions(self, real_frames, gen_frames): def concat_on_y_axis(x): x = tf.unstack(x, axis=1) x = tf.concat(x, axis=1) return x frames_gd = common_video.swap_time_and_batch_axes(real_frames) frames_pd = common_video.swap_time_and_batch_axes(gen_frames) frames_gd = concat_on_y_axis(frames_gd) frames_pd = concat_on_y_axis(frames_pd) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video)
def discriminator(self, frames): """3-D SNGAN discriminator. Args: frames: a list of batch-major tensors indexed by time. Returns: logits: 1-D Tensor with shape=batch_size. Positive logits imply that the discriminator thinks that it belongs to the true class. """ ndf = self.hparams.num_discriminator_filters frames = tf.stack(frames) # Switch from time-major axis to batch-major axis. frames = common_video.swap_time_and_batch_axes(frames) # 3-D Conv-net mapping inputs to activations. num_outputs = [ndf, ndf*2, ndf*2, ndf*4, ndf*4, ndf*8, ndf*8] kernel_sizes = [3, 4, 3, 4, 3, 4, 3] strides = [[1, 1, 1], [1, 2, 2], [1, 1, 1], [1, 2, 2], [1, 1, 1], [2, 2, 2], [1, 1, 1]] names = ["video_sn_conv0_0", "video_sn_conv0_1", "video_sn_conv1_0", "video_sn_conv1_1", "video_sn_conv2_0", "video_sn_conv2_1", "video_sn_conv3_0"] iterable = zip(num_outputs, kernel_sizes, strides, names) activations = frames for num_filters, kernel_size, stride, name in iterable: activations = self.pad_conv3d_lrelu(activations, num_filters, kernel_size, stride, name) num_fc_dimensions = self.get_fc_dimensions(strides, kernel_sizes) activations = tf.reshape(activations, (-1, num_fc_dimensions)) return tf.squeeze(tf.layers.dense(activations, 1))
def visualize_predictions(self, real_frames, gen_frames): def concat_on_y_axis(x): x = tf.unstack(x, axis=1) x = tf.concat(x, axis=1) return x frames_gd = common_video.swap_time_and_batch_axes(real_frames) frames_pd = common_video.swap_time_and_batch_axes(gen_frames) if self.is_per_pixel_softmax: frames_pd_shape = common_layers.shape_list(frames_pd) frames_pd = tf.reshape(frames_pd, [-1, 256]) frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1)) frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3]) frames_gd = concat_on_y_axis(frames_gd) frames_pd = concat_on_y_axis(frames_pd) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video)
def visualize_predictions(self, real_frames, gen_frames): def concat_on_y_axis(x): x = tf.unstack(x, axis=1) x = tf.concat(x, axis=1) return x frames_gd = common_video.swap_time_and_batch_axes(real_frames) frames_pd = common_video.swap_time_and_batch_axes(gen_frames) if self.is_per_pixel_softmax: frames_pd_shape = common_layers.shape_list(frames_pd) frames_pd = tf.reshape(frames_pd, [-1, 256]) frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1)) frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3]) frames_gd = concat_on_y_axis(frames_gd) frames_pd = concat_on_y_axis(frames_pd) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video)
def get_input_if_exists(self, features, key, batch_size, num_frames): if key in features: x = features[key] else: x = tf.zeros((batch_size, num_frames, 1, self.hparams.hidden_size)) return common_video.swap_time_and_batch_axes(x)
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) retvals = self.construct_model(images=all_frames, actions=all_actions, rewards=all_rewards) # retrieve tensors returned by the model contructor gen_images = retvals[0] gen_rewards = retvals[1] latent_means = retvals[2] latent_logvars = retvals[3] latent_means_p = retvals[4] latent_logvars_p = retvals[5] extra_loss = self.get_extra_loss(latent_means=latent_means, latent_logvars=latent_logvars, latent_means_p=latent_means_p, latent_logvars_p=latent_logvars_p) # Visualize predictions in Tensorboard if self.is_training: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if self.is_training and hparams.internal_loss: # add the loss for input frames as well. extra_gts = all_frames[1:hparams.video_num_input_frames] extra_gts = common_video.swap_time_and_batch_axes(extra_gts) extra_pds = gen_images[:hparams.video_num_input_frames - 1] extra_pds = common_video.swap_time_and_batch_axes(extra_pds) extra_raw_gts = features["inputs_raw"][:, 1:] recon_loss = self.get_extra_internal_loss(extra_raw_gts, extra_gts, extra_pds) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ del args, kwargs # Make a copy of features that can be used in the call to self # that builds the graph. new_features = {} new_features["inputs"] = features["inputs"] new_features["targets"] = features["infer_targets"] _, _ = self(new_features) # pylint: disable=not-callable if self.hparams.gen_mode == "unconditional": num_target_frames = 1 else: num_target_frames = self.hparams.video_num_target_frames ops = [ glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout ] var_scope = tf.variable_scope("next_frame_glow/body", reuse=True) all_frames = [] # If eps=None, images are sampled from the prior. with arg_scope(ops, init=False), var_scope: for target_frame in range(1, num_target_frames + 1): # subscript -> timestep, superscript -> level. # self.z_sample equals z^0_{t} (top-level latent) # (X_{t}, z^{1..l}_{t}) = Glow(z^0_{t}, z^{1..l}_{t-1}) # Get current set of cond_latents. cond_level, cond_level_latents = get_cond_latents( self.all_level_latents, self.hparams) glow_vals = glow_ops.encoder_decoder( "codec", self.z_sample, self.hparams, eps=None, reverse=True, cond_latents=cond_level_latents, states=self.level_states, condition=cond_level, temperature=self.temperature) predicted_frame, _, curr_latents, self.level_states = glow_vals all_frames.append(predicted_frame) self.all_level_latents.append(curr_latents) # Compute z^0_{t+1} = f(z^0_{t}) if target_frame < num_target_frames: cond_top, cond_top_latents = get_cond_latents( self.all_top_latents, self.hparams) prior_dist = self.top_prior(condition=cond_top, cond_latents=cond_top_latents) self.z_sample = prior_dist.sample() self.all_top_latents.append(self.z_sample) all_frames = tf.stack(all_frames) predicted_video = common_video.swap_time_and_batch_axes(all_frames) # The video-decode API requires the predicted video to be the same shape # as the target-video. Hence, for unconditional generation, # tile across time to ensure same shape. if self.hparams.gen_mode == "unconditional": predicted_video = tf.tile( predicted_video, [1, self.hparams.video_num_target_frames, 1, 1, 1]) predicted_video = glow_ops.postprocess(predicted_video) # Output of a single decode / sample. output_features = {} output_features["targets"] = tf.zeros_like(predicted_video) output_features["outputs"] = predicted_video output_features["scores"] = tf.zeros_like(predicted_video) return output_features
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) beta = self.get_beta() extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, beta=beta, true_frames=all_frames, gen_frames=gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove undeeded dimension. # TODO(mbz): clean this up! def fix_video_dims_and_concat_on_x_axis(x): x = tf.transpose(x, [1, 3, 4, 0, 2]) x = tf.reshape(x, [batch_size, 64, 3, -1]) x = tf.transpose(x, [0, 3, 1, 2]) return x frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames) frames_pd = fix_video_dims_and_concat_on_x_axis(predictions) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video) # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) return_targets = predictions if "target_reward" in features: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard if self.is_training: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if self.is_training and hparams.internal_loss: # add the MSE loss for input frames as well. extra_gts = all_frames[1:hparams.video_num_input_frames + 1] extra_gts = common_video.swap_time_and_batch_axes(extra_gts) extra_pds = gen_images[:hparams.video_num_input_frames] extra_pds = common_video.swap_time_and_batch_axes(extra_pds) if self._target_modality == "VideoModalityL2Raw": recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds) elif self._target_modality == "VideoModality": shape = common_layers.shape_list(extra_pds) updated_shape = shape[:-1] + [3, 256] extra_pds = tf.reshape(extra_pds, updated_shape) # Merge time and batch logits = tf.reshape(extra_pds, [-1] + updated_shape[2:]) targets_shape = common_layers.shape_list( features["targets_raw"]) targets = tf.reshape(features["targets_raw"], [-1] + targets_shape[2:]) mod = self.hparams.problem_hparams.target_modality["targets"] numerator, denominator = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing, cutoff=getattr(hparams, "video_modality_loss_cutoff", 0.01), weights_fn=mod.targets_weights_fn) recon_loss = numerator / denominator else: raise ValueError( "internal loss only supports specific modalities.") tf.summary.scalar("recon_extra", recon_loss) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) step_num = tf.train.get_global_step() # TODO(mbz): what should it be if it"s undefined? if step_num is None: step_num = _LARGE_STEP_NUMBER schedule = self.hparams.latent_loss_multiplier_schedule second_stage = self.hparams.num_iterations_2nd_stage # TODO(mechcoder): Add log_annealing schedule. if schedule == "constant": beta = tf.cond(tf.greater(step_num, second_stage), lambda: self.hparams.latent_loss_multiplier, lambda: 0.0) elif schedule == "linear_anneal": # Linearly anneal beta from 0.0 to self.hparams.latent_loss_multiplier. # between self.hparams.num_iterations_2nd_stage to anneal_end. # beta = latent_loss * (1 - (global_step - 2nd_stage) / (anneal_end - 2nd_stage)) # pylint:disable=line-too-long anneal_end = self.hparams.anneal_end latent_multiplier = self.hparams.latent_loss_multiplier if anneal_end < second_stage: raise ValueError("Expected hparams.num_iterations_2nd_stage < " "hparams.anneal_end %d, got %d." % (second_stage, anneal_end)) def anneal_loss(step_num): step_num = tf.cast(step_num, dtype=tf.float32) fraction = (float(anneal_end) - step_num) / (anneal_end - second_stage) return self.hparams.latent_loss_multiplier * (1 - fraction) beta = tf.case(pred_fn_pairs={ tf.less(step_num, second_stage): lambda: 0.0, tf.greater(step_num, anneal_end): lambda: latent_multiplier }, default=lambda: anneal_loss(step_num)) kl_loss = 0.0 if self.is_training: for i, (mean, std) in enumerate(zip(latent_means, latent_stds)): kl_loss += common_layers.kl_divergence(mean, std) tf.summary.histogram("posterior_mean_%d" % i, mean) tf.summary.histogram("posterior_std_%d" % i, std) tf.summary.scalar("beta", beta) tf.summary.scalar("kl_raw", tf.reduce_mean(kl_loss)) extra_loss = beta * kl_loss # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove undeeded dimension. # TODO(mbz): clean this up! def fix_video_dims_and_concat_on_x_axis(x): x = tf.transpose(x, [1, 3, 4, 0, 2]) x = tf.reshape(x, [batch_size, 64, 3, -1]) x = tf.transpose(x, [0, 3, 1, 2]) return x frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames) frames_pd = fix_video_dims_and_concat_on_x_axis(predictions) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2) tf.summary.image("full_video", side_by_side_video) # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) return_targets = predictions if "target_reward" in features: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes(features["inputs"]) target_frames = common_video.swap_time_and_batch_axes(features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss( latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard if self.is_training: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames-1:] reward_pred = gen_rewards[hparams.video_num_input_frames-1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if self.is_training and hparams.internal_loss: # add the loss for input frames as well. extra_gts = all_frames[1:hparams.video_num_input_frames] extra_gts = common_video.swap_time_and_batch_axes(extra_gts) extra_pds = gen_images[:hparams.video_num_input_frames-1] extra_pds = common_video.swap_time_and_batch_axes(extra_pds) extra_raw_gts = features["inputs_raw"][:, 1:] recon_loss = self.get_extra_internal_loss( extra_raw_gts, extra_gts, extra_pds) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = {"targets": predictions, "target_reward": reward_pred} return return_targets, extra_loss
def get_input_if_exists(self, features, key, batch_size, num_frames): if key in features: x = features[key] else: x = tf.zeros((batch_size, num_frames, 1, self.hparams.hidden_size)) return common_video.swap_time_and_batch_axes(x)
def body(self, features): hparams = self.hparams input_shape = common_layers.shape_list(features['inputs']) batch_size, _, frame_width, frame_height, frame_channels = input_shape # pylint: disable=unused-variable # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( tf.to_float(features['inputs'])) target_frames = common_video.swap_time_and_batch_axes( features['targets']) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, 'input_action', batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, 'target_action', batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros # TODO(blazej) enable rewards. # input_rewards = self.get_input_if_exists( # features, 'input_reward', batch_size, hparams.video_num_input_frames) # target_rewards = self.get_input_if_exists( # features, 'target_reward', batch_size,hparams.video_num_target_frames) # all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_actions = tf.concat([input_actions, target_actions], axis=0) # flatten actions tensor to have the shape: framesXbatch_sizeXaction_dims. actions_shape = common_layers.shape_list(all_actions) all_actions = tf.reshape(all_actions, [ actions_shape[0], -1, reduce(lambda x, y: x * y, actions_shape[2:]) ]) all_frames = tf.concat([input_frames, target_frames], axis=0) all_frames = tf.unstack(all_frames, axis=0) all_actions = tf.unstack(all_actions, axis=0) # TODO(blazej) - most likely this downsize is too strong. all_frames = [ tf.image.resize_images(image, (IMG_HEIGHT, IMG_WIDTH), method=tf.image.ResizeMethod.BICUBIC) for image in all_frames ] enc_out_all, pred_out_all, _, van_on_enc_all = construct_model( all_frames, all_actions, context_frames=hparams.context_frames, hparams=hparams, is_training=self.is_training) enc_pred_loss, _ = calc_loss_psnr( enc_out_all[1:], pred_out_all, 'enc_pred_loss', hparams=hparams, use_l1_loss=hparams.enc_pred_use_l1_loss) van_on_enc_loss, _ = calc_loss_psnr(van_on_enc_all, all_frames[1:], 'van_on_enc_loss', hparams=hparams) enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1) enc_pred_loss_scale = tf.nn.sigmoid( (tf.to_float(tf.train.get_or_create_global_step()) - enc_pred_loss_scale_delay) / (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale) epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss tf.summary.scalar('epva_loss', epva_loss) predictions = tf.stack(van_on_enc_all) if hparams.clip_pixel_values: predictions = tf.clip_by_value(predictions, 0.0, 1.0) # TODO(mbz): clean this up! def fix_video_dims_and_concat_on_x_axis(x): x = tf.transpose(x, [1, 3, 4, 0, 2]) x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1]) x = tf.transpose(x, [0, 3, 1, 2]) return x frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames) frames_pd = fix_video_dims_and_concat_on_x_axis(predictions) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1) tf.summary.image('full_video', side_by_side_video) predictions = tf.unstack(predictions) predictions = [ tf.image.resize_images(image, (frame_width, frame_height), method=tf.image.ResizeMethod.BICUBIC) for image in predictions ] predictions = tf.stack(predictions) predictions = common_video.swap_time_and_batch_axes(predictions) predictions = tf.slice( predictions, [0, hparams.video_num_input_frames - 1, 0, 0, 0], [-1] * 5) return predictions, {'extra': epva_loss}
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if hparams.internal_loss: # add the MSE loss for input frames as well. # we are assuming the modality is L2. otherwise the loss would be # incosistent across the frames. modality = self.hparams.problem_hparams.target_modality["targets"] if modality.__class__.__name__ != "VideoModalityL2Raw": raise ValueError("internal loss only works with L2.") recon_loss = tf.losses.mean_squared_error( all_frames[1:hparams.video_num_input_frames + 1], gen_images[:hparams.video_num_input_frames]) tf.summary.scalar("mse_extra", recon_loss) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) beta = self.get_beta() extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, beta=beta, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] if self.is_training: reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } if hparams.internal_loss: loss = tf.losses.mean_squared_error(all_frames[1:], gen_images) extra_loss = {"training": loss + extra_loss} return return_targets, extra_loss
def body(self, features): hparams = self.hparams input_shape = common_layers.shape_list(features['inputs']) batch_size, _, frame_width, frame_height, frame_channels = input_shape # pylint: disable=unused-variable # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( tf.to_float(features['inputs'])) target_frames = common_video.swap_time_and_batch_axes(features['targets']) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, 'input_action', batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, 'target_action', batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros # TODO(blazej) enable rewards. # input_rewards = self.get_input_if_exists( # features, 'input_reward', batch_size, hparams.video_num_input_frames) # target_rewards = self.get_input_if_exists( # features, 'target_reward', batch_size,hparams.video_num_target_frames) # all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_actions = tf.concat([input_actions, target_actions], axis=0) # flatten actions tensor to have the shape: framesXbatch_sizeXaction_dims. actions_shape = common_layers.shape_list(all_actions) all_actions = tf.reshape( all_actions, [actions_shape[0], -1, reduce(lambda x, y: x * y, actions_shape[2:])]) all_frames = tf.concat([input_frames, target_frames], axis=0) all_frames = tf.unstack(all_frames, axis=0) all_actions = tf.unstack(all_actions, axis=0) # TODO(blazej) - most likely this downsize is too strong. all_frames = [ tf.image.resize_images( image, (IMG_HEIGHT, IMG_WIDTH), method=tf.image.ResizeMethod.BICUBIC) for image in all_frames ] enc_out_all, pred_out_all, _, van_on_enc_all = construct_model( all_frames, all_actions, context_frames=hparams.context_frames, hparams=hparams, is_training=self.is_training) enc_pred_loss, _ = calc_loss_psnr( enc_out_all[1:], pred_out_all, 'enc_pred_loss', hparams=hparams, use_l1_loss=hparams.enc_pred_use_l1_loss) van_on_enc_loss, _ = calc_loss_psnr( van_on_enc_all, all_frames[1:], 'van_on_enc_loss', hparams=hparams) enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1) enc_pred_loss_scale = tf.nn.sigmoid( (tf.to_float(tf.train.get_or_create_global_step() ) - enc_pred_loss_scale_delay) / (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale) epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss tf.summary.scalar('epva_loss', epva_loss) predictions = tf.stack(van_on_enc_all) if hparams.clip_pixel_values: predictions = tf.clip_by_value(predictions, 0.0, 1.0) # TODO(mbz): clean this up! def fix_video_dims_and_concat_on_x_axis(x): x = tf.transpose(x, [1, 3, 4, 0, 2]) x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1]) x = tf.transpose(x, [0, 3, 1, 2]) return x frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames) frames_pd = fix_video_dims_and_concat_on_x_axis(predictions) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1) tf.summary.image('full_video', side_by_side_video) predictions = tf.unstack(predictions) predictions = [ tf.image.resize_images( image, (frame_width, frame_height), method=tf.image.ResizeMethod.BICUBIC) for image in predictions ] predictions = tf.stack(predictions) predictions = common_video.swap_time_and_batch_axes(predictions) predictions = tf.slice(predictions, [0, hparams.video_num_input_frames-1, 0, 0, 0], [-1]*5) return predictions, {'extra': epva_loss}