Exemple #1
0
 def video_features(self, all_frames, all_actions, all_rewards,
                    all_raw_frames):
     """Video wide latent."""
     del all_actions, all_rewards, all_raw_frames
     mean, std = self.construct_latent_tower(all_frames, time_axis=0)
     latent = common_video.get_gaussian_tensor(mean, std)
     return [latent, mean, std]
Exemple #2
0
    def construct_model(self, images, actions, rewards):
        images = tf.unstack(images, axis=0)
        actions = tf.unstack(actions, axis=0)
        rewards = tf.unstack(rewards, axis=0)

        batch_size = common_layers.shape_list(images[0])[0]
        context_frames = self.hparams.video_num_input_frames

        # Predicted images and rewards.
        gen_rewards, gen_images, latent_means, latent_stds = [], [], [], []

        # LSTM states.
        lstm_state = [None] * 7

        # Create scheduled sampling function
        ss_func = self.get_scheduled_sample_func(batch_size)

        pred_image = tf.zeros_like(images[0])
        pred_reward = tf.zeros_like(rewards[0])
        latent = None
        for timestep, image, action, reward in zip(range(len(images) - 1),
                                                   images[:-1], actions[:-1],
                                                   rewards[:-1]):
            # Scheduled Sampling
            done_warm_start = timestep > context_frames - 1
            groundtruth_items = [image, reward]
            generated_items = [pred_image, pred_reward]
            input_image, input_reward = self.get_scheduled_sample_inputs(
                done_warm_start, groundtruth_items, generated_items, ss_func)

            # Latent
            # TODO(mbz): should we use input_image iunstead of image?
            latent_images = tf.stack([image, images[timestep + 1]], axis=0)
            latent_mean, latent_std = self.construct_latent_tower(
                latent_images, time_axis=0)
            latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
            latent_means.append(latent_mean)
            latent_stds.append(latent_std)

            # Prediction
            pred_image, lstm_state = self.construct_predictive_tower(
                input_image, input_reward, action, lstm_state, latent)

            if self.hparams.reward_prediction:
                pred_reward = self.reward_prediction(pred_image, input_reward,
                                                     action, latent)
                pred_reward = common_video.decode_to_shape(
                    pred_reward, common_layers.shape_list(input_reward),
                    "reward_dec")
            else:
                pred_reward = input_reward

            gen_images.append(pred_image)
            gen_rewards.append(pred_reward)

        gen_images = tf.stack(gen_images, axis=0)
        gen_rewards = tf.stack(gen_rewards, axis=0)

        return gen_images, gen_rewards, latent_means, latent_stds
Exemple #3
0
  def construct_model(self, images, actions, rewards):
    images = tf.unstack(images, axis=0)
    actions = tf.unstack(actions, axis=0)
    rewards = tf.unstack(rewards, axis=0)

    batch_size = common_layers.shape_list(images[0])[0]
    context_frames = self.hparams.video_num_input_frames

    # Predicted images and rewards.
    gen_rewards, gen_images, latent_means, latent_stds = [], [], [], []

    # LSTM states.
    lstm_state = [None] * 7

    # Create scheduled sampling function
    ss_func = self.get_scheduled_sample_func(batch_size)

    pred_image = tf.zeros_like(images[0])
    pred_reward = tf.zeros_like(rewards[0])
    latent = None
    for timestep, image, action, reward in zip(
        range(len(images)-1), images[:-1], actions[:-1], rewards[:-1]):
      # Scheduled Sampling
      done_warm_start = timestep > context_frames - 1
      groundtruth_items = [image, reward]
      generated_items = [pred_image, pred_reward]
      input_image, input_reward = self.get_scheduled_sample_inputs(
          done_warm_start, groundtruth_items, generated_items, ss_func)

      # Latent
      # TODO(mbz): should we use input_image iunstead of image?
      latent_images = tf.stack([image, images[timestep+1]], axis=0)
      latent_mean, latent_std = self.construct_latent_tower(
          latent_images, time_axis=0)
      latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
      latent_means.append(latent_mean)
      latent_stds.append(latent_std)

      # Prediction
      pred_image, lstm_state, _ = self.construct_predictive_tower(
          input_image, input_reward, action, lstm_state, latent)

      if self.hparams.reward_prediction:
        pred_reward = self.reward_prediction(
            pred_image, input_reward, action, latent)
        pred_reward = common_video.decode_to_shape(
            pred_reward, common_layers.shape_list(input_reward), "reward_dec")
      else:
        pred_reward = input_reward

      gen_images.append(pred_image)
      gen_rewards.append(pred_reward)

    gen_images = tf.stack(gen_images, axis=0)
    gen_rewards = tf.stack(gen_rewards, axis=0)

    return gen_images, gen_rewards, latent_means, latent_stds
Exemple #4
0
 def video_features(self, all_frames, all_actions, all_rewards,
                    all_raw_frames):
     """Video wide latent."""
     del all_actions, all_rewards, all_raw_frames
     if not self.hparams.stochastic_model:
         return None, None, None
     frames = tf.stack(all_frames, axis=1)
     mean, std = self.construct_latent_tower(frames, time_axis=1)
     latent = common_video.get_gaussian_tensor(mean, std)
     return [latent, mean, std]
Exemple #5
0
 def video_features(
     self, all_frames, all_actions, all_rewards, all_raw_frames):
   """Video wide latent."""
   del all_actions, all_rewards, all_raw_frames
   if not self.hparams.stochastic_model:
     return None, None, None
   frames = tf.stack(all_frames, axis=1)
   mean, std = self.construct_latent_tower(frames, time_axis=1)
   latent = common_video.get_gaussian_tensor(mean, std)
   return [latent, mean, std]
 def inject_latent(self, layer, features, filters):
     """Do nothing for deterministic model."""
     # Latent for stochastic model
     full_video = tf.concat(
         [features["inputs_raw"], features["targets_raw"]], axis=1)
     latent_mean, latent_std = self.construct_latent_tower(full_video,
                                                           time_axis=1)
     latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
     latent = tf.layers.flatten(latent)
     latent = tf.expand_dims(latent, axis=1)
     latent = tf.expand_dims(latent, axis=1)
     latent_mask = tf.layers.dense(latent, filters, name="latent_mask")
     zeros_mask = tf.zeros(common_layers.shape_list(layer)[:-1] + [filters],
                           dtype=tf.float32)
     layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1)
     extra_loss = self.get_extra_loss(latent_mean, latent_std)
     return layer, extra_loss
 def inject_latent(self, layer, inputs, target):
     """Inject a VAE-style latent."""
     # Latent for stochastic model
     filters = 128
     full_video = tf.stack(inputs + [target], axis=1)
     latent_mean, latent_std = self.construct_latent_tower(full_video,
                                                           time_axis=1)
     latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
     latent = tfl.flatten(latent)
     latent = tf.expand_dims(latent, axis=1)
     latent = tf.expand_dims(latent, axis=1)
     latent_mask = tfl.dense(latent, filters, name="latent_mask")
     zeros_mask = tf.zeros(common_layers.shape_list(layer)[:-1] + [filters],
                           dtype=tf.float32)
     layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1)
     extra_loss = self.get_kl_loss([latent_mean], [latent_std])
     return layer, extra_loss
 def inject_latent(self, layer, inputs, target, action):
   """Inject a VAE-style latent."""
   del action
   # Latent for stochastic model
   filters = 128
   full_video = tf.stack(inputs + [target], axis=1)
   latent_mean, latent_std = self.construct_latent_tower(
       full_video, time_axis=1)
   latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
   latent = tfl.flatten(latent)
   latent = tf.expand_dims(latent, axis=1)
   latent = tf.expand_dims(latent, axis=1)
   latent_mask = tfl.dense(latent, filters, name="latent_mask")
   zeros_mask = tf.zeros(
       common_layers.shape_list(layer)[:-1] + [filters], dtype=tf.float32)
   layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1)
   extra_loss = self.get_kl_loss([latent_mean], [latent_std])
   return layer, extra_loss
Exemple #9
0
    def construct_model(self, images, actions, rewards):
        """Build convolutional lstm video predictor using CDNA, or DNA.

    Args:
      images: list of tensors of ground truth image sequences
              there should be a 4D image ?xWxHxC for each timestep
      actions: list of action tensors
               each action should be in the shape ?x1xZ
      rewards: list of reward tensors
               each reward should be in the shape ?x1xZ
    Returns:
      gen_images: predicted future image frames
      gen_rewards: predicted future rewards
      latent_mean: mean of approximated posterior
      latent_std: std of approximated posterior

    Raises:
      ValueError: if more than 1 mask specified for DNA model.
    """
        context_frames = self.hparams.video_num_input_frames
        buffer_size = self.hparams.reward_prediction_buffer_size
        if buffer_size == 0:
            buffer_size = context_frames
        if buffer_size > context_frames:
            raise ValueError(
                "Buffer size is bigger than context frames %d %d." %
                (buffer_size, context_frames))

        batch_size = common_layers.shape_list(images)[1]
        ss_func = self.get_scheduled_sample_func(batch_size)

        def process_single_frame(prev_outputs, inputs):
            """Process a single frame of the video."""
            cur_image, input_reward, action = inputs
            time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs

            generated_items = [prev_image]
            groundtruth_items = [cur_image]
            done_warm_start = tf.greater(time_step, context_frames - 1)
            input_image, = self.get_scheduled_sample_inputs(
                done_warm_start, groundtruth_items, generated_items, ss_func)

            # Prediction
            pred_image, lstm_states = self.construct_predictive_tower(
                input_image, None, action, lstm_states, latent)

            if self.hparams.reward_prediction:
                reward_input_image = pred_image
                if self.hparams.reward_prediction_stop_gradient:
                    reward_input_image = tf.stop_gradient(reward_input_image)
                with tf.control_dependencies([time_step]):
                    frame_buf = [reward_input_image] + frame_buf[:-1]
                pred_reward = self.reward_prediction(frame_buf, None, action,
                                                     latent)
                pred_reward = common_video.decode_to_shape(
                    pred_reward, common_layers.shape_list(input_reward),
                    "reward_dec")
            else:
                pred_reward = prev_reward

            time_step += 1
            outputs = (time_step, pred_image, pred_reward, frame_buf,
                       lstm_states)

            return outputs

        # Latent tower
        latent = None
        if self.hparams.stochastic_model:
            latent_mean, latent_std = self.construct_latent_tower(images,
                                                                  time_axis=0)
            latent = common_video.get_gaussian_tensor(latent_mean, latent_std)

        # HACK: Do first step outside to initialize all the variables
        lstm_states = [None] * 7
        frame_buffer = [tf.zeros_like(images[0])] * buffer_size
        inputs = images[0], rewards[0], actions[0]
        prev_outputs = (tf.constant(0), tf.zeros_like(images[0]),
                        tf.zeros_like(rewards[0]), frame_buffer, lstm_states)

        initializers = process_single_frame(prev_outputs, inputs)
        first_gen_images = tf.expand_dims(initializers[1], axis=0)
        first_gen_rewards = tf.expand_dims(initializers[2], axis=0)

        inputs = (images[1:-1], rewards[1:-1], actions[1:-1])

        outputs = tf.scan(process_single_frame, inputs, initializers)
        gen_images, gen_rewards = outputs[1:3]

        gen_images = tf.concat((first_gen_images, gen_images), axis=0)
        gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0)

        if self.hparams.stochastic_model:
            return gen_images, gen_rewards, [latent_mean], [latent_std]
        else:
            return gen_images, gen_rewards, None, None
Exemple #10
0
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_features):
        del target_frame

        if not self.hparams.use_vae or self.hparams.use_gan:
            raise NotImplementedError("Only supporting VAE for now.")

        if self.has_pred_actions or self.has_values:
            raise NotImplementedError(
                "Parameter sharing with policy not supported.")

        image, action, reward = frames[0], actions[0], rewards[0]
        latent_dims = self.hparams.z_dim
        batch_size = common_layers.shape_list(image)[0]

        if internal_states is None:
            # Initialize LSTM State
            frame_index = 0
            lstm_state = [None] * 7
            cond_latent_state, prior_latent_state = None, None
            gen_prior_video = []
        else:
            (frame_index, lstm_state, cond_latent_state, prior_latent_state,
             gen_prior_video) = internal_states

        z_mu, log_sigma_sq = video_features
        z_mu, log_sigma_sq = z_mu[frame_index], log_sigma_sq[frame_index]

        # Sample latents using a gaussian centered at conditional mu and std.
        latent = common_video.get_gaussian_tensor(z_mu, log_sigma_sq)

        # Sample prior latents from isotropic normal distribution.
        prior_latent = tf.random_normal(tf.shape(latent), dtype=tf.float32)

        # # LSTM that encodes correlations between conditional latents.
        # # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf
        enc_cond_latent, cond_latent_state = common_video.basic_lstm(
            latent, cond_latent_state, latent_dims, name="cond_latent")

        # LSTM that encodes correlations between prior latents.
        enc_prior_latent, prior_latent_state = common_video.basic_lstm(
            prior_latent, prior_latent_state, latent_dims, name="prior_latent")

        all_latents = tf.concat([enc_cond_latent, enc_prior_latent], axis=0)
        all_image = tf.concat([image, image], 0)
        all_action = tf.concat([action, action],
                               0) if self.has_actions else None

        all_pred_images, lstm_state = self.construct_predictive_tower(
            all_image,
            None,
            all_action,
            lstm_state,
            all_latents,
            concat_latent=True)

        cond_pred_images, prior_pred_images = \
          all_pred_images[:batch_size], all_pred_images[batch_size:]

        if self.is_training and self.hparams.use_vae:
            pred_image = cond_pred_images
        else:
            pred_image = prior_pred_images

        gen_prior_video.append(prior_pred_images)
        internal_states = (frame_index + 1, lstm_state, cond_latent_state,
                           prior_latent_state, gen_prior_video)

        if not self.has_rewards:
            return pred_image, None, 0.0, internal_states

        pred_reward = self.reward_prediction(pred_image, action, reward,
                                             latent)
        return pred_image, pred_reward, None, None, 0.0, internal_states
Exemple #11
0
    def construct_model(self, images, actions, rewards):
        """Model that takes in images and returns predictions.

    Args:
      images: list of 4-D Tensors indexed by time.
              (batch_size, width, height, channels)
      actions: list of action tensors
               each action should be in the shape ?x1xZ
      rewards: list of reward tensors
               each reward should be in the shape ?x1xZ

    Returns:
      video: list of 4-D predicted frames.
      all_rewards: predicted rewards.
      latent_means: list of gaussian means conditioned on the input at
                    every frame.
      latent_stds: list of gaussian stds conditioned on the input at
                   every frame.

    Raises:
      ValueError: If not exactly one of self.hparams.vae or self.hparams.gan
                  is set to True.
    """
        if not self.hparams.use_vae and not self.hparams.use_gan:
            raise ValueError(
                "Set at least one of use_vae or use_gan to be True")
        if self.hparams.gan_optimization not in ["joint", "sequential"]:
            raise ValueError(
                "self.hparams.gan_optimization should be either joint "
                "or sequential got %s" % self.hparams.gan_optimization)

        images = tf.unstack(images, axis=0)
        actions = tf.unstack(actions, axis=0)
        rewards = tf.unstack(rewards, axis=0)

        latent_dims = self.hparams.z_dim
        context_frames = self.hparams.video_num_input_frames
        seq_len = len(images)
        input_shape = common_layers.shape_list(images[0])
        batch_size = input_shape[0]

        # Model does not support reward-conditioned frame generation.
        fake_rewards = rewards[:-1]

        # Concatenate x_{t-1} and x_{t} along depth and encode it to
        # produce the mean and standard deviation of z_{t-1}
        image_pairs = tf.concat([images[:seq_len - 1], images[1:seq_len]],
                                axis=-1)

        z_mu, z_log_sigma_sq = self.encoder(image_pairs)
        # Unstack z_mu and z_log_sigma_sq along the time dimension.
        z_mu = tf.unstack(z_mu, axis=0)
        z_log_sigma_sq = tf.unstack(z_log_sigma_sq, axis=0)
        iterable = zip(images[:-1], actions[:-1], fake_rewards, z_mu,
                       z_log_sigma_sq)

        # Initialize LSTM State
        lstm_state = [None] * 7
        gen_cond_video, gen_prior_video, all_rewards, latent_means, latent_stds = \
          [], [], [], [], []
        pred_image = tf.zeros_like(images[0])
        prior_latent_state, cond_latent_state = None, None
        train_mode = self.hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Create scheduled sampling function
        ss_func = self.get_scheduled_sample_func(batch_size)

        with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):

            for step, (image, action, reward, mu,
                       log_sigma_sq) in enumerate(iterable):  # pylint:disable=line-too-long
                # Sample latents using a gaussian centered at conditional mu and std.
                latent = common_video.get_gaussian_tensor(mu, log_sigma_sq)

                # Sample prior latents from isotropic normal distribution.
                prior_latent = tf.random_normal(tf.shape(latent),
                                                dtype=tf.float32)

                # LSTM that encodes correlations between conditional latents.
                # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf
                enc_cond_latent, cond_latent_state = common_video.basic_lstm(
                    latent, cond_latent_state, latent_dims, name="cond_latent")

                # LSTM that encodes correlations between prior latents.
                enc_prior_latent, prior_latent_state = common_video.basic_lstm(
                    prior_latent,
                    prior_latent_state,
                    latent_dims,
                    name="prior_latent")

                # Scheduled Sampling
                done_warm_start = step > context_frames - 1
                groundtruth_items = [image]
                generated_items = [pred_image]
                input_image, = self.get_scheduled_sample_inputs(
                    done_warm_start, groundtruth_items, generated_items,
                    ss_func)

                all_latents = tf.concat([enc_cond_latent, enc_prior_latent],
                                        axis=0)
                all_image = tf.concat([input_image, input_image], axis=0)
                all_action = tf.concat([action, action], axis=0)
                all_rewards = tf.concat([reward, reward], axis=0)

                all_pred_images, lstm_state, _ = self.construct_predictive_tower(
                    all_image,
                    all_rewards,
                    all_action,
                    lstm_state,
                    all_latents,
                    concat_latent=True)

                cond_pred_images, prior_pred_images = \
                  all_pred_images[:batch_size], all_pred_images[batch_size:]

                if train_mode and self.hparams.use_vae:
                    pred_image = cond_pred_images
                else:
                    pred_image = prior_pred_images

                gen_cond_video.append(cond_pred_images)
                gen_prior_video.append(prior_pred_images)
                latent_means.append(mu)
                latent_stds.append(log_sigma_sq)

        gen_cond_video = tf.stack(gen_cond_video, axis=0)
        self.gen_prior_video = tf.stack(gen_prior_video, axis=0)
        fake_rewards = tf.stack(fake_rewards, axis=0)

        if train_mode and self.hparams.use_vae:
            return gen_cond_video, fake_rewards, latent_means, latent_stds
        else:
            return self.gen_prior_video, fake_rewards, latent_means, latent_stds
Exemple #12
0
    def construct_model(self, images, actions, rewards):
        """Builds the stochastic model.

    The model first encodes all the images (x_t) in the sequence
    using the encoder. Let"s call the output e_t. Then it predicts the
    latent state of the next frame using a recurrent posterior network
    z ~ q(z|e_{0:t}) = N(mu(e_{0:t}), sigma(e_{0:t})).
    Another recurrent network predicts the embedding of the next frame
    using the approximated posterior e_{t+1} = p(e_{t+1}|e_{0:t}, z)
    Finally, the decoder decodes e_{t+1} into x_{t+1}.
    Skip connections from encoder to decoder help with reconstruction.

    Args:
      images: tensor of ground truth image sequences
      actions: NOT used list of action tensors
      rewards: NOT used list of reward tensors

    Returns:
      gen_images: generated images
      fakr_rewards: input rewards as reward prediction!
      pred_mu: predited means of posterior
      pred_logvar: predicted log(var) of posterior
    """
        # model does not support action conditioned and reward prediction
        fake_reward_prediction = rewards
        del actions, rewards

        z_dim = self.hparams.z_dim
        g_dim = self.hparams.g_dim
        rnn_size = self.hparams.rnn_size
        prior_rnn_layers = self.hparams.prior_rnn_layers
        posterior_rnn_layers = self.hparams.posterior_rnn_layers
        predictor_rnn_layers = self.hparams.predictor_rnn_layers
        context_frames = self.hparams.video_num_input_frames
        has_batchnorm = self.hparams.has_batchnorm

        seq_len, batch_size, _, _, color_channels = common_layers.shape_list(
            images)

        # LSTM initial sizesstates.
        prior_states = [None] * prior_rnn_layers
        posterior_states = [None] * posterior_rnn_layers
        predictor_states = [None] * predictor_rnn_layers

        tf.logging.info(">>>> Encoding")
        # Encoding:
        enc_images, enc_skips = [], []
        images = tf.unstack(images, axis=0)
        for i, image in enumerate(images):
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                enc, skips = self.encoder(image,
                                          g_dim,
                                          has_batchnorm=has_batchnorm)
                enc = tfl.flatten(enc)
                enc_images.append(enc)
                enc_skips.append(skips)

        tf.logging.info(">>>> Prediction")
        # Prediction
        pred_mu_pos = []
        pred_logvar_pos = []
        pred_mu_prior = []
        pred_logvar_prior = []
        gen_images = []
        for i in range(1, seq_len):
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                # current encoding
                if self.is_training or len(gen_images) < context_frames:
                    h_current = enc_images[i - 1]
                else:
                    h_current, _ = self.encoder(gen_images[-1], g_dim)
                    h_current = tfl.flatten(h_current)

                # target encoding
                h_target = enc_images[i]

            with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):
                # Prior parameters
                if self.hparams.learned_prior:
                    mu_prior, logvar_prior, prior_states = self.lstm_gaussian(
                        h_current, prior_states, rnn_size, z_dim,
                        prior_rnn_layers, "prior")
                else:
                    mu_prior = tf.zeros((batch_size, z_dim))
                    logvar_prior = tf.zeros((batch_size, z_dim))

                # Only use Posterior if it's training time
                if self.is_training or len(gen_images) < context_frames:
                    mu_pos, logvar_pos, posterior_states = self.lstm_gaussian(
                        h_target, posterior_states, rnn_size, z_dim,
                        posterior_rnn_layers, "posterior")
                    # Sample z from posterior distribution
                    z = common_video.get_gaussian_tensor(mu_pos, logvar_pos)
                else:
                    mu_pos = tf.zeros_like(mu_prior)
                    logvar_pos = tf.zeros_like(logvar_prior)
                    z = common_video.get_gaussian_tensor(
                        mu_prior, logvar_prior)

                # Predict output encoding
                h_pred, predictor_states = self.stacked_lstm(
                    tf.concat([h_current, z], axis=1), predictor_states,
                    rnn_size, g_dim, predictor_rnn_layers)

                pred_mu_pos.append(tf.identity(mu_pos, "mu_pos"))
                pred_logvar_pos.append(tf.identity(logvar_pos, "logvar_pos"))
                pred_mu_prior.append(tf.identity(mu_prior, "mu_prior"))
                pred_logvar_prior.append(
                    tf.identity(logvar_prior, "logvar_prior"))

            with tf.variable_scope("decoding", reuse=tf.AUTO_REUSE):
                skip_index = min(context_frames - 1, i - 1)
                h_pred = tf.reshape(h_pred, [batch_size, 1, 1, g_dim])
                if self.hparams.has_skips:
                    x_pred = self.decoder(h_pred,
                                          color_channels,
                                          skips=enc_skips[skip_index],
                                          has_batchnorm=has_batchnorm)
                else:
                    x_pred = self.decoder(h_pred,
                                          color_channels,
                                          has_batchnorm=has_batchnorm)
                gen_images.append(x_pred)

        tf.logging.info(">>>> Done")
        gen_images = tf.stack(gen_images, axis=0)
        return {
            "gen_images": gen_images,
            "fake_reward_prediction": fake_reward_prediction,
            "pred_mu_pos": pred_mu_pos,
            "pred_logvar_pos": pred_logvar_pos,
            "pred_mu_prior": pred_mu_prior,
            "pred_logvar_prior": pred_logvar_prior
        }
Exemple #13
0
  def construct_model(self,
                      images,
                      actions,
                      rewards):
    """Build convolutional lstm video predictor using CDNA, or DNA.

    Args:
      images: list of tensors of ground truth image sequences
              there should be a 4D image ?xWxHxC for each timestep
      actions: list of action tensors
               each action should be in the shape ?x1xZ
      rewards: list of reward tensors
               each reward should be in the shape ?x1xZ
    Returns:
      gen_images: predicted future image frames
      gen_rewards: predicted future rewards
      latent_mean: mean of approximated posterior
      latent_std: std of approximated posterior

    Raises:
      ValueError: if more than 1 mask specified for DNA model.
    """
    context_frames = self.hparams.video_num_input_frames
    buffer_size = self.hparams.reward_prediction_buffer_size
    if buffer_size == 0:
      buffer_size = context_frames
    if buffer_size > context_frames:
      raise ValueError("Buffer size is bigger than context frames %d %d." %
                       (buffer_size, context_frames))

    batch_size = common_layers.shape_list(images[0])[0]
    ss_func = self.get_scheduled_sample_func(batch_size)

    def process_single_frame(prev_outputs, inputs):
      """Process a single frame of the video."""
      cur_image, input_reward, action = inputs
      time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs

      # sample from softmax (by argmax). this is noop for non-softmax loss.
      prev_image = self.get_sampled_frame(prev_image)

      generated_items = [prev_image]
      groundtruth_items = [cur_image]
      done_warm_start = tf.greater(time_step, context_frames - 1)
      input_image, = self.get_scheduled_sample_inputs(
          done_warm_start, groundtruth_items, generated_items, ss_func)

      # Prediction
      pred_image, lstm_states, _ = self.construct_predictive_tower(
          input_image, None, action, lstm_states, latent)

      if self.hparams.reward_prediction:
        reward_input_image = self.get_sampled_frame(pred_image)
        if self.hparams.reward_prediction_stop_gradient:
          reward_input_image = tf.stop_gradient(reward_input_image)
        with tf.control_dependencies([time_step]):
          frame_buf = [reward_input_image] + frame_buf[:-1]
        pred_reward = self.reward_prediction(frame_buf, None, action, latent)
        pred_reward = common_video.decode_to_shape(
            pred_reward, common_layers.shape_list(input_reward), "reward_dec")
      else:
        pred_reward = prev_reward

      time_step += 1
      outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states)

      return outputs

    # Latent tower
    latent = None
    if self.hparams.stochastic_model:
      latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0)
      latent = common_video.get_gaussian_tensor(latent_mean, latent_std)

    # HACK: Do first step outside to initialize all the variables

    lstm_states = [None] * (5 if self.hparams.small_mode else 7)
    frame_buffer = [tf.zeros_like(images[0])] * buffer_size
    inputs = images[0], rewards[0], actions[0]
    init_image_shape = common_layers.shape_list(images[0])
    if self.is_per_pixel_softmax:
      init_image_shape[-1] *= 256
    init_image = tf.zeros(init_image_shape, dtype=images.dtype)
    prev_outputs = (tf.constant(0),
                    init_image,
                    tf.zeros_like(rewards[0]),
                    frame_buffer,
                    lstm_states)

    initializers = process_single_frame(prev_outputs, inputs)
    first_gen_images = tf.expand_dims(initializers[1], axis=0)
    first_gen_rewards = tf.expand_dims(initializers[2], axis=0)

    inputs = (images[1:-1], rewards[1:-1], actions[1:-1])

    outputs = tf.scan(process_single_frame, inputs, initializers)
    gen_images, gen_rewards = outputs[1:3]

    gen_images = tf.concat((first_gen_images, gen_images), axis=0)
    gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0)

    if self.hparams.stochastic_model:
      return gen_images, gen_rewards, [latent_mean], [latent_std]
    else:
      return gen_images, gen_rewards, None, None