def latents_to_frames(z_top_interp, level_eps_interp, hparams): """Decodes latents to frames.""" # Decode [z^1_t, z^2_t .. z^l_t] to [X_t] images, _, _, _ = glow_ops.encoder_decoder( "codec", z_top_interp, hparams, eps=level_eps_interp, reverse=True) images = glow_ops.postprocess(images) return images
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ del args, kwargs x = features["inputs"] batch_size = common_layers.shape_list(x)[0] features["targets"] = tf.zeros(shape=(batch_size, 1, 1, 1)) _, _ = self(features) # pylint: disable=not-callable ops = [glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout] var_scope = tf.variable_scope("glow/body", reuse=True) # If eps=None, images are sampled from the prior. with arg_scope(ops, init=False), var_scope: predictions, _, _, _ = glow_ops.encoder_decoder( "codec", self.z_sample, self.hparams, eps=None, reverse=True, temperature=self.temperature) return glow_ops.postprocess(predictions, self.hparams.n_bits_x)
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ del args, kwargs # Make a copy of features that can be used in the call to self # that builds the graph. new_features = {} new_features["inputs"] = features["inputs"] new_features["targets"] = features["infer_targets"] _, _ = self(new_features) # pylint: disable=not-callable if self.hparams.gen_mode == "unconditional": num_target_frames = 1 else: num_target_frames = self.hparams.video_num_target_frames ops = [ glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout ] var_scope = tf.variable_scope("next_frame_glow/body", reuse=True) all_frames = [] # If eps=None, images are sampled from the prior. with arg_scope(ops, init=False), var_scope: for target_frame in range(1, num_target_frames + 1): # subscript -> timestep, superscript -> level. # self.z_sample equals z^0_{t} (top-level latent) # (X_{t}, z^{1..l}_{t}) = Glow(z^0_{t}, z^{1..l}_{t-1}) # Get current set of cond_latents. cond_level, cond_level_latents = get_cond_latents( self.all_level_latents, self.hparams) glow_vals = glow_ops.encoder_decoder( "codec", self.z_sample, self.hparams, eps=None, reverse=True, cond_latents=cond_level_latents, states=self.level_states, condition=cond_level, temperature=self.temperature) predicted_frame, _, curr_latents, self.level_states = glow_vals all_frames.append(predicted_frame) self.all_level_latents.append(curr_latents) # Compute z^0_{t+1} = f(z^0_{t}) if target_frame < num_target_frames: cond_top, cond_top_latents = get_cond_latents( self.all_top_latents, self.hparams) prior_dist = self.top_prior(condition=cond_top, cond_latents=cond_top_latents) self.z_sample = prior_dist.sample() self.all_top_latents.append(self.z_sample) all_frames = tf.stack(all_frames) predicted_video = common_video.swap_time_and_batch_axes(all_frames) # The video-decode API requires the predicted video to be the same shape # as the target-video. Hence, for unconditional generation, # tile across time to ensure same shape. if self.hparams.gen_mode == "unconditional": predicted_video = tf.tile( predicted_video, [1, self.hparams.video_num_target_frames, 1, 1, 1]) predicted_video = glow_ops.postprocess(predicted_video) # Output of a single decode / sample. output_features = {} output_features["targets"] = tf.zeros_like(predicted_video) output_features["outputs"] = predicted_video output_features["scores"] = tf.zeros_like(predicted_video) return output_features