예제 #1
0
 def top(self, body_output, _):
     frames = body_output
     if isinstance(body_output, list):
         frames = tf.stack(body_output, axis=1)
     rgb_frames = common_layers.convert_real_to_rgb(frames)
     common_video.gif_summary("body_output", rgb_frames)
     return tf.expand_dims(rgb_frames, axis=-1)
예제 #2
0
 def top(self, body_output, targets):
     num_channels = self._model_hparams.problem.num_channels
     shape = common_layers.shape_list(body_output)
     reshape_shape = shape[:-1] + [num_channels, self.top_dimensionality]
     res = tf.reshape(body_output, reshape_shape)
     # Calculate argmax so as to have a summary with the produced images.
     x = tf.argmax(tf.reshape(res, [-1, self.top_dimensionality]), axis=-1)
     x = tf.reshape(x, shape[:-1] + [num_channels])
     common_video.gif_summary("results", x, max_outputs=1)
     return res
  def testGifSummary(self):
    for c in (1, 3):
      images_shape = (1, 12, 48, 64, c)  # batch, time, height, width, channels
      images = np.random.randint(256, size=images_shape).astype(np.uint8)

      with self.test_session():
        summary = common_video.gif_summary(
            "gif", tf.convert_to_tensor(images), fps=10)
        summary_string = summary.eval()

      summary = tf.Summary()
      summary.ParseFromString(summary_string)

      self.assertEqual(1, len(summary.value))
      self.assertTrue(summary.value[0].HasField("image"))
      encoded = summary.value[0].image.encoded_image_string

      self.assertEqual(encoded, common_video._encode_gif(images[0], fps=10))  # pylint: disable=protected-access
예제 #4
0
  def testGifSummary(self):
    for c in (1, 3):
      images_shape = (1, 12, 48, 64, c)  # batch, time, height, width, channels
      images = np.random.randint(256, size=images_shape).astype(np.uint8)

      with self.test_session():
        summary = common_video.gif_summary(
            "gif", tf.convert_to_tensor(images), fps=10)
        summary_string = summary.eval()

      summary = tf.Summary()
      summary.ParseFromString(summary_string)

      self.assertEqual(1, len(summary.value))
      self.assertTrue(summary.value[0].HasField("image"))
      encoded = summary.value[0].image.encoded_image_string

      self.assertEqual(encoded, common_video._encode_gif(images[0], fps=10))  # pylint: disable=protected-access
예제 #5
0
 def targets_bottom(self, x):  # pylint: disable=arguments-differ
     common_video.gif_summary("targets_bottom", x)
     return common_layers.convert_rgb_to_real(x)
예제 #6
0
 def bottom(self, x):
     common_video.gif_summary("inputs", x)
     return common_layers.convert_rgb_to_real(x)
예제 #7
0
 def targets_bottom(self, x):
     common_video.gif_summary("targets", x, max_outputs=1)
     x = common_layers.standardize_images(x)
     return x
예제 #8
0
 def targets_bottom(self, x):
   common_video.gif_summary("targets", x, max_outputs=1)
   return x
예제 #9
0
 def bottom(self, x):
   common_video.gif_summary("inputs", x, max_outputs=1)
   return x
예제 #10
0
    def process(self, inputs, targets):
        all_frames = tf.unstack(inputs, axis = 1) + tf.unstack(targets, axis = 1)
        hparams = self.hparams

        batch_size = common_layers.shape_list(all_frames[0])[0]
        
        z_dim = hparams.z_dim
        g_dim = hparams.g_dim
        rnn_size = hparams.rnn_size
        prior_rnn_layers = hparams.prior_rnn_layers
        posterior_rnn_layers = hparams.posterior_rnn_layers
        predictor_rnn_layers = hparams.predictor_rnn_layers

        num_input_frames = hparams.num_input_frames
        num_target_frames = hparams.num_target_frames
        num_all_frames = num_input_frames + num_target_frames

        #Creating RNN cells
        predictor_cell = self.rnn_model(rnn_size, "predictor", n_layers = predictor_rnn_layers)
        prior_cell = self.rnn_model(rnn_size, "prior", n_layers = prior_rnn_layers)
        posterior_cell = self.rnn_model(rnn_size, "posterior", n_layers = posterior_rnn_layers)

        #Getting RNN states 
        predictor_state = predictor_cell.zero_state(batch_size, tf.float32)
        prior_state = prior_cell.zero_state(batch_size, tf.float32)
        posterior_state = posterior_cell.zero_state(batch_size, tf.float32)

        #Encoding
        enc_frames, enc_skips = [], []
        for frame in all_frames if self.is_training else all_frames[:num_input_frames]:
            with tf.variable_scope("encoder", reuse = tf.AUTO_REUSE):
                enc, skip = self.encoder(frame)
                enc_frames.append(enc)
                enc_skips.append(skip)

        #Prediction
        prior_mus = []
        prior_logvars = []
        posterior_mus = []
        posterior_logvars = []
        predicted_frames = []
        z_positions = []
        skip = None
        if self.is_training:
            for i in range(1,num_all_frames):
                h = enc_frames[i-1]
                h_target = enc_frames[i]
                if i < num_input_frames:
                    skip = enc_skips[i-1]
                with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE):
                    mu, log_var, posterior_state = self.gaussian_rnn(posterior_cell, h_target, posterior_state,
                        z_dim, "posterior")
                    mu_p, log_var_p, prior_state = self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior")
                    z = utils.get_gaussian_tensor(mu,log_var)
                    h_pred, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1),\
                        predictor_state, g_dim, "predictor")
                with tf.variable_scope("decoder", reuse = tf.AUTO_REUSE):
                    x_pred = self.decoder(h_pred, skip)
                predicted_frames.append(x_pred)
                prior_mus.append(mu_p)
                prior_logvars.append(log_var_p)
                posterior_mus.append(mu)
                posterior_logvars.append(log_var)
                z_positions.append(z)
        else:
            for i in range(1, num_all_frames):
                if i < num_input_frames:
                    h = enc_frames[i-1]
                    skip = enc_skips[i-1]
                else:
                    with tf.variable_scope("encoder", reuse = tf.AUTO_REUSE):
                        h, _ = self.encoder(predicted_frames[-1])
                mu = log_var = mu_p = log_var_p = None
                if i < num_input_frames:
                    h_target = enc_frames[i]
                    with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE):
                        mu, log_var, posterior_state = self.gaussian_rnn(posterior_cell, h_target, posterior_state,\
                            z_dim, "posterior")
                        mu_p, log_var_p, prior_state= self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior")
                        z = utils.get_gaussian_tensor(mu,log_var)
                        _, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1), predictor_state,\
                            g_dim, "predictor")
                    x_pred = all_frames[i]
                else:
                    with tf.variable_scope("prediction", reuse = tf.AUTO_REUSE):
                        mu_p, log_var_p, prior_state = self.gaussian_rnn(prior_cell, h, prior_state, z_dim, "prior")
                        z = utils.get_gaussian_tensor(mu_p, log_var_p)
                        h_pred, predictor_state = self.deterministic_rnn(predictor_cell, tf.concat([h,z], axis = 1), predictor_state, g_dim, "predictor")
                    with tf.variable_scope("decoder", reuse = tf.AUTO_REUSE):
                        x_pred = self.decoder(h_pred,skip)
                predicted_frames.append(x_pred)
                prior_mus.append(mu_p)
                prior_logvars.append(log_var_p)
                posterior_mus.append(mu)
                posterior_logvars.append(log_var)
                z_positions.append(z)

        recon_loss = 0
        kl_loss = 0

        #recon loss
        recon_loss = l2_loss(tf.stack(predicted_frames), tf.stack(all_frames[1:]))*(num_all_frames-1)

        if self.is_training:
            #kl loss
            kl_loss = self.get_kl_loss(posterior_mus,posterior_logvars, prior_mus,\
                prior_logvars)
        pred_outputs = tf.stack(predicted_frames[num_input_frames-1:], axis = 1)
        rgb_frames = tf.tile(common_layers.convert_real_to_rgb(tf.stack(predicted_frames, axis = 1)), [1,1,1,1,3])
        all_frames = tf.stack(all_frames, axis = 1)
        all_frames_rgb = tf.tile(common_layers.convert_real_to_rgb(all_frames), [1,1,1,1,3])
        common_video.gif_summary("body_output", rgb_frames)
        common_video.gif_summary("all_ground_frames", all_frames_rgb)
        tf.summary.scalar("kl_loss", kl_loss)
        tf.summary.scalar("recon_loss", recon_loss)
        loss = recon_loss + kl_loss

        return pred_outputs, loss, tf.stack(z_positions,axis = 1)