示例#1
0
  def top_cond_prior(self, name, cond_top_latents):
    """Maps the conditional top latents to a distribution.

    Args:
      name: variable scope.
      cond_top_latents: Tensor or a list of tensors.
                        Latent variables at the previous time-step.
                        If "pointwise", this is a single tensor.
                        If "conv_net", this is a list of tensors with length
                        equal to hparams.num_cond_latents.
    Returns:
      cond_dist: tfp.distributions.Normal
    Raises:
      ValueError: If cond_top_latents are not of the expected length.
    """
    with tf.variable_scope("top", reuse=tf.AUTO_REUSE):
      if self.hparams.latent_dist_encoder == "pointwise":
        last_latent = cond_top_latents
        top = glow_ops.scale_gaussian_prior(
            name, cond_top_latents, trainable=self.hparams.learn_top_scale)
      elif self.hparams.latent_dist_encoder == "conv_net":
        num_cond_latents = (self.hparams.num_cond_latents +
                            int(self.hparams.cond_first_frame))
        if len(cond_top_latents) != num_cond_latents:
          raise ValueError(
              "Expected length of cond_top_latents %d, got %d"
              % (num_cond_latents, len(cond_top_latents)))
        last_latent = cond_top_latents[-1]
        output_channels = common_layers.shape_list(last_latent)[-1]
        cond_top_latents = tf.concat(cond_top_latents, axis=-1)

        # Maps the latent-stack to a distribution.
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        top = glow_ops.latent_to_dist(
            name, cond_top_latents, hparams=self.hparams,
            output_channels=output_channels)
      elif self.hparams.latent_dist_encoder == "conv_lstm":
        last_latent = cond_top_latents
        output_channels = common_layers.shape_list(cond_top_latents)[-1]
        # (h_t, c_t) = LSTM(z_{t-1}; (h_{t-1}, c_{t-1}))
        # (mu_t, sigma_t) = conv(h_t)
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        _, self.top_state = common_video.conv_lstm_2d(
            cond_top_latents, self.top_state, self.hparams.latent_encoder_width,
            kernel_size=3, name="conv_lstm")
        top = glow_ops.single_conv_dist(
            name, self.top_state.h, output_channels=output_channels)
      elif self.hparams.latent_dist_encoder == "conv3d_net":
        last_latent = cond_top_latents[-1]
        cond_top_latents = tf.stack(cond_top_latents, axis=1)
        cond_top_latents = glow_ops.noise_op(cond_top_latents, self.hparams)
        top = glow_ops.temporal_latent_to_dist(
            "conv3d", cond_top_latents, self.hparams)

      # mu(z_{t}) = z_{t-1} + latent_encoder(z_{cond})
      if self.hparams.latent_skip:
        top = tfp.distributions.Normal(last_latent + top.loc, top.scale)
    return top
示例#2
0
def level_cond_prior(prior_dist, z, latent, hparams, state):
    """Returns a conditional prior for each level.

  Args:
    prior_dist: Distribution conditioned on the previous levels.
    z: Tensor, output of the previous levels.
    latent: Tensor or a list of tensors to condition the latent_distribution.
    hparams: next_frame_glow hparams.
    state: Current LSTM state. Used only if hparams.latent_dist_encoder is
           a lstm.
  Raises:
    ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape
                of latent is different from z.
  """
    latent_dist_encoder = hparams.get("latent_dist_encoder", None)
    latent_skip = hparams.get("latent_skip", False)
    if latent_dist_encoder == "pointwise":
        merge_std = hparams.level_scale
        latent_shape = common_layers.shape_list(latent)
        z_shape = common_layers.shape_list(z)
        if latent_shape != z_shape:
            raise ValueError("Expected latent_shape to be %s, got %s" %
                             (latent_shape, z_shape))
        latent_dist = scale_gaussian_prior("latent_prior",
                                           latent,
                                           logscale_factor=3.0)
        cond_dist = merge_level_and_latent_dist(prior_dist,
                                                latent_dist,
                                                merge_std=merge_std)
    elif latent_dist_encoder == "conv_net":
        output_channels = common_layers.shape_list(z)[-1]
        latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1)
        cond_dist = tensor_to_dist(
            "latent_stack",
            latent_stack,
            output_channels=output_channels,
            architecture=hparams.latent_architecture,
            depth=hparams.latent_encoder_depth,
            pre_output_channels=hparams.latent_pre_output_channels,
            width=hparams.latent_encoder_width)
        if latent_skip:
            cond_dist = tf.distributions.Normal(cond_dist.loc + latent[-1],
                                                cond_dist.scale)
    elif latent_dist_encoder == "conv_lstm":
        output_channels = common_layers.shape_list(z)[-1]
        latent_stack = tf.concat((prior_dist.loc, latent), axis=-1)
        _, state = common_video.conv_lstm_2d(latent_stack,
                                             state,
                                             hparams.latent_encoder_width,
                                             kernel_size=3,
                                             name="conv_lstm")
        cond_dist = tensor_to_dist("state_to_dist",
                                   state.h,
                                   output_channels=output_channels)
        if latent_skip:
            cond_dist = tf.distributions.Normal(cond_dist.loc + latent,
                                                cond_dist.scale)
    return cond_dist.loc, cond_dist.scale, state
示例#3
0
def level_cond_prior(prior_dist, z, latent, hparams, state):
    """Returns a conditional prior for each level.

  Args:
    prior_dist: Distribution conditioned on the previous levels.
    z: Tensor, output of the previous levels.
    latent: Tensor or a list of tensors to condition the latent_distribution.
    hparams: next_frame_glow hparams.
    state: Current LSTM state. Used only if hparams.latent_dist_encoder is
           a lstm.
  Raises:
    ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape
                of latent is different from z.
  """
    latent_dist_encoder = hparams.get("latent_dist_encoder", None)
    latent_skip = hparams.get("latent_skip", False)
    if latent_dist_encoder == "pointwise":
        merge_std = hparams.level_scale
        latent_shape = common_layers.shape_list(latent)
        z_shape = common_layers.shape_list(z)
        if latent_shape != z_shape:
            raise ValueError("Expected latent_shape to be %s, got %s" %
                             (latent_shape, z_shape))
        latent_dist = scale_gaussian_prior("latent_prior",
                                           latent,
                                           logscale_factor=3.0)
        cond_dist = merge_level_and_latent_dist(prior_dist,
                                                latent_dist,
                                                merge_std=merge_std)

    elif latent_dist_encoder == "conv_net":
        output_channels = common_layers.shape_list(z)[-1]
        last_latent = latent[-1]
        latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1)
        cond_dist = latent_to_dist("latent_stack",
                                   latent_stack,
                                   hparams=hparams,
                                   output_channels=output_channels)

    elif latent_dist_encoder == "conv3d_net":
        last_latent = latent[-1]
        output_channels = common_layers.shape_list(last_latent)[-1]
        num_steps = len(latent)

        # Stack across time.
        cond_latents = tf.stack(latent, axis=1)

        # Concat latents from previous levels across channels.
        prev_latents = tf.tile(tf.expand_dims(prior_dist.loc, axis=1),
                               [1, num_steps, 1, 1, 1])
        cond_latents = tf.concat((cond_latents, prev_latents), axis=-1)
        cond_dist = temporal_latent_to_dist("latent_stack",
                                            cond_latents,
                                            hparams,
                                            output_channels=output_channels)

    elif latent_dist_encoder == "conv_lstm":
        last_latent = latent
        output_channels = common_layers.shape_list(z)[-1]
        latent_stack = tf.concat((prior_dist.loc, latent), axis=-1)
        _, state = common_video.conv_lstm_2d(latent_stack,
                                             state,
                                             hparams.latent_encoder_width,
                                             kernel_size=3,
                                             name="conv_lstm")

        cond_dist = single_conv_dist("state_to_dist",
                                     state.h,
                                     output_channels=output_channels)
    if latent_skip:
        new_mean = cond_dist.loc + last_latent
        cond_dist = tf.distributions.Normal(new_mean, cond_dist.scale)
    return cond_dist.loc, cond_dist.scale, state
示例#4
0
def compute_prior(name, z, latent, hparams, state=None):
    """Distribution on z_t conditioned on z_{t-1} and latent.

  Args:
    name: variable scope.
    z: 4-D Tensor.
    latent: optional,
            if hparams.latent_dist_encoder == "pointwise", this is a list
            of 4-D Tensors of length hparams.num_cond_latents.
            else, this is just a 4-D Tensor
            The first-three dimensions of the latent should be the same as z.
    hparams: next_frame_glow_hparams.
    state: tf.contrib.rnn.LSTMStateTuple.
           the current state of a LSTM used to model the distribution. Used
           only if hparams.latent_dist_encoder = "conv_lstm".
  Returns:
    prior_dist: instance of tf.distributions.Normal
    state: Returns updated state.
  Raises:
    ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape
                of latent is different from z.
  """
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        prior_dist = tensor_to_dist("level_prior",
                                    z,
                                    architecture="single_conv")

        # TODO(mechcoder) Refactor into separate sub-functions.
        if latent is not None:
            latent_dist_encoder = hparams.get("latent_dist_encoder", None)
            latent_skip = hparams.get("latent_skip", False)
            if latent_dist_encoder == "pointwise":
                merge_std = hparams.level_scale
                latent_shape = common_layers.shape_list(latent)
                z_shape = common_layers.shape_list(z)
                if latent_shape != z_shape:
                    raise ValueError("Expected latent_shape to be %s, got %s" %
                                     (latent_shape, z_shape))
                latent_dist = scale_gaussian_prior("latent_prior",
                                                   latent,
                                                   logscale_factor=3.0)
                prior_dist = merge_level_and_latent_dist(prior_dist,
                                                         latent_dist,
                                                         merge_std=merge_std)
            elif latent_dist_encoder == "conv_net":
                output_channels = common_layers.shape_list(z)[-1]
                latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1)
                prior_dist = tensor_to_dist(
                    "latent_stack",
                    latent_stack,
                    output_channels=output_channels,
                    architecture=hparams.latent_architecture,
                    depth=hparams.latent_encoder_depth,
                    pre_output_channels=hparams.latent_pre_output_channels)
                if latent_skip:
                    prior_dist = tf.distributions.Normal(
                        prior_dist.loc + latent[-1], prior_dist.scale)
            elif latent_dist_encoder == "conv_lstm":
                output_channels = common_layers.shape_list(z)[-1]
                latent_stack = tf.concat((prior_dist.loc, latent), axis=-1)
                _, state = common_video.conv_lstm_2d(latent_stack,
                                                     state,
                                                     output_channels,
                                                     kernel_size=3,
                                                     name="conv_lstm")
                prior_dist = tensor_to_dist("state_to_dist",
                                            state.h,
                                            output_channels=output_channels)
                if latent_skip:
                    prior_dist = tf.distributions.Normal(
                        prior_dist.loc + latent, prior_dist.scale)
            tf.summary.histogram("split_prior_mean", prior_dist.loc)
            tf.summary.histogram("split_prior_scale", prior_dist.scale)

    return prior_dist, state