Example #1
0
    def visualize_predictions(self, real_frames, gen_frames, actions=None):
        def concat_on_y_axis(x):
            x = tf.unstack(x, axis=1)
            x = tf.concat(x, axis=1)
            return x

        frames_gd = common_video.swap_time_and_batch_axes(real_frames)
        frames_pd = common_video.swap_time_and_batch_axes(gen_frames)
        if actions is not None:
            actions = common_video.swap_time_and_batch_axes(actions)

        if self.is_per_pixel_softmax:
            frames_pd_shape = common_layers.shape_list(frames_pd)
            frames_pd = tf.reshape(frames_pd, [-1, 256])
            frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1))
            frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3])

        frames_gd = concat_on_y_axis(frames_gd)
        frames_pd = concat_on_y_axis(frames_pd)
        if actions is not None:
            actions = tf.clip_by_value(actions, 0, 1)
            summary("action_vid", tf.cast(actions * 255, tf.uint8))
            actions = concat_on_y_axis(actions)
            side_by_side_video = tf.concat([frames_gd, frames_pd, actions],
                                           axis=2)
        else:
            side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
        tf.summary.image("full_video", side_by_side_video)
Example #2
0
    def visualize_predictions(self, real_frames, gen_frames):
        def concat_on_y_axis(x):
            x = tf.unstack(x, axis=1)
            x = tf.concat(x, axis=1)
            return x

        frames_gd = common_video.swap_time_and_batch_axes(real_frames)
        frames_pd = common_video.swap_time_and_batch_axes(gen_frames)
        frames_gd = concat_on_y_axis(frames_gd)
        frames_pd = concat_on_y_axis(frames_pd)
        side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
        tf.summary.image("full_video", side_by_side_video)
  def discriminator(self, frames):
    """3-D SNGAN discriminator.

    Args:
      frames: a list of batch-major tensors indexed by time.

    Returns:
      logits: 1-D Tensor with shape=batch_size.
              Positive logits imply that the discriminator thinks that it
              belongs to the true class.
    """
    ndf = self.hparams.num_discriminator_filters
    frames = tf.stack(frames)

    # Switch from time-major axis to batch-major axis.
    frames = common_video.swap_time_and_batch_axes(frames)

    # 3-D Conv-net mapping inputs to activations.
    num_outputs = [ndf, ndf*2, ndf*2, ndf*4, ndf*4, ndf*8, ndf*8]
    kernel_sizes = [3, 4, 3, 4, 3, 4, 3]
    strides = [[1, 1, 1], [1, 2, 2], [1, 1, 1], [1, 2, 2], [1, 1, 1],
               [2, 2, 2], [1, 1, 1]]

    names = ["video_sn_conv0_0", "video_sn_conv0_1", "video_sn_conv1_0",
             "video_sn_conv1_1", "video_sn_conv2_0", "video_sn_conv2_1",
             "video_sn_conv3_0"]
    iterable = zip(num_outputs, kernel_sizes, strides, names)
    activations = frames
    for num_filters, kernel_size, stride, name in iterable:
      activations = self.pad_conv3d_lrelu(activations, num_filters, kernel_size,
                                          stride, name)
    num_fc_dimensions = self.get_fc_dimensions(strides, kernel_sizes)
    activations = tf.reshape(activations, (-1, num_fc_dimensions))
    return tf.squeeze(tf.layers.dense(activations, 1))
Example #4
0
  def visualize_predictions(self, real_frames, gen_frames):
    def concat_on_y_axis(x):
      x = tf.unstack(x, axis=1)
      x = tf.concat(x, axis=1)
      return x

    frames_gd = common_video.swap_time_and_batch_axes(real_frames)
    frames_pd = common_video.swap_time_and_batch_axes(gen_frames)

    if self.is_per_pixel_softmax:
      frames_pd_shape = common_layers.shape_list(frames_pd)
      frames_pd = tf.reshape(frames_pd, [-1, 256])
      frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1))
      frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3])

    frames_gd = concat_on_y_axis(frames_gd)
    frames_pd = concat_on_y_axis(frames_pd)
    side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
    tf.summary.image("full_video", side_by_side_video)
Example #5
0
    def visualize_predictions(self, real_frames, gen_frames):
        def concat_on_y_axis(x):
            x = tf.unstack(x, axis=1)
            x = tf.concat(x, axis=1)
            return x

        frames_gd = common_video.swap_time_and_batch_axes(real_frames)
        frames_pd = common_video.swap_time_and_batch_axes(gen_frames)

        if self.is_per_pixel_softmax:
            frames_pd_shape = common_layers.shape_list(frames_pd)
            frames_pd = tf.reshape(frames_pd, [-1, 256])
            frames_pd = tf.to_float(tf.argmax(frames_pd, axis=-1))
            frames_pd = tf.reshape(frames_pd, frames_pd_shape[:-1] + [3])

        frames_gd = concat_on_y_axis(frames_gd)
        frames_pd = concat_on_y_axis(frames_pd)
        side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
        tf.summary.image("full_video", side_by_side_video)
Example #6
0
 def get_input_if_exists(self, features, key, batch_size, num_frames):
     if key in features:
         x = features[key]
     else:
         x = tf.zeros((batch_size, num_frames, 1, self.hparams.hidden_size))
     return common_video.swap_time_and_batch_axes(x)
Example #7
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        retvals = self.construct_model(images=all_frames,
                                       actions=all_actions,
                                       rewards=all_rewards)

        # retrieve tensors returned by the model contructor
        gen_images = retvals[0]
        gen_rewards = retvals[1]
        latent_means = retvals[2]
        latent_logvars = retvals[3]
        latent_means_p = retvals[4]
        latent_logvars_p = retvals[5]

        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_logvars=latent_logvars,
                                         latent_means_p=latent_means_p,
                                         latent_logvars_p=latent_logvars_p)

        # Visualize predictions in Tensorboard
        if self.is_training:
            self.visualize_predictions(all_frames[1:], gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove extra dimension.

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        if self.is_training and hparams.internal_loss:
            # add the loss for input frames as well.
            extra_gts = all_frames[1:hparams.video_num_input_frames]
            extra_gts = common_video.swap_time_and_batch_axes(extra_gts)
            extra_pds = gen_images[:hparams.video_num_input_frames - 1]
            extra_pds = common_video.swap_time_and_batch_axes(extra_pds)
            extra_raw_gts = features["inputs_raw"][:, 1:]
            recon_loss = self.get_extra_internal_loss(extra_raw_gts, extra_gts,
                                                      extra_pds)
            extra_loss += recon_loss

        return_targets = predictions
        if hparams.reward_prediction:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Example #8
0
    def infer(self, features, *args, **kwargs):  # pylint: disable=arguments-differ
        del args, kwargs

        # Make a copy of features that can be used in the call to self
        # that builds the graph.
        new_features = {}
        new_features["inputs"] = features["inputs"]
        new_features["targets"] = features["infer_targets"]
        _, _ = self(new_features)  # pylint: disable=not-callable

        if self.hparams.gen_mode == "unconditional":
            num_target_frames = 1
        else:
            num_target_frames = self.hparams.video_num_target_frames

        ops = [
            glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout
        ]
        var_scope = tf.variable_scope("next_frame_glow/body", reuse=True)
        all_frames = []

        # If eps=None, images are sampled from the prior.
        with arg_scope(ops, init=False), var_scope:
            for target_frame in range(1, num_target_frames + 1):

                # subscript -> timestep, superscript -> level.
                # self.z_sample equals z^0_{t} (top-level latent)
                # (X_{t}, z^{1..l}_{t}) = Glow(z^0_{t}, z^{1..l}_{t-1})
                # Get current set of cond_latents.
                cond_level, cond_level_latents = get_cond_latents(
                    self.all_level_latents, self.hparams)

                glow_vals = glow_ops.encoder_decoder(
                    "codec",
                    self.z_sample,
                    self.hparams,
                    eps=None,
                    reverse=True,
                    cond_latents=cond_level_latents,
                    states=self.level_states,
                    condition=cond_level,
                    temperature=self.temperature)
                predicted_frame, _, curr_latents, self.level_states = glow_vals
                all_frames.append(predicted_frame)
                self.all_level_latents.append(curr_latents)

                # Compute z^0_{t+1} = f(z^0_{t})
                if target_frame < num_target_frames:
                    cond_top, cond_top_latents = get_cond_latents(
                        self.all_top_latents, self.hparams)
                    prior_dist = self.top_prior(condition=cond_top,
                                                cond_latents=cond_top_latents)
                    self.z_sample = prior_dist.sample()
                    self.all_top_latents.append(self.z_sample)

        all_frames = tf.stack(all_frames)
        predicted_video = common_video.swap_time_and_batch_axes(all_frames)

        # The video-decode API requires the predicted video to be the same shape
        # as the target-video. Hence, for unconditional generation,
        # tile across time to ensure same shape.
        if self.hparams.gen_mode == "unconditional":
            predicted_video = tf.tile(
                predicted_video,
                [1, self.hparams.video_num_target_frames, 1, 1, 1])
        predicted_video = glow_ops.postprocess(predicted_video)

        # Output of a single decode / sample.
        output_features = {}
        output_features["targets"] = tf.zeros_like(predicted_video)
        output_features["outputs"] = predicted_video
        output_features["scores"] = tf.zeros_like(predicted_video)
        return output_features
Example #9
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        beta = self.get_beta()
        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_stds=latent_stds,
                                         beta=beta,
                                         true_frames=all_frames,
                                         gen_frames=gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove undeeded dimension.

        # TODO(mbz): clean this up!
        def fix_video_dims_and_concat_on_x_axis(x):
            x = tf.transpose(x, [1, 3, 4, 0, 2])
            x = tf.reshape(x, [batch_size, 64, 3, -1])
            x = tf.transpose(x, [0, 3, 1, 2])
            return x

        frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames)
        frames_pd = fix_video_dims_and_concat_on_x_axis(predictions)
        side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
        tf.summary.image("full_video", side_by_side_video)

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        return_targets = predictions
        if "target_reward" in features:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Example #10
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_stds=latent_stds,
                                         true_frames=all_frames,
                                         gen_frames=gen_images)

        # Visualize predictions in Tensorboard
        if self.is_training:
            self.visualize_predictions(all_frames[1:], gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove extra dimension.

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        if self.is_training and hparams.internal_loss:
            # add the MSE loss for input frames as well.
            extra_gts = all_frames[1:hparams.video_num_input_frames + 1]
            extra_gts = common_video.swap_time_and_batch_axes(extra_gts)
            extra_pds = gen_images[:hparams.video_num_input_frames]
            extra_pds = common_video.swap_time_and_batch_axes(extra_pds)
            if self._target_modality == "VideoModalityL2Raw":
                recon_loss = tf.losses.mean_squared_error(extra_gts, extra_pds)
            elif self._target_modality == "VideoModality":
                shape = common_layers.shape_list(extra_pds)
                updated_shape = shape[:-1] + [3, 256]
                extra_pds = tf.reshape(extra_pds, updated_shape)
                # Merge time and batch
                logits = tf.reshape(extra_pds, [-1] + updated_shape[2:])
                targets_shape = common_layers.shape_list(
                    features["targets_raw"])
                targets = tf.reshape(features["targets_raw"],
                                     [-1] + targets_shape[2:])
                mod = self.hparams.problem_hparams.target_modality["targets"]
                numerator, denominator = common_layers.padded_cross_entropy(
                    logits,
                    targets,
                    hparams.label_smoothing,
                    cutoff=getattr(hparams, "video_modality_loss_cutoff",
                                   0.01),
                    weights_fn=mod.targets_weights_fn)
                recon_loss = numerator / denominator
            else:
                raise ValueError(
                    "internal loss only supports specific modalities.")

            tf.summary.scalar("recon_extra", recon_loss)
            extra_loss += recon_loss

        return_targets = predictions
        if hparams.reward_prediction:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Example #11
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        step_num = tf.train.get_global_step()
        # TODO(mbz): what should it be if it"s undefined?
        if step_num is None:
            step_num = _LARGE_STEP_NUMBER

        schedule = self.hparams.latent_loss_multiplier_schedule
        second_stage = self.hparams.num_iterations_2nd_stage
        # TODO(mechcoder): Add log_annealing schedule.
        if schedule == "constant":
            beta = tf.cond(tf.greater(step_num, second_stage),
                           lambda: self.hparams.latent_loss_multiplier,
                           lambda: 0.0)
        elif schedule == "linear_anneal":
            # Linearly anneal beta from 0.0 to self.hparams.latent_loss_multiplier.
            # between self.hparams.num_iterations_2nd_stage to anneal_end.
            # beta = latent_loss * (1 - (global_step - 2nd_stage) / (anneal_end - 2nd_stage))  # pylint:disable=line-too-long
            anneal_end = self.hparams.anneal_end
            latent_multiplier = self.hparams.latent_loss_multiplier
            if anneal_end < second_stage:
                raise ValueError("Expected hparams.num_iterations_2nd_stage < "
                                 "hparams.anneal_end %d, got %d." %
                                 (second_stage, anneal_end))

            def anneal_loss(step_num):
                step_num = tf.cast(step_num, dtype=tf.float32)
                fraction = (float(anneal_end) - step_num) / (anneal_end -
                                                             second_stage)
                return self.hparams.latent_loss_multiplier * (1 - fraction)

            beta = tf.case(pred_fn_pairs={
                tf.less(step_num, second_stage):
                lambda: 0.0,
                tf.greater(step_num, anneal_end):
                lambda: latent_multiplier
            },
                           default=lambda: anneal_loss(step_num))

        kl_loss = 0.0
        if self.is_training:
            for i, (mean, std) in enumerate(zip(latent_means, latent_stds)):
                kl_loss += common_layers.kl_divergence(mean, std)
                tf.summary.histogram("posterior_mean_%d" % i, mean)
                tf.summary.histogram("posterior_std_%d" % i, std)

            tf.summary.scalar("beta", beta)
            tf.summary.scalar("kl_raw", tf.reduce_mean(kl_loss))

        extra_loss = beta * kl_loss

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove undeeded dimension.

        # TODO(mbz): clean this up!
        def fix_video_dims_and_concat_on_x_axis(x):
            x = tf.transpose(x, [1, 3, 4, 0, 2])
            x = tf.reshape(x, [batch_size, 64, 3, -1])
            x = tf.transpose(x, [0, 3, 1, 2])
            return x

        frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames)
        frames_pd = fix_video_dims_and_concat_on_x_axis(predictions)
        side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
        tf.summary.image("full_video", side_by_side_video)

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        return_targets = predictions
        if "target_reward" in features:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Example #12
0
  def body(self, features):
    hparams = self.hparams
    batch_size = common_layers.shape_list(features["inputs"])[0]

    # Swap time and batch axes.
    input_frames = common_video.swap_time_and_batch_axes(features["inputs"])
    target_frames = common_video.swap_time_and_batch_axes(features["targets"])

    # Get actions if exist otherwise use zeros
    input_actions = self.get_input_if_exists(
        features, "input_action", batch_size, hparams.video_num_input_frames)
    target_actions = self.get_input_if_exists(
        features, "target_action", batch_size, hparams.video_num_target_frames)

    # Get rewards if exist otherwise use zeros
    input_rewards = self.get_input_if_exists(
        features, "input_reward", batch_size, hparams.video_num_input_frames)
    target_rewards = self.get_input_if_exists(
        features, "target_reward", batch_size, hparams.video_num_target_frames)

    all_actions = tf.concat([input_actions, target_actions], axis=0)
    all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
    all_frames = tf.concat([input_frames, target_frames], axis=0)

    # Each image is being used twice, in latent tower and main tower.
    # This is to make sure we are using the *same* image for both, ...
    # ... given how TF queues work.
    # NOT sure if this is required at all. Doesn"t hurt though! :)
    all_frames = tf.identity(all_frames)

    gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
        images=all_frames,
        actions=all_actions,
        rewards=all_rewards,
    )

    extra_loss = self.get_extra_loss(
        latent_means=latent_means,
        latent_stds=latent_stds,
        true_frames=all_frames,
        gen_frames=gen_images)

    # Visualize predictions in Tensorboard
    if self.is_training:
      self.visualize_predictions(all_frames[1:], gen_images)

    # Ignore the predictions from the input frames.
    # This is NOT the same as original paper/implementation.
    predictions = gen_images[hparams.video_num_input_frames-1:]
    reward_pred = gen_rewards[hparams.video_num_input_frames-1:]
    reward_pred = tf.squeeze(reward_pred, axis=2)  # Remove extra dimension.

    # Swap back time and batch axes.
    predictions = common_video.swap_time_and_batch_axes(predictions)
    reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

    if self.is_training and hparams.internal_loss:
      # add the loss for input frames as well.
      extra_gts = all_frames[1:hparams.video_num_input_frames]
      extra_gts = common_video.swap_time_and_batch_axes(extra_gts)
      extra_pds = gen_images[:hparams.video_num_input_frames-1]
      extra_pds = common_video.swap_time_and_batch_axes(extra_pds)
      extra_raw_gts = features["inputs_raw"][:, 1:]
      recon_loss = self.get_extra_internal_loss(
          extra_raw_gts, extra_gts, extra_pds)
      extra_loss += recon_loss

    return_targets = predictions
    if hparams.reward_prediction:
      return_targets = {"targets": predictions, "target_reward": reward_pred}

    return return_targets, extra_loss
Example #13
0
 def get_input_if_exists(self, features, key, batch_size, num_frames):
   if key in features:
     x = features[key]
   else:
     x = tf.zeros((batch_size, num_frames, 1, self.hparams.hidden_size))
   return common_video.swap_time_and_batch_axes(x)
Example #14
0
    def body(self, features):
        hparams = self.hparams
        input_shape = common_layers.shape_list(features['inputs'])
        batch_size, _, frame_width, frame_height, frame_channels = input_shape  # pylint: disable=unused-variable

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            tf.to_float(features['inputs']))
        target_frames = common_video.swap_time_and_batch_axes(
            features['targets'])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, 'input_action', batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, 'target_action', batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        # TODO(blazej) enable rewards.
        # input_rewards = self.get_input_if_exists(
        #     features, 'input_reward', batch_size, hparams.video_num_input_frames)
        # target_rewards = self.get_input_if_exists(
        #     features, 'target_reward', batch_size,hparams.video_num_target_frames)
        # all_rewards = tf.concat([input_rewards, target_rewards], axis=0)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        # flatten actions tensor to have the shape: framesXbatch_sizeXaction_dims.
        actions_shape = common_layers.shape_list(all_actions)
        all_actions = tf.reshape(all_actions, [
            actions_shape[0], -1,
            reduce(lambda x, y: x * y, actions_shape[2:])
        ])
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        all_frames = tf.unstack(all_frames, axis=0)
        all_actions = tf.unstack(all_actions, axis=0)

        # TODO(blazej) - most likely this downsize is too strong.
        all_frames = [
            tf.image.resize_images(image, (IMG_HEIGHT, IMG_WIDTH),
                                   method=tf.image.ResizeMethod.BICUBIC)
            for image in all_frames
        ]

        enc_out_all, pred_out_all, _, van_on_enc_all = construct_model(
            all_frames,
            all_actions,
            context_frames=hparams.context_frames,
            hparams=hparams,
            is_training=self.is_training)

        enc_pred_loss, _ = calc_loss_psnr(
            enc_out_all[1:],
            pred_out_all,
            'enc_pred_loss',
            hparams=hparams,
            use_l1_loss=hparams.enc_pred_use_l1_loss)

        van_on_enc_loss, _ = calc_loss_psnr(van_on_enc_all,
                                            all_frames[1:],
                                            'van_on_enc_loss',
                                            hparams=hparams)

        enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1)
        enc_pred_loss_scale = tf.nn.sigmoid(
            (tf.to_float(tf.train.get_or_create_global_step()) -
             enc_pred_loss_scale_delay) /
            (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale
        tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale)
        epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss
        tf.summary.scalar('epva_loss', epva_loss)

        predictions = tf.stack(van_on_enc_all)

        if hparams.clip_pixel_values:
            predictions = tf.clip_by_value(predictions, 0.0, 1.0)

        # TODO(mbz): clean this up!
        def fix_video_dims_and_concat_on_x_axis(x):
            x = tf.transpose(x, [1, 3, 4, 0, 2])
            x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1])
            x = tf.transpose(x, [0, 3, 1, 2])
            return x

        frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames)
        frames_pd = fix_video_dims_and_concat_on_x_axis(predictions)
        side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1)
        tf.summary.image('full_video', side_by_side_video)

        predictions = tf.unstack(predictions)
        predictions = [
            tf.image.resize_images(image, (frame_width, frame_height),
                                   method=tf.image.ResizeMethod.BICUBIC)
            for image in predictions
        ]
        predictions = tf.stack(predictions)

        predictions = common_video.swap_time_and_batch_axes(predictions)
        predictions = tf.slice(
            predictions, [0, hparams.video_num_input_frames - 1, 0, 0, 0],
            [-1] * 5)

        return predictions, {'extra': epva_loss}
Example #15
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_stds=latent_stds,
                                         true_frames=all_frames,
                                         gen_frames=gen_images)

        # Visualize predictions in Tensorboard
        self.visualize_predictions(all_frames[1:], gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove extra dimension.

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        if hparams.internal_loss:
            # add the MSE loss for input frames as well.
            # we are assuming the modality is L2. otherwise the loss would be
            # incosistent across the frames.
            modality = self.hparams.problem_hparams.target_modality["targets"]
            if modality.__class__.__name__ != "VideoModalityL2Raw":
                raise ValueError("internal loss only works with L2.")
            recon_loss = tf.losses.mean_squared_error(
                all_frames[1:hparams.video_num_input_frames + 1],
                gen_images[:hparams.video_num_input_frames])
            tf.summary.scalar("mse_extra", recon_loss)
            extra_loss += recon_loss

        return_targets = predictions
        if hparams.reward_prediction:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Example #16
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        beta = self.get_beta()
        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_stds=latent_stds,
                                         beta=beta,
                                         true_frames=all_frames,
                                         gen_frames=gen_images)

        # Visualize predictions in Tensorboard
        self.visualize_predictions(all_frames[1:], gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        if self.is_training:
            reward_pred = tf.squeeze(reward_pred,
                                     axis=2)  # Remove extra dimension.

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        return_targets = predictions
        if hparams.reward_prediction:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        if hparams.internal_loss:
            loss = tf.losses.mean_squared_error(all_frames[1:], gen_images)
            extra_loss = {"training": loss + extra_loss}

        return return_targets, extra_loss
Example #17
0
  def body(self, features):
    hparams = self.hparams
    input_shape = common_layers.shape_list(features['inputs'])
    batch_size, _, frame_width, frame_height, frame_channels = input_shape  # pylint: disable=unused-variable

    # Swap time and batch axes.
    input_frames = common_video.swap_time_and_batch_axes(
        tf.to_float(features['inputs']))
    target_frames = common_video.swap_time_and_batch_axes(features['targets'])

    # Get actions if exist otherwise use zeros
    input_actions = self.get_input_if_exists(
        features, 'input_action', batch_size, hparams.video_num_input_frames)
    target_actions = self.get_input_if_exists(
        features, 'target_action', batch_size, hparams.video_num_target_frames)

    # Get rewards if exist otherwise use zeros
    # TODO(blazej) enable rewards.
    # input_rewards = self.get_input_if_exists(
    #     features, 'input_reward', batch_size, hparams.video_num_input_frames)
    # target_rewards = self.get_input_if_exists(
    #     features, 'target_reward', batch_size,hparams.video_num_target_frames)
    # all_rewards = tf.concat([input_rewards, target_rewards], axis=0)

    all_actions = tf.concat([input_actions, target_actions], axis=0)
    # flatten actions tensor to have the shape: framesXbatch_sizeXaction_dims.
    actions_shape = common_layers.shape_list(all_actions)
    all_actions = tf.reshape(
        all_actions,
        [actions_shape[0], -1,
         reduce(lambda x, y: x * y, actions_shape[2:])])
    all_frames = tf.concat([input_frames, target_frames], axis=0)

    all_frames = tf.unstack(all_frames, axis=0)
    all_actions = tf.unstack(all_actions, axis=0)

    # TODO(blazej) - most likely this downsize is too strong.
    all_frames = [
        tf.image.resize_images(
            image, (IMG_HEIGHT, IMG_WIDTH),
            method=tf.image.ResizeMethod.BICUBIC)
        for image in all_frames
    ]

    enc_out_all, pred_out_all, _, van_on_enc_all = construct_model(
        all_frames,
        all_actions,
        context_frames=hparams.context_frames,
        hparams=hparams,
        is_training=self.is_training)

    enc_pred_loss, _ = calc_loss_psnr(
        enc_out_all[1:],
        pred_out_all,
        'enc_pred_loss',
        hparams=hparams,
        use_l1_loss=hparams.enc_pred_use_l1_loss)

    van_on_enc_loss, _ = calc_loss_psnr(
        van_on_enc_all,
        all_frames[1:],
        'van_on_enc_loss',
        hparams=hparams)

    enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1)
    enc_pred_loss_scale = tf.nn.sigmoid(
        (tf.to_float(tf.train.get_or_create_global_step()
                    ) - enc_pred_loss_scale_delay) /
        (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale
    tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale)
    epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss
    tf.summary.scalar('epva_loss', epva_loss)

    predictions = tf.stack(van_on_enc_all)

    if hparams.clip_pixel_values:
      predictions = tf.clip_by_value(predictions, 0.0, 1.0)

    # TODO(mbz): clean this up!
    def fix_video_dims_and_concat_on_x_axis(x):
      x = tf.transpose(x, [1, 3, 4, 0, 2])
      x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1])
      x = tf.transpose(x, [0, 3, 1, 2])
      return x

    frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames)
    frames_pd = fix_video_dims_and_concat_on_x_axis(predictions)
    side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1)
    tf.summary.image('full_video', side_by_side_video)

    predictions = tf.unstack(predictions)
    predictions = [
        tf.image.resize_images(
            image, (frame_width, frame_height),
            method=tf.image.ResizeMethod.BICUBIC)
        for image in predictions
    ]
    predictions = tf.stack(predictions)

    predictions = common_video.swap_time_and_batch_axes(predictions)
    predictions = tf.slice(predictions,
                           [0, hparams.video_num_input_frames-1, 0, 0, 0],
                           [-1]*5)

    return predictions, {'extra': epva_loss}