Example #1
0
def latents_to_frames(z_top_interp, level_eps_interp, hparams):
  """Decodes latents to frames."""
  # Decode [z^1_t, z^2_t .. z^l_t] to [X_t]
  images, _, _, _ = glow_ops.encoder_decoder(
      "codec", z_top_interp, hparams, eps=level_eps_interp, reverse=True)
  images = glow_ops.postprocess(images)
  return images
Example #2
0
  def infer(self, features, *args, **kwargs):  # pylint: disable=arguments-differ
    del args, kwargs
    x = features["inputs"]
    batch_size = common_layers.shape_list(x)[0]
    features["targets"] = tf.zeros(shape=(batch_size, 1, 1, 1))
    _, _ = self(features)  # pylint: disable=not-callable

    ops = [glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout]
    var_scope = tf.variable_scope("glow/body", reuse=True)
    # If eps=None, images are sampled from the prior.
    with arg_scope(ops, init=False), var_scope:
      predictions, _, _, _ = glow_ops.encoder_decoder(
          "codec", self.z_sample, self.hparams, eps=None, reverse=True,
          temperature=self.temperature)

    return glow_ops.postprocess(predictions, self.hparams.n_bits_x)
Example #3
0
    def infer(self, features, *args, **kwargs):  # pylint: disable=arguments-differ
        del args, kwargs

        # Make a copy of features that can be used in the call to self
        # that builds the graph.
        new_features = {}
        new_features["inputs"] = features["inputs"]
        new_features["targets"] = features["infer_targets"]
        _, _ = self(new_features)  # pylint: disable=not-callable

        if self.hparams.gen_mode == "unconditional":
            num_target_frames = 1
        else:
            num_target_frames = self.hparams.video_num_target_frames

        ops = [
            glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout
        ]
        var_scope = tf.variable_scope("next_frame_glow/body", reuse=True)
        all_frames = []

        # If eps=None, images are sampled from the prior.
        with arg_scope(ops, init=False), var_scope:
            for target_frame in range(1, num_target_frames + 1):

                # subscript -> timestep, superscript -> level.
                # self.z_sample equals z^0_{t} (top-level latent)
                # (X_{t}, z^{1..l}_{t}) = Glow(z^0_{t}, z^{1..l}_{t-1})
                # Get current set of cond_latents.
                cond_level, cond_level_latents = get_cond_latents(
                    self.all_level_latents, self.hparams)

                glow_vals = glow_ops.encoder_decoder(
                    "codec",
                    self.z_sample,
                    self.hparams,
                    eps=None,
                    reverse=True,
                    cond_latents=cond_level_latents,
                    states=self.level_states,
                    condition=cond_level,
                    temperature=self.temperature)
                predicted_frame, _, curr_latents, self.level_states = glow_vals
                all_frames.append(predicted_frame)
                self.all_level_latents.append(curr_latents)

                # Compute z^0_{t+1} = f(z^0_{t})
                if target_frame < num_target_frames:
                    cond_top, cond_top_latents = get_cond_latents(
                        self.all_top_latents, self.hparams)
                    prior_dist = self.top_prior(condition=cond_top,
                                                cond_latents=cond_top_latents)
                    self.z_sample = prior_dist.sample()
                    self.all_top_latents.append(self.z_sample)

        all_frames = tf.stack(all_frames)
        predicted_video = common_video.swap_time_and_batch_axes(all_frames)

        # The video-decode API requires the predicted video to be the same shape
        # as the target-video. Hence, for unconditional generation,
        # tile across time to ensure same shape.
        if self.hparams.gen_mode == "unconditional":
            predicted_video = tf.tile(
                predicted_video,
                [1, self.hparams.video_num_target_frames, 1, 1, 1])
        predicted_video = glow_ops.postprocess(predicted_video)

        # Output of a single decode / sample.
        output_features = {}
        output_features["targets"] = tf.zeros_like(predicted_video)
        output_features["outputs"] = predicted_video
        output_features["scores"] = tf.zeros_like(predicted_video)
        return output_features