def video_features(self, all_frames, all_actions, all_rewards, all_raw_frames): """Video wide latent.""" del all_actions, all_rewards, all_raw_frames mean, std = self.construct_latent_tower(all_frames, time_axis=0) latent = common_video.get_gaussian_tensor(mean, std) return [latent, mean, std]
def construct_model(self, images, actions, rewards): images = tf.unstack(images, axis=0) actions = tf.unstack(actions, axis=0) rewards = tf.unstack(rewards, axis=0) batch_size = common_layers.shape_list(images[0])[0] context_frames = self.hparams.video_num_input_frames # Predicted images and rewards. gen_rewards, gen_images, latent_means, latent_stds = [], [], [], [] # LSTM states. lstm_state = [None] * 7 # Create scheduled sampling function ss_func = self.get_scheduled_sample_func(batch_size) pred_image = tf.zeros_like(images[0]) pred_reward = tf.zeros_like(rewards[0]) latent = None for timestep, image, action, reward in zip(range(len(images) - 1), images[:-1], actions[:-1], rewards[:-1]): # Scheduled Sampling done_warm_start = timestep > context_frames - 1 groundtruth_items = [image, reward] generated_items = [pred_image, pred_reward] input_image, input_reward = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Latent # TODO(mbz): should we use input_image iunstead of image? latent_images = tf.stack([image, images[timestep + 1]], axis=0) latent_mean, latent_std = self.construct_latent_tower( latent_images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent_means.append(latent_mean) latent_stds.append(latent_std) # Prediction pred_image, lstm_state = self.construct_predictive_tower( input_image, input_reward, action, lstm_state, latent) if self.hparams.reward_prediction: pred_reward = self.reward_prediction(pred_image, input_reward, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = input_reward gen_images.append(pred_image) gen_rewards.append(pred_reward) gen_images = tf.stack(gen_images, axis=0) gen_rewards = tf.stack(gen_rewards, axis=0) return gen_images, gen_rewards, latent_means, latent_stds
def construct_model(self, images, actions, rewards): images = tf.unstack(images, axis=0) actions = tf.unstack(actions, axis=0) rewards = tf.unstack(rewards, axis=0) batch_size = common_layers.shape_list(images[0])[0] context_frames = self.hparams.video_num_input_frames # Predicted images and rewards. gen_rewards, gen_images, latent_means, latent_stds = [], [], [], [] # LSTM states. lstm_state = [None] * 7 # Create scheduled sampling function ss_func = self.get_scheduled_sample_func(batch_size) pred_image = tf.zeros_like(images[0]) pred_reward = tf.zeros_like(rewards[0]) latent = None for timestep, image, action, reward in zip( range(len(images)-1), images[:-1], actions[:-1], rewards[:-1]): # Scheduled Sampling done_warm_start = timestep > context_frames - 1 groundtruth_items = [image, reward] generated_items = [pred_image, pred_reward] input_image, input_reward = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Latent # TODO(mbz): should we use input_image iunstead of image? latent_images = tf.stack([image, images[timestep+1]], axis=0) latent_mean, latent_std = self.construct_latent_tower( latent_images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent_means.append(latent_mean) latent_stds.append(latent_std) # Prediction pred_image, lstm_state, _ = self.construct_predictive_tower( input_image, input_reward, action, lstm_state, latent) if self.hparams.reward_prediction: pred_reward = self.reward_prediction( pred_image, input_reward, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = input_reward gen_images.append(pred_image) gen_rewards.append(pred_reward) gen_images = tf.stack(gen_images, axis=0) gen_rewards = tf.stack(gen_rewards, axis=0) return gen_images, gen_rewards, latent_means, latent_stds
def video_features(self, all_frames, all_actions, all_rewards, all_raw_frames): """Video wide latent.""" del all_actions, all_rewards, all_raw_frames if not self.hparams.stochastic_model: return None, None, None frames = tf.stack(all_frames, axis=1) mean, std = self.construct_latent_tower(frames, time_axis=1) latent = common_video.get_gaussian_tensor(mean, std) return [latent, mean, std]
def video_features( self, all_frames, all_actions, all_rewards, all_raw_frames): """Video wide latent.""" del all_actions, all_rewards, all_raw_frames if not self.hparams.stochastic_model: return None, None, None frames = tf.stack(all_frames, axis=1) mean, std = self.construct_latent_tower(frames, time_axis=1) latent = common_video.get_gaussian_tensor(mean, std) return [latent, mean, std]
def inject_latent(self, layer, features, filters): """Do nothing for deterministic model.""" # Latent for stochastic model full_video = tf.concat( [features["inputs_raw"], features["targets_raw"]], axis=1) latent_mean, latent_std = self.construct_latent_tower(full_video, time_axis=1) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent = tf.layers.flatten(latent) latent = tf.expand_dims(latent, axis=1) latent = tf.expand_dims(latent, axis=1) latent_mask = tf.layers.dense(latent, filters, name="latent_mask") zeros_mask = tf.zeros(common_layers.shape_list(layer)[:-1] + [filters], dtype=tf.float32) layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1) extra_loss = self.get_extra_loss(latent_mean, latent_std) return layer, extra_loss
def inject_latent(self, layer, inputs, target): """Inject a VAE-style latent.""" # Latent for stochastic model filters = 128 full_video = tf.stack(inputs + [target], axis=1) latent_mean, latent_std = self.construct_latent_tower(full_video, time_axis=1) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent = tfl.flatten(latent) latent = tf.expand_dims(latent, axis=1) latent = tf.expand_dims(latent, axis=1) latent_mask = tfl.dense(latent, filters, name="latent_mask") zeros_mask = tf.zeros(common_layers.shape_list(layer)[:-1] + [filters], dtype=tf.float32) layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1) extra_loss = self.get_kl_loss([latent_mean], [latent_std]) return layer, extra_loss
def inject_latent(self, layer, inputs, target, action): """Inject a VAE-style latent.""" del action # Latent for stochastic model filters = 128 full_video = tf.stack(inputs + [target], axis=1) latent_mean, latent_std = self.construct_latent_tower( full_video, time_axis=1) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent = tfl.flatten(latent) latent = tf.expand_dims(latent, axis=1) latent = tf.expand_dims(latent, axis=1) latent_mask = tfl.dense(latent, filters, name="latent_mask") zeros_mask = tf.zeros( common_layers.shape_list(layer)[:-1] + [filters], dtype=tf.float32) layer = tf.concat([layer, latent_mask + zeros_mask], axis=-1) extra_loss = self.get_kl_loss([latent_mean], [latent_std]) return layer, extra_loss
def construct_model(self, images, actions, rewards): """Build convolutional lstm video predictor using CDNA, or DNA. Args: images: list of tensors of ground truth image sequences there should be a 4D image ?xWxHxC for each timestep actions: list of action tensors each action should be in the shape ?x1xZ rewards: list of reward tensors each reward should be in the shape ?x1xZ Returns: gen_images: predicted future image frames gen_rewards: predicted future rewards latent_mean: mean of approximated posterior latent_std: std of approximated posterior Raises: ValueError: if more than 1 mask specified for DNA model. """ context_frames = self.hparams.video_num_input_frames buffer_size = self.hparams.reward_prediction_buffer_size if buffer_size == 0: buffer_size = context_frames if buffer_size > context_frames: raise ValueError( "Buffer size is bigger than context frames %d %d." % (buffer_size, context_frames)) batch_size = common_layers.shape_list(images)[1] ss_func = self.get_scheduled_sample_func(batch_size) def process_single_frame(prev_outputs, inputs): """Process a single frame of the video.""" cur_image, input_reward, action = inputs time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs generated_items = [prev_image] groundtruth_items = [cur_image] done_warm_start = tf.greater(time_step, context_frames - 1) input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Prediction pred_image, lstm_states = self.construct_predictive_tower( input_image, None, action, lstm_states, latent) if self.hparams.reward_prediction: reward_input_image = pred_image if self.hparams.reward_prediction_stop_gradient: reward_input_image = tf.stop_gradient(reward_input_image) with tf.control_dependencies([time_step]): frame_buf = [reward_input_image] + frame_buf[:-1] pred_reward = self.reward_prediction(frame_buf, None, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = prev_reward time_step += 1 outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states) return outputs # Latent tower latent = None if self.hparams.stochastic_model: latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) # HACK: Do first step outside to initialize all the variables lstm_states = [None] * 7 frame_buffer = [tf.zeros_like(images[0])] * buffer_size inputs = images[0], rewards[0], actions[0] prev_outputs = (tf.constant(0), tf.zeros_like(images[0]), tf.zeros_like(rewards[0]), frame_buffer, lstm_states) initializers = process_single_frame(prev_outputs, inputs) first_gen_images = tf.expand_dims(initializers[1], axis=0) first_gen_rewards = tf.expand_dims(initializers[2], axis=0) inputs = (images[1:-1], rewards[1:-1], actions[1:-1]) outputs = tf.scan(process_single_frame, inputs, initializers) gen_images, gen_rewards = outputs[1:3] gen_images = tf.concat((first_gen_images, gen_images), axis=0) gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0) if self.hparams.stochastic_model: return gen_images, gen_rewards, [latent_mean], [latent_std] else: return gen_images, gen_rewards, None, None
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_features): del target_frame if not self.hparams.use_vae or self.hparams.use_gan: raise NotImplementedError("Only supporting VAE for now.") if self.has_pred_actions or self.has_values: raise NotImplementedError( "Parameter sharing with policy not supported.") image, action, reward = frames[0], actions[0], rewards[0] latent_dims = self.hparams.z_dim batch_size = common_layers.shape_list(image)[0] if internal_states is None: # Initialize LSTM State frame_index = 0 lstm_state = [None] * 7 cond_latent_state, prior_latent_state = None, None gen_prior_video = [] else: (frame_index, lstm_state, cond_latent_state, prior_latent_state, gen_prior_video) = internal_states z_mu, log_sigma_sq = video_features z_mu, log_sigma_sq = z_mu[frame_index], log_sigma_sq[frame_index] # Sample latents using a gaussian centered at conditional mu and std. latent = common_video.get_gaussian_tensor(z_mu, log_sigma_sq) # Sample prior latents from isotropic normal distribution. prior_latent = tf.random_normal(tf.shape(latent), dtype=tf.float32) # # LSTM that encodes correlations between conditional latents. # # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf enc_cond_latent, cond_latent_state = common_video.basic_lstm( latent, cond_latent_state, latent_dims, name="cond_latent") # LSTM that encodes correlations between prior latents. enc_prior_latent, prior_latent_state = common_video.basic_lstm( prior_latent, prior_latent_state, latent_dims, name="prior_latent") all_latents = tf.concat([enc_cond_latent, enc_prior_latent], axis=0) all_image = tf.concat([image, image], 0) all_action = tf.concat([action, action], 0) if self.has_actions else None all_pred_images, lstm_state = self.construct_predictive_tower( all_image, None, all_action, lstm_state, all_latents, concat_latent=True) cond_pred_images, prior_pred_images = \ all_pred_images[:batch_size], all_pred_images[batch_size:] if self.is_training and self.hparams.use_vae: pred_image = cond_pred_images else: pred_image = prior_pred_images gen_prior_video.append(prior_pred_images) internal_states = (frame_index + 1, lstm_state, cond_latent_state, prior_latent_state, gen_prior_video) if not self.has_rewards: return pred_image, None, 0.0, internal_states pred_reward = self.reward_prediction(pred_image, action, reward, latent) return pred_image, pred_reward, None, None, 0.0, internal_states
def construct_model(self, images, actions, rewards): """Model that takes in images and returns predictions. Args: images: list of 4-D Tensors indexed by time. (batch_size, width, height, channels) actions: list of action tensors each action should be in the shape ?x1xZ rewards: list of reward tensors each reward should be in the shape ?x1xZ Returns: video: list of 4-D predicted frames. all_rewards: predicted rewards. latent_means: list of gaussian means conditioned on the input at every frame. latent_stds: list of gaussian stds conditioned on the input at every frame. Raises: ValueError: If not exactly one of self.hparams.vae or self.hparams.gan is set to True. """ if not self.hparams.use_vae and not self.hparams.use_gan: raise ValueError( "Set at least one of use_vae or use_gan to be True") if self.hparams.gan_optimization not in ["joint", "sequential"]: raise ValueError( "self.hparams.gan_optimization should be either joint " "or sequential got %s" % self.hparams.gan_optimization) images = tf.unstack(images, axis=0) actions = tf.unstack(actions, axis=0) rewards = tf.unstack(rewards, axis=0) latent_dims = self.hparams.z_dim context_frames = self.hparams.video_num_input_frames seq_len = len(images) input_shape = common_layers.shape_list(images[0]) batch_size = input_shape[0] # Model does not support reward-conditioned frame generation. fake_rewards = rewards[:-1] # Concatenate x_{t-1} and x_{t} along depth and encode it to # produce the mean and standard deviation of z_{t-1} image_pairs = tf.concat([images[:seq_len - 1], images[1:seq_len]], axis=-1) z_mu, z_log_sigma_sq = self.encoder(image_pairs) # Unstack z_mu and z_log_sigma_sq along the time dimension. z_mu = tf.unstack(z_mu, axis=0) z_log_sigma_sq = tf.unstack(z_log_sigma_sq, axis=0) iterable = zip(images[:-1], actions[:-1], fake_rewards, z_mu, z_log_sigma_sq) # Initialize LSTM State lstm_state = [None] * 7 gen_cond_video, gen_prior_video, all_rewards, latent_means, latent_stds = \ [], [], [], [], [] pred_image = tf.zeros_like(images[0]) prior_latent_state, cond_latent_state = None, None train_mode = self.hparams.mode == tf.estimator.ModeKeys.TRAIN # Create scheduled sampling function ss_func = self.get_scheduled_sample_func(batch_size) with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE): for step, (image, action, reward, mu, log_sigma_sq) in enumerate(iterable): # pylint:disable=line-too-long # Sample latents using a gaussian centered at conditional mu and std. latent = common_video.get_gaussian_tensor(mu, log_sigma_sq) # Sample prior latents from isotropic normal distribution. prior_latent = tf.random_normal(tf.shape(latent), dtype=tf.float32) # LSTM that encodes correlations between conditional latents. # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf enc_cond_latent, cond_latent_state = common_video.basic_lstm( latent, cond_latent_state, latent_dims, name="cond_latent") # LSTM that encodes correlations between prior latents. enc_prior_latent, prior_latent_state = common_video.basic_lstm( prior_latent, prior_latent_state, latent_dims, name="prior_latent") # Scheduled Sampling done_warm_start = step > context_frames - 1 groundtruth_items = [image] generated_items = [pred_image] input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) all_latents = tf.concat([enc_cond_latent, enc_prior_latent], axis=0) all_image = tf.concat([input_image, input_image], axis=0) all_action = tf.concat([action, action], axis=0) all_rewards = tf.concat([reward, reward], axis=0) all_pred_images, lstm_state, _ = self.construct_predictive_tower( all_image, all_rewards, all_action, lstm_state, all_latents, concat_latent=True) cond_pred_images, prior_pred_images = \ all_pred_images[:batch_size], all_pred_images[batch_size:] if train_mode and self.hparams.use_vae: pred_image = cond_pred_images else: pred_image = prior_pred_images gen_cond_video.append(cond_pred_images) gen_prior_video.append(prior_pred_images) latent_means.append(mu) latent_stds.append(log_sigma_sq) gen_cond_video = tf.stack(gen_cond_video, axis=0) self.gen_prior_video = tf.stack(gen_prior_video, axis=0) fake_rewards = tf.stack(fake_rewards, axis=0) if train_mode and self.hparams.use_vae: return gen_cond_video, fake_rewards, latent_means, latent_stds else: return self.gen_prior_video, fake_rewards, latent_means, latent_stds
def construct_model(self, images, actions, rewards): """Builds the stochastic model. The model first encodes all the images (x_t) in the sequence using the encoder. Let"s call the output e_t. Then it predicts the latent state of the next frame using a recurrent posterior network z ~ q(z|e_{0:t}) = N(mu(e_{0:t}), sigma(e_{0:t})). Another recurrent network predicts the embedding of the next frame using the approximated posterior e_{t+1} = p(e_{t+1}|e_{0:t}, z) Finally, the decoder decodes e_{t+1} into x_{t+1}. Skip connections from encoder to decoder help with reconstruction. Args: images: tensor of ground truth image sequences actions: NOT used list of action tensors rewards: NOT used list of reward tensors Returns: gen_images: generated images fakr_rewards: input rewards as reward prediction! pred_mu: predited means of posterior pred_logvar: predicted log(var) of posterior """ # model does not support action conditioned and reward prediction fake_reward_prediction = rewards del actions, rewards z_dim = self.hparams.z_dim g_dim = self.hparams.g_dim rnn_size = self.hparams.rnn_size prior_rnn_layers = self.hparams.prior_rnn_layers posterior_rnn_layers = self.hparams.posterior_rnn_layers predictor_rnn_layers = self.hparams.predictor_rnn_layers context_frames = self.hparams.video_num_input_frames has_batchnorm = self.hparams.has_batchnorm seq_len, batch_size, _, _, color_channels = common_layers.shape_list( images) # LSTM initial sizesstates. prior_states = [None] * prior_rnn_layers posterior_states = [None] * posterior_rnn_layers predictor_states = [None] * predictor_rnn_layers tf.logging.info(">>>> Encoding") # Encoding: enc_images, enc_skips = [], [] images = tf.unstack(images, axis=0) for i, image in enumerate(images): with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): enc, skips = self.encoder(image, g_dim, has_batchnorm=has_batchnorm) enc = tfl.flatten(enc) enc_images.append(enc) enc_skips.append(skips) tf.logging.info(">>>> Prediction") # Prediction pred_mu_pos = [] pred_logvar_pos = [] pred_mu_prior = [] pred_logvar_prior = [] gen_images = [] for i in range(1, seq_len): with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): # current encoding if self.is_training or len(gen_images) < context_frames: h_current = enc_images[i - 1] else: h_current, _ = self.encoder(gen_images[-1], g_dim) h_current = tfl.flatten(h_current) # target encoding h_target = enc_images[i] with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE): # Prior parameters if self.hparams.learned_prior: mu_prior, logvar_prior, prior_states = self.lstm_gaussian( h_current, prior_states, rnn_size, z_dim, prior_rnn_layers, "prior") else: mu_prior = tf.zeros((batch_size, z_dim)) logvar_prior = tf.zeros((batch_size, z_dim)) # Only use Posterior if it's training time if self.is_training or len(gen_images) < context_frames: mu_pos, logvar_pos, posterior_states = self.lstm_gaussian( h_target, posterior_states, rnn_size, z_dim, posterior_rnn_layers, "posterior") # Sample z from posterior distribution z = common_video.get_gaussian_tensor(mu_pos, logvar_pos) else: mu_pos = tf.zeros_like(mu_prior) logvar_pos = tf.zeros_like(logvar_prior) z = common_video.get_gaussian_tensor( mu_prior, logvar_prior) # Predict output encoding h_pred, predictor_states = self.stacked_lstm( tf.concat([h_current, z], axis=1), predictor_states, rnn_size, g_dim, predictor_rnn_layers) pred_mu_pos.append(tf.identity(mu_pos, "mu_pos")) pred_logvar_pos.append(tf.identity(logvar_pos, "logvar_pos")) pred_mu_prior.append(tf.identity(mu_prior, "mu_prior")) pred_logvar_prior.append( tf.identity(logvar_prior, "logvar_prior")) with tf.variable_scope("decoding", reuse=tf.AUTO_REUSE): skip_index = min(context_frames - 1, i - 1) h_pred = tf.reshape(h_pred, [batch_size, 1, 1, g_dim]) if self.hparams.has_skips: x_pred = self.decoder(h_pred, color_channels, skips=enc_skips[skip_index], has_batchnorm=has_batchnorm) else: x_pred = self.decoder(h_pred, color_channels, has_batchnorm=has_batchnorm) gen_images.append(x_pred) tf.logging.info(">>>> Done") gen_images = tf.stack(gen_images, axis=0) return { "gen_images": gen_images, "fake_reward_prediction": fake_reward_prediction, "pred_mu_pos": pred_mu_pos, "pred_logvar_pos": pred_logvar_pos, "pred_mu_prior": pred_mu_prior, "pred_logvar_prior": pred_logvar_prior }
def construct_model(self, images, actions, rewards): """Build convolutional lstm video predictor using CDNA, or DNA. Args: images: list of tensors of ground truth image sequences there should be a 4D image ?xWxHxC for each timestep actions: list of action tensors each action should be in the shape ?x1xZ rewards: list of reward tensors each reward should be in the shape ?x1xZ Returns: gen_images: predicted future image frames gen_rewards: predicted future rewards latent_mean: mean of approximated posterior latent_std: std of approximated posterior Raises: ValueError: if more than 1 mask specified for DNA model. """ context_frames = self.hparams.video_num_input_frames buffer_size = self.hparams.reward_prediction_buffer_size if buffer_size == 0: buffer_size = context_frames if buffer_size > context_frames: raise ValueError("Buffer size is bigger than context frames %d %d." % (buffer_size, context_frames)) batch_size = common_layers.shape_list(images[0])[0] ss_func = self.get_scheduled_sample_func(batch_size) def process_single_frame(prev_outputs, inputs): """Process a single frame of the video.""" cur_image, input_reward, action = inputs time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs # sample from softmax (by argmax). this is noop for non-softmax loss. prev_image = self.get_sampled_frame(prev_image) generated_items = [prev_image] groundtruth_items = [cur_image] done_warm_start = tf.greater(time_step, context_frames - 1) input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Prediction pred_image, lstm_states, _ = self.construct_predictive_tower( input_image, None, action, lstm_states, latent) if self.hparams.reward_prediction: reward_input_image = self.get_sampled_frame(pred_image) if self.hparams.reward_prediction_stop_gradient: reward_input_image = tf.stop_gradient(reward_input_image) with tf.control_dependencies([time_step]): frame_buf = [reward_input_image] + frame_buf[:-1] pred_reward = self.reward_prediction(frame_buf, None, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = prev_reward time_step += 1 outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states) return outputs # Latent tower latent = None if self.hparams.stochastic_model: latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) # HACK: Do first step outside to initialize all the variables lstm_states = [None] * (5 if self.hparams.small_mode else 7) frame_buffer = [tf.zeros_like(images[0])] * buffer_size inputs = images[0], rewards[0], actions[0] init_image_shape = common_layers.shape_list(images[0]) if self.is_per_pixel_softmax: init_image_shape[-1] *= 256 init_image = tf.zeros(init_image_shape, dtype=images.dtype) prev_outputs = (tf.constant(0), init_image, tf.zeros_like(rewards[0]), frame_buffer, lstm_states) initializers = process_single_frame(prev_outputs, inputs) first_gen_images = tf.expand_dims(initializers[1], axis=0) first_gen_rewards = tf.expand_dims(initializers[2], axis=0) inputs = (images[1:-1], rewards[1:-1], actions[1:-1]) outputs = tf.scan(process_single_frame, inputs, initializers) gen_images, gen_rewards = outputs[1:3] gen_images = tf.concat((first_gen_images, gen_images), axis=0) gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0) if self.hparams.stochastic_model: return gen_images, gen_rewards, [latent_mean], [latent_std] else: return gen_images, gen_rewards, None, None