コード例 #1
0
  def top_cond_prior(self, name, cond_top_latents):
    """Maps the conditional top latents to a distribution.

    Args:
      name: variable scope.
      cond_top_latents: Tensor or a list of tensors.
                        Latent variables at the previous time-step.
                        If "pointwise", this is a single tensor.
                        If "conv_net", this is a list of tensors with length
                        equal to hparams.num_cond_latents.
    Returns:
      cond_dist: tfp.distributions.Normal
    Raises:
      ValueError: If cond_top_latents are not of the expected length.
    """
    with tf.variable_scope("top", reuse=tf.AUTO_REUSE):
      if self.hparams.latent_dist_encoder == "pointwise":
        last_latent = cond_top_latents
        top = glow_ops.scale_gaussian_prior(
            name, cond_top_latents, trainable=self.hparams.learn_top_scale)
      elif self.hparams.latent_dist_encoder == "conv_net":
        num_cond_latents = (self.hparams.num_cond_latents +
                            int(self.hparams.cond_first_frame))
        if len(cond_top_latents) != num_cond_latents:
          raise ValueError(
              "Expected length of cond_top_latents %d, got %d"
              % (num_cond_latents, len(cond_top_latents)))
        last_latent = cond_top_latents[-1]
        output_channels = common_layers.shape_list(last_latent)[-1]
        cond_top_latents = tf.concat(cond_top_latents, axis=-1)

        # Maps the latent-stack to a distribution.
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        top = glow_ops.latent_to_dist(
            name, cond_top_latents, hparams=self.hparams,
            output_channels=output_channels)
      elif self.hparams.latent_dist_encoder == "conv_lstm":
        last_latent = cond_top_latents
        output_channels = common_layers.shape_list(cond_top_latents)[-1]
        # (h_t, c_t) = LSTM(z_{t-1}; (h_{t-1}, c_{t-1}))
        # (mu_t, sigma_t) = conv(h_t)
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        _, self.top_state = common_video.conv_lstm_2d(
            cond_top_latents, self.top_state, self.hparams.latent_encoder_width,
            kernel_size=3, name="conv_lstm")
        top = glow_ops.single_conv_dist(
            name, self.top_state.h, output_channels=output_channels)
      elif self.hparams.latent_dist_encoder == "conv3d_net":
        last_latent = cond_top_latents[-1]
        cond_top_latents = tf.stack(cond_top_latents, axis=1)
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        top = glow_ops.temporal_latent_to_dist(
            "conv3d", cond_top_latents, self.hparams)

      # mu(z_{t}) = z_{t-1} + latent_encoder(z_{cond})
      if self.hparams.latent_skip:
        top = tfp.distributions.Normal(last_latent + top.loc, top.scale)
    return top
コード例 #2
0
 def test_scale_gaussian_prior(self):
   with tf.Graph().as_default():
     rng = np.random.RandomState(0)
     img_shape = (16, 2, 2, 2)
     x_rand = np.asarray(rng.randint(0, 10, img_shape), dtype=np.float32)
     z_rand = np.asarray(rng.randint(0, 10, img_shape), dtype=np.float32)
     x_t = tf.convert_to_tensor(x_rand)
     z_t = tf.convert_to_tensor(z_rand)
     dist = glow_ops.scale_gaussian_prior(
         "scale_gaussian_prior", z_t, x_t, trainable=True)
     with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       mean, scale = sess.run([dist.loc, dist.scale])
       self.assertTrue(np.allclose(mean, z_rand))
       self.assertTrue(np.allclose(scale, 1.0))
コード例 #3
0
 def test_scale_gaussian_prior(self):
   with tf.Graph().as_default():
     rng = np.random.RandomState(0)
     img_shape = (16, 2, 2, 2)
     x_rand = np.asarray(rng.randint(0, 10, img_shape), dtype=np.float32)
     z_rand = np.asarray(rng.randint(0, 10, img_shape), dtype=np.float32)
     x_t = tf.convert_to_tensor(x_rand)
     z_t = tf.convert_to_tensor(z_rand)
     dist = glow_ops.scale_gaussian_prior(
         "scale_gaussian_prior", z_t, x_t, trainable=True)
     with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       mean, scale = sess.run([dist.loc, dist.scale])
       self.assertTrue(np.allclose(mean, z_rand))
       self.assertTrue(np.allclose(scale, 1.0))