def top(self, body_output, _): frames = body_output if isinstance(body_output, list): frames = tf.stack(body_output, axis=1) rgb_frames = common_layers.convert_real_to_rgb(frames) common_video.gif_summary("body_output", rgb_frames) return tf.expand_dims(rgb_frames, axis=-1)
def top(self, body_output, targets): num_channels = self._model_hparams.problem.num_channels shape = common_layers.shape_list(body_output) reshape_shape = shape[:-1] + [num_channels, self.top_dimensionality] res = tf.reshape(body_output, reshape_shape) # Calculate argmax so as to have a summary with the produced images. x = tf.argmax(tf.reshape(res, [-1, self.top_dimensionality]), axis=-1) x = tf.reshape(x, shape[:-1] + [num_channels]) common_video.gif_summary("results", x, max_outputs=1) return res
def testGifSummary(self): for c in (1, 3): images_shape = (1, 12, 48, 64, c) # batch, time, height, width, channels images = np.random.randint(256, size=images_shape).astype(np.uint8) with self.test_session(): summary = common_video.gif_summary( "gif", tf.convert_to_tensor(images), fps=10) summary_string = summary.eval() summary = tf.Summary() summary.ParseFromString(summary_string) self.assertEqual(1, len(summary.value)) self.assertTrue(summary.value[0].HasField("image")) encoded = summary.value[0].image.encoded_image_string self.assertEqual(encoded, common_video._encode_gif(images[0], fps=10)) # pylint: disable=protected-access
def targets_bottom(self, x): # pylint: disable=arguments-differ common_video.gif_summary("targets_bottom", x) return common_layers.convert_rgb_to_real(x)
def bottom(self, x): common_video.gif_summary("inputs", x) return common_layers.convert_rgb_to_real(x)
def targets_bottom(self, x): common_video.gif_summary("targets", x, max_outputs=1) x = common_layers.standardize_images(x) return x
def targets_bottom(self, x): common_video.gif_summary("targets", x, max_outputs=1) return x
def bottom(self, x): common_video.gif_summary("inputs", x, max_outputs=1) return x
def process(self, inputs, targets): all_frames = tf.unstack(inputs, axis = 1) + tf.unstack(targets, axis = 1) hparams = self.hparams batch_size = common_layers.shape_list(all_frames[0])[0] z_dim = hparams.z_dim g_dim = hparams.g_dim rnn_size = hparams.rnn_size prior_rnn_layers = hparams.prior_rnn_layers posterior_rnn_layers = hparams.posterior_rnn_layers predictor_rnn_layers = hparams.predictor_rnn_layers num_input_frames = hparams.num_input_frames num_target_frames = hparams.num_target_frames num_all_frames = num_input_frames + num_target_frames #Creating RNN cells predictor_cell = self.rnn_model(rnn_size, "predictor", n_layers = predictor_rnn_layers) prior_cell = self.rnn_model(rnn_size, "prior", n_layers = prior_rnn_layers) posterior_cell = self.rnn_model(rnn_size, "posterior", n_layers = posterior_rnn_layers) #Getting RNN states predictor_state = predictor_cell.zero_state(batch_size, tf.float32) prior_state = prior_cell.zero_state(batch_size, tf.float32) posterior_state = posterior_cell.zero_state(batch_size, tf.float32) #Encoding enc_frames, enc_skips = [], [] for frame in all_frames if self.is_training else all_frames[:num_input_frames]: with tf.variable_scope("encoder", reuse = tf.AUTO_REUSE): enc, skip = self.encoder(frame) enc_frames.append(enc) enc_skips.append(skip) #Prediction prior_mus = [] prior_logvars = [] posterior_mus = [] posterior_logvars = [] predicted_frames = [] z_positions = [] skip = None if self.is_training: for i in range(1,num_all_frames): h = enc_frames[i-1] h_target = enc_frames[i] if i < num_input_frames: skip = enc_skips[i-1] with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE): mu, log_var, posterior_state = self.gaussian_rnn(posterior_cell, h_target, posterior_state, z_dim, "posterior") mu_p, log_var_p, prior_state = self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior") z = utils.get_gaussian_tensor(mu,log_var) h_pred, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1),\ predictor_state, g_dim, "predictor") with tf.variable_scope("decoder", reuse = tf.AUTO_REUSE): x_pred = self.decoder(h_pred, skip) predicted_frames.append(x_pred) prior_mus.append(mu_p) prior_logvars.append(log_var_p) posterior_mus.append(mu) posterior_logvars.append(log_var) z_positions.append(z) else: for i in range(1, num_all_frames): if i < num_input_frames: h = enc_frames[i-1] skip = enc_skips[i-1] else: with tf.variable_scope("encoder", reuse = tf.AUTO_REUSE): h, _ = self.encoder(predicted_frames[-1]) mu = log_var = mu_p = log_var_p = None if i < num_input_frames: h_target = enc_frames[i] with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE): mu, log_var, posterior_state = self.gaussian_rnn(posterior_cell, h_target, posterior_state,\ z_dim, "posterior") mu_p, log_var_p, prior_state= self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior") z = utils.get_gaussian_tensor(mu,log_var) _, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1), predictor_state,\ g_dim, "predictor") x_pred = all_frames[i] else: with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE): mu_p, log_var_p, prior_state = self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior") z = utils.get_gaussian_tensor(mu_p, log_var_p) h_pred, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1), predictor_state, g_dim, "predictor") with tf.variable_scope("decoder", reuse = tf.AUTO_REUSE): x_pred = self.decoder(h_pred,skip) predicted_frames.append(x_pred) prior_mus.append(mu_p) prior_logvars.append(log_var_p) posterior_mus.append(mu) posterior_logvars.append(log_var) z_positions.append(z) recon_loss = 0 kl_loss = 0 #recon loss recon_loss = l2_loss(tf.stack(predicted_frames), tf.stack(all_frames[1:]))*(num_all_frames-1) if self.is_training: #kl loss kl_loss = self.get_kl_loss(posterior_mus,posterior_logvars, prior_mus,\ prior_logvars) pred_outputs = tf.stack(predicted_frames[num_input_frames-1:], axis = 1) rgb_frames = tf.tile(common_layers.convert_real_to_rgb(tf.stack(predicted_frames, axis = 1)), [1,1,1,1,3]) all_frames = tf.stack(all_frames, axis = 1) all_frames_rgb = tf.tile(common_layers.convert_real_to_rgb(all_frames), [1,1,1,1,3]) common_video.gif_summary("body_output", rgb_frames) common_video.gif_summary("all_ground_frames", all_frames_rgb) tf.summary.scalar("kl_loss", kl_loss) tf.summary.scalar("recon_loss", recon_loss) loss = recon_loss + kl_loss return pred_outputs, loss, tf.stack(z_positions,axis = 1)