def lstm_gaussian(self, inputs, states, hidden_size, output_size, nlayers, name): """Stacked LSTM layers with FC layer as input and gaussian as output. Args: inputs: input tensor states: a list of internal lstm states for each layer hidden_size: number of lstm units output_size: size of the output nlayers: number of lstm layers name: the lstm name for scope definition Returns: mu: mean of the predicted gaussian logvar: log(var) of the predicted gaussian skips: a list of updated lstm states for each layer """ net = inputs net = tfl.dense(net, hidden_size, activation=None, name="%sf1" % name) for i in range(nlayers): net, states[i] = common_video.basic_lstm(net, states[i], hidden_size, name="%slstm%d" % (name, i)) mu = tfl.dense(net, output_size, activation=None, name="%sf2mu" % name) logvar = tfl.dense(net, output_size, activation=None, name="%sf2log" % name) return mu, logvar, states
def testBasicLstm(self): """Tests that the parameters of the LSTM are shared across time.""" with tf.Graph().as_default(): state = None for _ in range(10): inputs = tf.random_uniform(shape=(32, 16)) _, state = common_video.basic_lstm( inputs, state, num_units=100, name="basic") num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) # 4 * ((100 + 16)*100 + 100) => 4 * (W_{xh} + W_{hh} + b) self.assertEqual(num_params, 46800)
def stacked_lstm(self, inputs, states, hidden_size, output_size, nlayers): """Stacked LSTM layers with FC layers as input and output embeddings. Args: inputs: input tensor states: a list of internal lstm states for each layer hidden_size: number of lstm units output_size: size of the output nlayers: number of lstm layers Returns: net: output of the network skips: a list of updated lstm states for each layer """ net = inputs net = tfl.dense(net, hidden_size, activation=None, name="af1") for i in range(nlayers): net, states[i] = common_video.basic_lstm(net, states[i], hidden_size, name="alstm%d" % i) net = tfl.dense(net, output_size, activation=tf.nn.tanh, name="af2") return net, states
def lstm_gaussian(self, inputs, states, hidden_size, output_size, nlayers): """Stacked LSTM layers with FC layer as input and gaussian as output. Args: inputs: input tensor states: a list of internal lstm states for each layer hidden_size: number of lstm units output_size: size of the output nlayers: number of lstm layers Returns: mu: mean of the predicted gaussian logvar: log(var) of the predicted gaussian skips: a list of updated lstm states for each layer """ net = inputs net = tfl.dense(net, hidden_size, activation=None, name="bf1") for i in range(nlayers): net, states[i] = common_video.basic_lstm( net, states[i], hidden_size, name="blstm%d"%i) mu = tfl.dense(net, output_size, activation=None, name="bf2mu") logvar = tfl.dense(net, output_size, activation=None, name="bf2log") return mu, logvar, states
def stacked_lstm(self, inputs, states, hidden_size, output_size, nlayers): """Stacked LSTM layers with FC layers as input and output embeddings. Args: inputs: input tensor states: a list of internal lstm states for each layer hidden_size: number of lstm units output_size: size of the output nlayers: number of lstm layers Returns: net: output of the network skips: a list of updated lstm states for each layer """ net = inputs net = tfl.dense( net, hidden_size, activation=None, name="af1") for i in range(nlayers): net, states[i] = common_video.basic_lstm( net, states[i], hidden_size, name="alstm%d"%i) net = tfl.dense( net, output_size, activation=tf.nn.tanh, name="af2") return net, states
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_features): del target_frame if not self.hparams.use_vae or self.hparams.use_gan: raise NotImplementedError("Only supporting VAE for now.") if self.has_pred_actions or self.has_values: raise NotImplementedError( "Parameter sharing with policy not supported.") image, action, reward = frames[0], actions[0], rewards[0] latent_dims = self.hparams.z_dim batch_size = common_layers.shape_list(image)[0] if internal_states is None: # Initialize LSTM State frame_index = 0 lstm_state = [None] * 7 cond_latent_state, prior_latent_state = None, None gen_prior_video = [] else: (frame_index, lstm_state, cond_latent_state, prior_latent_state, gen_prior_video) = internal_states z_mu, log_sigma_sq = video_features z_mu, log_sigma_sq = z_mu[frame_index], log_sigma_sq[frame_index] # Sample latents using a gaussian centered at conditional mu and std. latent = common_video.get_gaussian_tensor(z_mu, log_sigma_sq) # Sample prior latents from isotropic normal distribution. prior_latent = tf.random_normal(tf.shape(latent), dtype=tf.float32) # # LSTM that encodes correlations between conditional latents. # # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf enc_cond_latent, cond_latent_state = common_video.basic_lstm( latent, cond_latent_state, latent_dims, name="cond_latent") # LSTM that encodes correlations between prior latents. enc_prior_latent, prior_latent_state = common_video.basic_lstm( prior_latent, prior_latent_state, latent_dims, name="prior_latent") all_latents = tf.concat([enc_cond_latent, enc_prior_latent], axis=0) all_image = tf.concat([image, image], 0) all_action = tf.concat([action, action], 0) if self.has_actions else None all_pred_images, lstm_state = self.construct_predictive_tower( all_image, None, all_action, lstm_state, all_latents, concat_latent=True) cond_pred_images, prior_pred_images = \ all_pred_images[:batch_size], all_pred_images[batch_size:] if self.is_training and self.hparams.use_vae: pred_image = cond_pred_images else: pred_image = prior_pred_images gen_prior_video.append(prior_pred_images) internal_states = (frame_index + 1, lstm_state, cond_latent_state, prior_latent_state, gen_prior_video) if not self.has_rewards: return pred_image, None, 0.0, internal_states pred_reward = self.reward_prediction(pred_image, action, reward, latent) return pred_image, pred_reward, None, None, 0.0, internal_states
def construct_model(self, images, actions, rewards): """Model that takes in images and returns predictions. Args: images: list of 4-D Tensors indexed by time. (batch_size, width, height, channels) actions: list of action tensors each action should be in the shape ?x1xZ rewards: list of reward tensors each reward should be in the shape ?x1xZ Returns: video: list of 4-D predicted frames. all_rewards: predicted rewards. latent_means: list of gaussian means conditioned on the input at every frame. latent_stds: list of gaussian stds conditioned on the input at every frame. Raises: ValueError: If not exactly one of self.hparams.vae or self.hparams.gan is set to True. """ if not self.hparams.use_vae and not self.hparams.use_gan: raise ValueError( "Set at least one of use_vae or use_gan to be True") if self.hparams.gan_optimization not in ["joint", "sequential"]: raise ValueError( "self.hparams.gan_optimization should be either joint " "or sequential got %s" % self.hparams.gan_optimization) images = tf.unstack(images, axis=0) actions = tf.unstack(actions, axis=0) rewards = tf.unstack(rewards, axis=0) latent_dims = self.hparams.z_dim context_frames = self.hparams.video_num_input_frames seq_len = len(images) input_shape = common_layers.shape_list(images[0]) batch_size = input_shape[0] # Model does not support reward-conditioned frame generation. fake_rewards = rewards[:-1] # Concatenate x_{t-1} and x_{t} along depth and encode it to # produce the mean and standard deviation of z_{t-1} image_pairs = tf.concat([images[:seq_len - 1], images[1:seq_len]], axis=-1) z_mu, z_log_sigma_sq = self.encoder(image_pairs) # Unstack z_mu and z_log_sigma_sq along the time dimension. z_mu = tf.unstack(z_mu, axis=0) z_log_sigma_sq = tf.unstack(z_log_sigma_sq, axis=0) iterable = zip(images[:-1], actions[:-1], fake_rewards, z_mu, z_log_sigma_sq) # Initialize LSTM State lstm_state = [None] * 7 gen_cond_video, gen_prior_video, all_rewards, latent_means, latent_stds = \ [], [], [], [], [] pred_image = tf.zeros_like(images[0]) prior_latent_state, cond_latent_state = None, None train_mode = self.hparams.mode == tf.estimator.ModeKeys.TRAIN # Create scheduled sampling function ss_func = self.get_scheduled_sample_func(batch_size) with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE): for step, (image, action, reward, mu, log_sigma_sq) in enumerate(iterable): # pylint:disable=line-too-long # Sample latents using a gaussian centered at conditional mu and std. latent = common_video.get_gaussian_tensor(mu, log_sigma_sq) # Sample prior latents from isotropic normal distribution. prior_latent = tf.random_normal(tf.shape(latent), dtype=tf.float32) # LSTM that encodes correlations between conditional latents. # Pg 22 in https://arxiv.org/pdf/1804.01523.pdf enc_cond_latent, cond_latent_state = common_video.basic_lstm( latent, cond_latent_state, latent_dims, name="cond_latent") # LSTM that encodes correlations between prior latents. enc_prior_latent, prior_latent_state = common_video.basic_lstm( prior_latent, prior_latent_state, latent_dims, name="prior_latent") # Scheduled Sampling done_warm_start = step > context_frames - 1 groundtruth_items = [image] generated_items = [pred_image] input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) all_latents = tf.concat([enc_cond_latent, enc_prior_latent], axis=0) all_image = tf.concat([input_image, input_image], axis=0) all_action = tf.concat([action, action], axis=0) all_rewards = tf.concat([reward, reward], axis=0) all_pred_images, lstm_state, _ = self.construct_predictive_tower( all_image, all_rewards, all_action, lstm_state, all_latents, concat_latent=True) cond_pred_images, prior_pred_images = \ all_pred_images[:batch_size], all_pred_images[batch_size:] if train_mode and self.hparams.use_vae: pred_image = cond_pred_images else: pred_image = prior_pred_images gen_cond_video.append(cond_pred_images) gen_prior_video.append(prior_pred_images) latent_means.append(mu) latent_stds.append(log_sigma_sq) gen_cond_video = tf.stack(gen_cond_video, axis=0) self.gen_prior_video = tf.stack(gen_prior_video, axis=0) fake_rewards = tf.stack(fake_rewards, axis=0) if train_mode and self.hparams.use_vae: return gen_cond_video, fake_rewards, latent_means, latent_stds else: return self.gen_prior_video, fake_rewards, latent_means, latent_stds