Exemple #1
0
  def video_features(
      self, all_frames, all_actions, all_rewards, all_raw_frames):
    """Video wide latent."""
    del all_actions, all_rewards, all_raw_frames

    hparams = self.hparams
    frames = tf.stack(all_frames, axis=1)
    mean, std = self.construct_latent_tower(frames, time_axis=1)
    tower_output = tf.concat([mean, std], axis=-1)
    tower_output_shape = common_layers.shape_list(tower_output)
    batch_size = tower_output_shape[0]

    if not self.is_training:
      rand = tf.random_uniform([batch_size, hparams.bottleneck_bits])
      d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
    else:
      x = tfl.flatten(tower_output)
      x = tfl.dense(x, hparams.bottleneck_bits, name="bits_enc")
      x_shape = common_layers.shape_list(x)
      x += tf.truncated_normal(x_shape, mean=0.0, stddev=0.2)
      x = tf.tanh(x)
      noise = tf.random_uniform(x_shape)
      noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0
      x *= noise
      d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
      p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
      d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)

    decoded_bits = common_video.encode_to_shape(
        d, tower_output_shape, "bits_dec")
    return [decoded_bits, None, None]
 def bottleneck(self, x):
   hparams = self.hparams
   b, _ = super(AutoencoderDualDiscrete, self).bottleneck(x)
   if hparams.mode == tf.estimator.ModeKeys.EVAL:
     return b, 0.0
   bt, bi = tf.split(b, 2, axis=0)
   if self.hparams.mode != tf.estimator.ModeKeys.TRAIN:
     return tf.concat([bi, bi], axis=0), 0.0
   # Share the first hparams.bottleneck_shared_bits.
   shared = (bt + bi) / 2  # -1 if both -1, 1 if both were 1, 0 if disagree.
   rand = tf.random_uniform(common_layers.shape_list(bt))
   br = tf.where(rand < 0.5, bt, bi)  # Break ties at random.
   bs = tf.where(shared == 0, br, shared)
   bs = tf.concat([bs, bs], axis=0)
   n = hparams.bottleneck_shared_bits
   step = tf.train.get_global_step()
   zero = tf.constant(0, dtype=tf.int64)
   if step is None:
     step = zero
   step = tf.maximum(zero, step - hparams.bottleneck_shared_bits_start_warmup)
   f = common_layers.inverse_lin_decay(
       hparams.bottleneck_shared_bits_stop_warmup, min_value=0.1, step=step)
   n = tf.where(step > 1, n * f, n)
   n = tf.cast(n, tf.int64)
   b_shape = common_layers.shape_list(b)
   b = tf.concat([bs[..., :n], b[..., n:]], axis=-1)
   b = tf.reshape(b, b_shape)
   return b, 0.0
Exemple #3
0
 def bottleneck(self, x):
   hparams = self.hparams
   b, _ = super(AutoencoderDualDiscrete, self).bottleneck(x)
   if hparams.mode == tf.estimator.ModeKeys.EVAL:
     return b, 0.0
   bt, bi = tf.split(b, 2, axis=0)
   if self.hparams.mode != tf.estimator.ModeKeys.TRAIN:
     return tf.concat([bi, bi], axis=0), 0.0
   # Share the first hparams.bottleneck_shared_bits.
   shared = (bt + bi) / 2  # -1 if both -1, 1 if both were 1, 0 if disagree.
   rand = tf.random_uniform(common_layers.shape_list(bt))
   br = tf.where(rand < 0.5, bt, bi)  # Break ties at random.
   bs = tf.where(shared == 0, br, shared)
   bs = tf.concat([bs, bs], axis=0)
   n = hparams.bottleneck_shared_bits
   step = tf.train.get_global_step()
   zero = tf.constant(0, dtype=tf.int64)
   if step is None:
     step = zero
   step = tf.maximum(zero, step - hparams.bottleneck_shared_bits_start_warmup)
   f = common_layers.inverse_lin_decay(
       hparams.bottleneck_shared_bits_stop_warmup, min_value=0.1, step=step)
   n = tf.where(step > 1, n * f, n)
   n = tf.cast(n, tf.int64)
   b_shape = common_layers.shape_list(b)
   b = tf.concat([bs[..., :n], b[..., n:]], axis=-1)
   b = tf.reshape(b, b_shape)
   return b, 0.0
def dae(x, hparams, name):
  with tf.variable_scope(name):
    m = tf.layers.dense(x, hparams.v_size, name="mask")
    if hparams.softmax_k > 0:
      m, kl = top_k_softmax(m, hparams.softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)
    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = hparams.kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)
    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9),
                          lambda: temperature,
                          lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = - tf.reduce_max(logsm, axis=-1)
    if _DO_SUMMARIES:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = - tf.reduce_mean(d_variance)
    ret = s
    if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
def bit_vae(x, hparams, name):
    with tf.variable_scope(name):
        bity = tf.layers.dense(x, hparams.z_size, name="bity")
        dev = common_layers.inverse_lin_decay(hparams.startup_steps) * 1.5
        noise = tf.random_normal(tf.shape(bity), mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(bity + noise)
        tf.summary.histogram("bit", tf.reshape(y, [-1]))

        def discrete_y():
            d = tf.to_float(tf.less(0.5, y))
            return tf.stop_gradient(d) + y - tf.stop_gradient(y)

        y = tf.cond(tf.less(tf.train.get_global_step(), hparams.startup_steps),
                    lambda: y, discrete_y)
        # Flatten and predict for loss.
        y_flat = tf.reshape(y, [-1, hparams.z_size, 1, 1])
        hsize = hparams.hidden_size
        hparams.hidden_size = hsize // 2
        emb0 = tf.get_variable("emb0", [hparams.hidden_size])
        emb1 = tf.get_variable("emb1", [hparams.hidden_size])
        emb0 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size])
        emb1 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size])
        y_emb = y_flat * emb1 + (1 - y_flat) * emb0
        y_logit = decode(None, None, y_emb, None, None, hparams, "dbit")
        hparams.hidden_size = hsize
        y_pred = tf.nn.log_softmax(tf.layers.dense(y_logit, 2, name="y_pred"))
        y_flat = tf.reshape(y_flat, [-1])
        y_pred = tf.reshape(y_pred, [-1, 2])
        loss = -(y_flat * y_pred[:, 1] + (1 - y_flat) * y_pred[:, 0])
        # Get the final z and return.
        z = tf.layers.dense(y, hparams.z_size, name="after_bit")
        return z, tf.reduce_mean(loss)
def dae(x, hparams, name):
  with tf.variable_scope(name):
    m = tf.layers.dense(x, hparams.v_size, name="mask")
    if hparams.softmax_k > 0:
      m, kl = top_k_softmax(m, hparams.softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)
    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = hparams.kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)
    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9),
                          lambda: temperature,
                          lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = - tf.reduce_max(logsm, axis=-1)
    if _DO_SUMMARIES:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = - tf.reduce_mean(d_variance)
    ret = s
    if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
def nearest_neighbor(x,
                     means,
                     block_v_size,
                     random_top_k=1,
                     soft_em=False,
                     soft_em_startup_steps=10000,
                     inv_temp=1.0,
                     ema_count=None):
    """Find the nearest element in means to elements in x.

  Args:
    x: Batch of encoder continuous latent states sliced/projected into shape
      [-1, num_blocks, block_dim].
    means: Embedding table of shpae [num_blocks, block_v_size, block_dim].
    block_v_size: Number of table entries per block.
    random_top_k: Noisy top-k if this is bigger than 1 (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    soft_em_startup_steps: Number of steps before soft_em activates
      (Default: 10000).
    inv_temp: Inverse temperature for soft EM (Default: 1.)
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).

  Returns:
    Tensor with nearest element in mean encoded in one-hot notation.
  """
    x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keep_dims=True)
    means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keep_dims=True)
    scalar_prod = tf.matmul(tf.transpose(x, perm=[1, 0, 2]),
                            tf.transpose(means, perm=[0, 2, 1]))
    scalar_prod = tf.transpose(scalar_prod, perm=[1, 0, 2])
    dist = x_norm_sq + tf.transpose(means_norm_sq, perm=[2, 0, 1
                                                         ]) - 2 * scalar_prod

    # computing cluster probabilities
    if soft_em:
        ema_count = tf.expand_dims(ema_count, 0)
        c_probs = ema_count / tf.reduce_sum(ema_count, 2, keepdims=True)
        c_probs = tf.where(
            tf.less(tf.to_int32(tf.train.get_global_step()),
                    soft_em_startup_steps),
            tf.ones_like(c_probs, dtype=tf.float32), c_probs)
        mask = common_layers.inverse_lin_decay(2 * soft_em_startup_steps)
        c_probs = mask * c_probs + (1 - mask) * tf.ones_like(c_probs)
        nearest_hot = tf.exp(-inv_temp * dist) * c_probs
        nearest_hot /= tf.reduce_sum(nearest_hot, 2, keepdims=True)
    else:
        if random_top_k > 1:
            _, top_k_idx = tf.nn.top_k(-dist, k=random_top_k)
            nearest_idx = tf.gather(top_k_idx,
                                    tf.random_uniform([1],
                                                      minval=0,
                                                      maxval=random_top_k - 1,
                                                      dtype=tf.int32),
                                    axis=-1)
        else:
            nearest_idx = tf.argmax(-dist, axis=-1)
        nearest_hot = tf.one_hot(nearest_idx, block_v_size)
    return nearest_hot
Exemple #8
0
def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin"):
  if mode == "lin":
    alpha_p = common_layers.inverse_lin_decay(steps) + 0.001
  else:
    alpha_p = common_layers.inverse_exp_decay(steps) + 0.001
  alpha_p = alpha_p * (max_prob - min_prob) + min_prob
  alpha = tf.random_uniform(tf.shape(x1))
  alpha = tf.to_float(tf.less(alpha, alpha_p))
  return alpha * x1 + (1.0 - alpha) * x2
 def dropout(self, x):
   is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
   hparams = self.hparams
   if hparams.dropout <= 0.0 or not is_training:
     return x
   warm_step = hparams.bottleneck_warmup_steps * 2**hparams.num_hidden_layers
   dropout = common_layers.inverse_lin_decay(warm_step // 2) * hparams.dropout
   return common_layers.dropout_with_broadcast_dims(
       x, 1.0 - dropout, broadcast_dims=[-1])
 def decoder(self, x, encoder_layers=None):
   with tf.variable_scope("decoder"):
     hparams = self.hparams
     is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
     kernel, strides = self._get_kernel_and_strides()
     residual_kernel = (hparams.residual_kernel_height,
                        hparams.residual_kernel_width)
     residual_kernel1d = (hparams.residual_kernel_height, 1)
     residual_kernel = residual_kernel1d if self.is1d else residual_kernel
     residual_conv = tf.layers.conv2d
     if hparams.residual_use_separable_conv:
       residual_conv = tf.layers.separable_conv2d
     # Up-convolutions.
     for i in range(hparams.num_hidden_layers):
       j = hparams.num_hidden_layers - i - 1
       if is_training:
         nomix_p = common_layers.inverse_lin_decay(
             int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01
         if common_layers.should_generate_summaries():
           tf.summary.scalar("nomix_p_%d" % j, nomix_p)
       filters = hparams.hidden_size * 2**j
       filters = min(filters, hparams.max_hidden_size)
       with tf.variable_scope("layer_%d" % i):
         j = hparams.num_hidden_layers - i - 1
         x = tf.layers.conv2d_transpose(
             x,
             filters,
             kernel,
             strides=strides,
             padding="SAME",
             activation=common_layers.belu,
             name="strided")
         y = x
         for r in range(hparams.num_residual_layers):
           residual_filters = filters
           if r < hparams.num_residual_layers - 1:
             residual_filters = int(
                 filters * hparams.residual_filter_multiplier)
           y = residual_conv(
               y,
               residual_filters,
               residual_kernel,
               padding="SAME",
               activation=common_layers.belu,
               name="residual_%d" % r)
         x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
         x = common_layers.layer_norm(x, name="ln")
         x = common_attention.add_timing_signal_nd(x)
         if encoder_layers is not None:
           enc_x = encoder_layers[j]
           enc_shape = common_layers.shape_list(enc_x)
           x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :]
           if is_training:  # Mix at the beginning of training.
             rand = tf.random_uniform(common_layers.shape_list(x_mix))
             x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x)
           x = x_mix
     return x
Exemple #11
0
 def dropout(self, x):
   is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
   hparams = self.hparams
   if hparams.dropout <= 0.0 or not is_training:
     return x
   warm_step = hparams.bottleneck_warmup_steps * 2**hparams.num_hidden_layers
   dropout = common_layers.inverse_lin_decay(warm_step // 2) * hparams.dropout
   return common_layers.dropout_with_broadcast_dims(
       x, 1.0 - dropout, broadcast_dims=[-1])
def vae_transformer_internal(inputs, targets, target_space, hparams):
    """VAE Transformer, main step used for training."""
    with tf.variable_scope("vae_transformer"):
        is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
        # Prepare inputs, targets, and k.
        inputs = common_layers.flatten4d3d(inputs)
        input_len = tf.shape(inputs)[1]  # Double input size to cover targets.
        inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]])
        inputs.set_shape([None, None, hparams.hidden_size])
        targets = common_layers.flatten4d3d(targets)
        k = 2**hparams.num_compress_steps
        inputs, targets = common_layers.pad_to_same_length(
            inputs, targets, final_length_divisible_by=k)
        inputs = encode(inputs, target_space, hparams, "input_enc")

        # Dropout targets or swap for zeros 5% of the time.
        targets_nodrop = targets
        max_prestep = hparams.kl_warmup_steps
        prob_targets = 0.95 if is_training else 1.0
        targets_dropout_max = common_layers.inverse_lin_decay(
            max_prestep) - 0.01
        targets = dropmask(targets, targets_dropout_max * 0.7, is_training)
        targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets),
                          lambda: targets, lambda: tf.zeros_like(targets))
        targets = targets_nodrop

        # Compress and vae.
        z = tf.get_variable("z", [hparams.hidden_size])
        z = tf.reshape(z, [1, 1, 1, -1])
        z = tf.tile(z, [tf.shape(inputs)[0], 1, 1, 1])

        z = attend(z, inputs, hparams, "z_attendsi")
        z = ffn(z, hparams, "zff2")
        z = attend(z, targets, hparams, "z_attendst2")
        z = ffn(z, hparams, "zff3")
        z, kl_loss, _, _ = vae(z, hparams, name="vae")
        z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")

        # z, kl_loss, _, _ = vae_compress(
        #     tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2),
        #     hparams, "vae_compress", "vae_decompress")

        decoder_in = tf.squeeze(z, axis=2) + tf.zeros_like(targets)
        (decoder_input, decoder_self_attention_bias) = (
            transformer.transformer_prepare_decoder(decoder_in, hparams))
        ret = transformer.transformer_decoder(decoder_input, inputs,
                                              decoder_self_attention_bias,
                                              None, hparams)

        kl_loss *= common_layers.inverse_exp_decay(int(
            max_prestep * 1.5)) * 5.0
        losses = {"kl": kl_loss}
        return tf.expand_dims(ret, axis=2), losses
Exemple #13
0
def get_anneal_mask(hparams):
  """Get anneal and kl mask."""
  startup = hparams.kl_startup_steps
  anneal = hparams.kl_anneal_steps
  global_step = tf.train.get_global_step()
  min_value = hparams.anneal_min_value
  step = tf.maximum(global_step - startup, 0)
  anneal = common_layers.inverse_lin_decay(
      anneal, min_value=min_value, step=step)
  kl_mask = tf.less(startup, tf.to_int32(global_step))
  kl_mask = tf.cast(kl_mask, tf.float32)
  return anneal, kl_mask
def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin", simple=False):
    """Mix starting with x2, mixing mixing, going towards x1."""
    if mode == "lin":
        alpha_p = common_layers.inverse_lin_decay(steps)
    else:
        alpha_p = common_layers.inverse_exp_decay(steps)
    alpha_p = alpha_p * (max_prob - min_prob) + min_prob
    if simple:
        return alpha_p * x1 + (1.0 - alpha_p) * x2
    alpha = tf.random_uniform(tf.shape(x1))
    alpha = tf.to_float(tf.less(alpha, alpha_p))
    return alpha * x1 + (1.0 - alpha) * x2
def dae(x, hparams, name):
    with tf.variable_scope(name):
        m = tf.layers.dense(x, hparams.v_size, name="mask")
        logsm = tf.nn.log_softmax(m)
        # Gumbel-softmax sample.
        gumbel_samples = gumbel_sample(tf.shape(m))
        steps = hparams.kl_warmup_steps
        gumbel_samples *= common_layers.inverse_exp_decay(steps) * 0.1
        temperature = 1.2 - common_layers.inverse_lin_decay(steps)
        s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
        m = tf.nn.softmax(m)
        kl = -tf.reduce_max(logsm, axis=-1)
        tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
        return m, s, tf.reduce_mean(kl)
Exemple #16
0
def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None):
  """Compress, then VAE."""
  mix_k = 8
  with tf.variable_scope(compress_name, reuse=reuse):
    cur = compress(x, None, hparams, "compress")
    # Convolve and ReLu to get state.
    cur = common_layers.conv_block(
        cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
    # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
    z, kl_loss = dvae(cur, None, hparams, name="dvae")
    z1, kl_loss1 = top_k_experts(cur, mix_k, hparams)
    mu, log_sigma = None, None

    # Mix expert-selection and flat selection.
    alpha_p = common_layers.inverse_lin_decay(60000) + 0.001
    z = alpha_p * z1 + (1 - alpha_p) * z
    kl_loss += kl_loss1

  # Compress context.
  with tf.variable_scope(compress_name, reuse=reuse):
    compress_c = compress(c, None, hparams, "compress_context")
    c_z = tf.layers.dense(compress_c, hparams.v_size, name="mask_context")
    reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
        labels=z, logits=c_z)

  # If not training, use the predicted z instead of the autoregressive one.
  # if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
  # z = mix(c_z, z, 50000, max_prob=0.3, mode="exp")
  # z, _ = top_k_softmax(c_z, mix_k)

  with tf.variable_scope(decompress_name, reuse=reuse):
    # Decompress.
    z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")

    # Leak at the beginning to help train.
    z = mix(z, cur, 30000)

    for i in xrange(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j)
      z = decompress_step(z, c, hparams, i > 0, "decompress_step_%d" % j)
    return z, kl_loss + 0.0001 * reconstruct_loss, mu, log_sigma
Exemple #17
0
    def body(self, features):
        hparams = self.hparams
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            x = features["targets"]
            shape = common_layers.shape_list(x)
            is1d = shape[2] == 1
            self.is1d = is1d
            # Run encoder.
            x = self.encoder(x)
            # Bottleneck (mix during early training, not too important but stable).
            b, b_loss = self.bottleneck(x)
            self._cur_bottleneck_tensor = b
            b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
            b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps,
                                  is_training)
            if hparams.add_gan_loss:
                # Add a purely sampled batch on which we'll compute the GAN loss.
                g = self.unbottleneck(self.sample(),
                                      common_layers.shape_list(x)[-1],
                                      reuse=True)
                b = tf.concat([g, b], axis=0)
            # With probability bottleneck_max_prob use the bottleneck, otherwise x.
            if hparams.bottleneck_max_prob < -1.0:
                x = tf.where(
                    tf.less(tf.random_uniform([]),
                            hparams.bottleneck_max_prob), b, x)
            else:
                x = b
        else:
            if self._cur_bottleneck_tensor is None:
                b = self.sample()
            else:
                b = self._cur_bottleneck_tensor
            res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
            res_size = min(res_size, hparams.max_hidden_size)
            x = self.unbottleneck(b, res_size)
        # Run decoder.
        x = self.decoder(x)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            return x, {"bottleneck_loss": 0.0}
        # Cut to the right size and mix before returning.
        res = x[:, :shape[1], :shape[2], :]
        # Add GAN loss if requested.
        gan_loss = 0.0
        if hparams.add_gan_loss:
            # Split back if we added a purely sampled batch.
            res_gan, res = tf.split(res, 2, axis=0)
            num_channels = self.hparams.problem.num_channels
            res_rgb = common_layers.convert_real_to_rgb(
                tf.nn.sigmoid(
                    tf.layers.dense(res_gan, num_channels, name="gan_rgb")))
            tf.summary.image("gan", tf.cast(res_rgb, tf.uint8), max_outputs=1)
            orig_rgb = tf.to_float(features["targets_raw"])

            def discriminate(x):
                return self.discriminator(x, is_training=is_training)

            gan_loss = common_layers.sliced_gan_loss(
                orig_rgb, reverse_gradient(res_rgb), discriminate,
                self.hparams.num_sliced_vecs)
            gan_loss *= common_layers.inverse_lin_decay(
                hparams.bottleneck_warmup_steps)
        # Mix the final result and return.
        res = common_layers.mix(res, features["targets"],
                                hparams.bottleneck_warmup_steps // 2,
                                is_training)
        return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
Exemple #18
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        batch_size = layer_shape[0]
        state_size = hparams.latent_predictor_state_size
        lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
        discrete_predict = tf.layers.Dense(256, name="discrete_predict")
        discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")

        def add_d(layer, d):
            z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if self.is_predicting:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                layer_pred = tf.reshape(
                    layer, [batch_size, prod(layer_shape[1:])])
                prediction = tf.layers.dense(layer_pred,
                                             state_size,
                                             name="istate")
                c_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="cstate")
                m_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="mstate")
                state = (c_state, m_state)
                outputs = []
                for i in range(hparams.bottleneck_bits // 8):
                    output, state = lstm_cell(prediction, state)
                    discrete_logits = discrete_predict(output)
                    discrete_samples = common_layers.sample_with_temperature(
                        discrete_logits, hparams.latent_predictor_temperature)
                    outputs.append(tf.expand_dims(discrete_samples, axis=1))
                    prediction = discrete_embed(
                        tf.one_hot(discrete_samples, 256))
                outputs = tf.concat(outputs, axis=1)
                outputs = discretization.int_to_bit(outputs, 8)
                rand = tf.reshape(outputs,
                                  [batch_size, 1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return add_d(layer, d), 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")
        x0 = tf.tanh(x)
        d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 -
                                  x0)
        pred_loss = 0.0
        if not hparams.full_latent_tower:
            d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0),
                                [batch_size, hparams.bottleneck_bits // 8, 8])
            d_int = discretization.bit_to_int(d_pred, 8)
            tf.summary.histogram("d_int", tf.reshape(d_int, [-1]))
            d_hot = tf.one_hot(d_int, 256, axis=-1)
            d_pred = discrete_embed(d_hot)
            layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])])
            prediction0 = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="istate")
            c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
            m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
            pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred],
                             axis=1)
            state = (c_state, m_state)
            outputs = []
            for i in range(hparams.bottleneck_bits // 8):
                output, state = lstm_cell(pred[:, i, :], state)
                outputs.append(tf.expand_dims(output, axis=1))
            outputs = tf.concat(outputs, axis=1)
            d_int_pred = discrete_predict(outputs)
            pred_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=d_int_pred, labels=d_int)
            pred_loss = tf.reduce_mean(pred_loss)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            x += tf.truncated_normal(common_layers.shape_list(x),
                                     mean=0.0,
                                     stddev=0.2)
            x = tf.tanh(x)
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            x *= noise
            d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 -
                                     x)
            p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
            d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)
        return add_d(layer, d), pred_loss
    def next_frame(self, frames, actions, rewards, target_frame,
                   internal_states, video_extra):
        del rewards, video_extra

        hparams = self.hparams
        filters = hparams.hidden_size
        kernel2 = (4, 4)
        action = actions[-1]
        activation_fn = common_layers.belu
        if self.hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        # Normalize frames.
        frames = [common_layers.standardize_images(f) for f in frames]

        # Stack the inputs.
        if internal_states is not None and hparams.concat_internal_states:
            # Use the first part of the first internal state if asked to concatenate.
            batch_size = common_layers.shape_list(frames[0])[0]
            internal_state = internal_states[0][0][:batch_size, :, :, :]
            stacked_frames = tf.concat(frames + [internal_state], axis=-1)
        else:
            stacked_frames = tf.concat(frames, axis=-1)
        inputs_shape = common_layers.shape_list(stacked_frames)

        # Update internal states early if requested.
        if hparams.concat_internal_states:
            internal_states = self.update_internal_states_early(
                internal_states, frames)

        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            stacked_frames,
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=activation_fn,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        if self.has_actions:
            with tf.variable_scope("policy"):
                x_flat = tf.layers.flatten(x)
                policy_pred = tf.layers.dense(x_flat,
                                              self.hparams.problem.num_actions)
                value_pred = tf.layers.dense(x_flat, 1)
                value_pred = tf.squeeze(value_pred, axis=-1)
        else:
            policy_pred, value_pred = None, None

        # Add embedded action if present.
        if self.has_actions:
            x = common_video.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)

        # Inject latent if present. Only for stochastic models.
        norm_target_frame = common_layers.standardize_images(target_frame)
        x, extra_loss = self.inject_latent(x, frames, norm_target_frame,
                                           action)

        x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x, internal_states = self.middle_network(x, internal_states)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
                if self.has_actions:
                    x = common_video.inject_additional_input(
                        x, action, "action_enc", hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=activation_fn,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        if hparams.do_autoregressive_rnn:
            # If enabled, we predict the target frame autoregregressively using rnns.
            # To this end, the current prediciton is flattened into one long sequence
            # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value,
            # from 0 to 255) is predicted with an RNN. To avoid doing as many steps
            # as width * height * channels, we only use a number of pixels back,
            # as many as hparams.autoregressive_rnn_lookback.
            with tf.variable_scope("autoregressive_rnn"):
                batch_size = common_layers.shape_list(frames[0])[0]
                # Height, width, channels and lookback are the constants we need.
                h, w = inputs_shape[1], inputs_shape[
                    2]  # 105, 80 on Atari games
                c = hparams.problem.num_channels
                lookback = hparams.autoregressive_rnn_lookback
                assert (
                    h * w
                ) % lookback == 0, "Number of pixels must divide lookback."
                m = (h * w) // lookback  # Batch size multiplier for the RNN.
                # These are logits that will be used as inputs to the RNN.
                rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs")
                # They are of shape [batch_size, h, w, c, 64], reshaping now.
                rnn_inputs = tf.reshape(rnn_inputs,
                                        [batch_size * m, lookback * c, 64])
                # Same for the target frame.
                rnn_target = tf.reshape(target_frame,
                                        [batch_size * m, lookback * c])
                # Construct rnn starting state: flatten rnn_inputs, apply a relu layer.
                rnn_start_state = tf.nn.relu(
                    tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)),
                                    256,
                                    name="rnn_start_state"))
                # Our RNN function API is on bits, each subpixel has 8 bits.
                total_num_bits = lookback * c * 8
                # We need to provide RNN targets as bits (due to the API).
                rnn_target_bits = discretization.int_to_bit(rnn_target, 8)
                rnn_target_bits = tf.reshape(rnn_target_bits,
                                             [batch_size * m, total_num_bits])
                if self.is_training:
                    # Run the RNN in training mode, add it's loss to the losses.
                    rnn_predict, rnn_loss = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        target_bits=rnn_target_bits,
                        extra_inputs=rnn_inputs)
                    extra_loss += rnn_loss
                    # We still use non-RNN predictions too in order to guide the network.
                    x = tf.layers.dense(x, c * 256, name="logits")
                    x = tf.reshape(x, [batch_size, h, w, c, 256])
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c, 256])
                    # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%.
                    x = tf.reshape(tf.nn.log_softmax(x),
                                   [batch_size, h, w, c * 256])
                    rnn_predict = tf.nn.log_softmax(rnn_predict)
                    rnn_predict = tf.reshape(rnn_predict,
                                             [batch_size, h, w, c * 256])
                    alpha = 0.9 * common_layers.inverse_lin_decay(
                        hparams.autoregressive_rnn_warmup_steps)
                    x = alpha * rnn_predict + (1.0 - alpha) * x
                else:
                    # In prediction mode, run the RNN without any targets.
                    bits, _ = discretization.predict_bits_with_lstm(
                        rnn_start_state,
                        128,
                        total_num_bits,
                        extra_inputs=rnn_inputs,
                        temperature=0.0
                    )  # No sampling from this RNN, just greedy.
                    # The output is in bits, get back the predicted pixels.
                    bits = tf.reshape(bits, [batch_size * m, lookback * c, 8])
                    ints = discretization.bit_to_int(tf.maximum(bits, 0), 8)
                    ints = tf.reshape(ints, [batch_size, h, w, c])
                    x = tf.reshape(tf.one_hot(ints, 256),
                                   [batch_size, h, w, c * 256])
        elif self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        reward_pred = None
        if self.has_rewards:
            # Reward prediction based on middle and final logits.
            reward_pred = tf.concat([x_mid, x_fin], axis=-1)
            reward_pred = tf.nn.relu(
                tf.layers.dense(reward_pred, 128, name="reward_pred"))
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
            reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

        return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  if inputs is not None:
    batch_size = common_layers.shape_list(inputs)[0]
  else:
    batch_size = common_layers.shape_list(targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    inputs_ex, ed_ex = inputs, ed
  else:
    ed, inputs_ex, ed_ex = None, None, None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    # flatten here
    original_targets_shape = tf.shape(targets)
    if hparams.task == "image":
      cia.maybe_reshape_4d_to_3d(targets)
    if hparams.task == "translate":
      max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    else:
      assert hparams.task == "image"
      max_targets_len_from_inputs = targets
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    if hparams.word_dropout:
      mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                               minval=0.0, maxval=1.0)
      targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                               tf.zeros_like(targets))
    else:
      targets_noisy = targets
    targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
          x=targets_c,
          filter_size=hparams.compress_filter_size,
          name="vc",
          mode=hparams.mode)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps)
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            inputs_ex, ed_ex,
            embed(latents_discrete), hparams, "extra",
            task="translate")
        _, latent_pred_loss = ae_latent_softmax(
            latents_pred, tf.stop_gradient(latents_discrete), hparams)
        losses["latent_pred"] = tf.reduce_mean(
            latent_pred_loss * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = hparams.bottleneck(
                x=inputs_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
          return bn
        inputs_c = bn_inputs
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = hparams.bottleneck(
            x=inputs_c,
            filter_size=hparams.compress_filter_size,
            name="vc",
            mode=hparams.mode)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = hparams.bottleneck(
            x=targets_c, filter_size=hparams.compress_filter_size, name="vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(
              latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # decompressing the dense latents
    for i in range(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
      if hparams.do_attend_decompress:
        d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
      d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
      masking *= common_layers.inverse_exp_decay(
          hparams.mask_startup_steps // 4)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * hparams.unmasked_percentage
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.use_predict_mask:
        masking = predict_mask
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)

      # targets is always [batch, length, 1, depth]
      targets = mask * targets + (1.0 - mask) * d
      # reshape back to 4d here
      if hparams.task == "image":
        targets = tf.reshape(targets, original_targets_shape)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder",
                           causal=hparams.causal)
  if hparams.do_ae:
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        # return residual_conv(res, 1, (5, 1), hparams, "refine")
        r, _ = encode(tf.squeeze(res, axis=[2]),
                      target_space, hparams, "refine_enc")
        return tf.expand_dims(r, axis=2)
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
def autoencoder_body(self, features):
  """ Customized body function for autoencoders acting on continuous images.
  This is based on tensor2tensor.models.research.AutoencoderBasic.body
  and should be compatible with most derived classes.

  The original autoencoder class relies on embedding the channels to a discrete
  vocabulary and defines the loss on that vocab. It's cool and all, but here we
  prefer expressing the reconstruction loss as an actual continuous likelihood
  function.
  """
  hparams = self.hparams
  is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

  output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None
  input_shape =  [None, ] + common_layers.shape_list(features["inputs"])[1:]

  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    # In predict mode, we also define TensorFlow Hub modules for all pieces of
    # the autoencoder
    if hparams.encode_psf and 'psf' in features:
      psf_shape =  [None, ] + common_layers.shape_list(features["psf"])[1:]
    # First build encoder spec
    def make_model_spec():
      input_layer = tf.placeholder(tf.float32, shape=input_shape)
      x = self.embed(tf.expand_dims(input_layer, -1))
      x, encoder_layers = self.encoder(x)
      b, b_loss = self.bottleneck(x)
      hub.add_signature(inputs=input_layer, outputs=b)

    def make_model_spec_psf():
      input_layer = tf.placeholder(tf.float32, shape=input_shape)
      psf_layer = tf.placeholder(tf.float32, shape=psf_shape)
      x = self.embed(tf.expand_dims(input_layer, -1))

      # If we have access to the PSF, we add this information to the encoder
      if hparams.encode_psf and 'psf' in features:
        psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(psf_layer[...,0], tf.complex64)), axis=-1)
        # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling
        psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2])
        psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2])
        net_psf = tf.layers.conv2d(psf_image,
                                   hparams.hidden_size // 4, 5,
                                   padding='same', name="psf_embed_1")
        net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
        x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1))
      else:
        x, encoder_layers = self.encoder(x)
      b, b_loss = self.bottleneck(x)
      hub.add_signature(inputs={'input':input_layer, 'psf':psf_layer}, outputs=b)

    spec = hub.create_module_spec(make_model_spec_psf if hparams.encode_psf else make_model_spec, drop_collections=['checkpoints'])
    encoder = hub.Module(spec, name="encoder_module")
    hub.register_module_for_export(encoder, "encoder")

    if hparams.encode_psf:
      code = encoder({'input':features["inputs"], 'psf': features['psf']})
    else:
      code = encoder(features["inputs"])
    b_shape = [None, ] + common_layers.shape_list(code)[1:]
    res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
    res_size = min(res_size, hparams.max_hidden_size)

    # Second build decoder spec
    def make_model_spec():
      input_layer = tf.placeholder(tf.float32, shape=b_shape)
      x = self.unbottleneck(input_layer, res_size)
      x = self.decoder(x, None)
      reconstr = tf.layers.dense(x, input_shape[-1], name="autoencoder_final",
                                 activation=output_activation)
      hub.add_signature(inputs=input_layer, outputs=reconstr)
      hub.attach_message("stamp_size", tf.train.Int64List(value=[hparams.problem_hparams.img_len]))
      try:
        hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale[res] for res in hparams.problem_hparams.resolutions]))
      except AttributeError:
        hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale]))
    spec = hub.create_module_spec(make_model_spec, drop_collections=['checkpoints'])
    decoder = hub.Module(spec, name="decoder_module")
    hub.register_module_for_export(decoder, "decoder")

    reconstr = decoder(code)
    return reconstr , {"bottleneck_loss": 0.0}

  encoder_layers = None
  self.is1d = hparams.sample_width == 1
  if (hparams.mode != tf.estimator.ModeKeys.PREDICT
      or self._encode_on_predict):
    labels = features["targets_raw"]
    labels_shape = common_layers.shape_list(labels)

    shape = common_layers.shape_list(labels)
    with tf.variable_scope('encoder_module'):
      x = self.embed(tf.expand_dims(labels, -1))

    if shape[2] == 1:
      self.is1d = True

    # Run encoder.
    with tf.variable_scope('encoder_module'):
      # If we have access to the PSF, we add this information to the encoder
      # Note that we only support single band images so far...
      if hparams.encode_psf and 'psf' in features:
        psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(features['psf'][...,0], tf.complex64)), axis=-1)
        # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling
        psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2])
        psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2])
        net_psf = tf.layers.conv2d(psf_image,
                                   hparams.hidden_size // 4, 5,
                                   padding='same', name="psf_embed_1")
        net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
        x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1))
      else:
        x, encoder_layers = self.encoder(x)

    # Bottleneck.
    with tf.variable_scope('encoder_module'):
      b, b_loss = self.bottleneck(x)

    xb_loss = 0.0
    b_shape = common_layers.shape_list(b)
    self._cur_bottleneck_tensor = b
    res_size = common_layers.shape_list(x)[-1]
    with tf.variable_scope('decoder_module'):
      b = self.unbottleneck(b, res_size)
    if not is_training:
      x = b
    else:
      l = 2**hparams.num_hidden_layers
      warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
      nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
      if common_layers.should_generate_summaries():
        tf.summary.scalar("nomix_p_bottleneck", nomix_p)
      rand = tf.random_uniform(common_layers.shape_list(x))
      # This is the distance between b and x. Having this as loss helps learn
      # the bottleneck function, but if we back-propagated to x it would be
      # minimized by just setting x=0 and b=0 -- so we don't want too much
      # of the influence of this, and we stop-gradient to not zero-out x.
      x_stop = tf.stop_gradient(x)
      xb_loss = tf.reduce_mean(tf.reduce_sum(
          tf.squared_difference(x_stop, b), axis=-1))
      # To prevent this loss from exploding we clip at 1, but anneal clipping.
      clip_max = 1.0 / common_layers.inverse_exp_decay(
          warm_step, min_value=0.001)
      xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
      xb_loss *= clip_max / xb_clip
      x = tf.where(tf.less(rand, nomix_p), b, x)
  else:
    if self._cur_bottleneck_tensor is None:
      b = self.sample()
    else:
      b = self._cur_bottleneck_tensor
    self._cur_bottleneck_tensor = b
    res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
    res_size = min(res_size, hparams.max_hidden_size)

    with tf.variable_scope('decoder_module'):
      x = self.unbottleneck(b, res_size)

  # Run decoder.
  with tf.variable_scope('decoder_module'):
    x = self.decoder(x, encoder_layers)

  # Cut to the right size and mix before returning.
  res = x
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    res = x[:, :shape[1], :shape[2], :]

  with tf.variable_scope('decoder_module'):
    reconstr = tf.layers.dense(res, shape[-1], name="autoencoder_final",
                               activation=output_activation)

  # We apply an optional apodization of the output before taking the
  if hparams.output_apodization > 0:
    nx = reconstr.get_shape().as_list()[1]
    alpha = 2 * hparams.output_apodization / nx
    from scipy.signal.windows import tukey
    # Create a tukey window
    w = tukey(nx, alpha)
    w = np.outer(w,w).reshape((1, nx, nx,1)).astype('float32')
    # And penalize non zero things at the border
    apo_loss = tf.reduce_mean(tf.reduce_sum(((1.- w)*reconstr)**2, axis=[1,2,3]))
  else:
    w = 1.0
    apo_loss = 0.

  # We apply the window
  reconstr = reconstr * w

  # Optionally regularizes further the output
  # Anisotropic TV:
  tv = tf.reduce_mean(tf.image.total_variation(reconstr))
  # Smoothed Isotropic TV:
  #im_dx, im_dy = tf.image.image_gradients(reconstr)
  #tv = tf.reduce_sum(tf.sqrt(im_dx**2 + im_dy**2 + 1e-6), axis=[1,2,3])
  #tv = tf.reduce_mean(tv)

  image_summary("without_psf",tf.reshape(reconstr, labels_shape))
  # Apply channel-wise convolution with the PSF if requested
  if hparams.apply_psf and 'psf' in features:
    output_list = []
    for i in range(shape[3]):
      output_list.append(tf.squeeze(convolve(tf.expand_dims(reconstr[...,i],-1), tf.cast(features['psf'][...,i], tf.complex64),
                          zero_padding_factor=1)))
    reconstr = tf.stack(output_list,axis=-1)
    reconstr = tf.reshape(reconstr,shape)

  # Losses.
  losses = {
      "bottleneck_extra": b_loss,
      "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss,
      "total_variation": hparams.total_variation_loss * tv,
      "apodization_loss": hparams.apodization_loss * apo_loss,
  }

  loglik = loglikelihood_fn(labels, reconstr, features, hparams)
  targets_loss = tf.reduce_mean(- loglik)

  tf.summary.scalar("negloglik", targets_loss)
  tf.summary.scalar("bottleneck_loss", b_loss)

  # Compute final loss
  losses["training"] = targets_loss + b_loss + hparams.bottleneck_l2_factor * xb_loss + hparams.total_variation_loss * tv +  hparams.apodization_loss * apo_loss
  logits = tf.reshape(reconstr, labels_shape)

  image_summary("ae", reconstr)
  image_summary("input", labels)

  return logits, losses
Exemple #22
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(
            tf.squared_difference(x_stop, b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
def adv_transformer_internal(inputs, targets, target_space, hparams):
    """Adversarial Transformer, main step used for training."""
    with tf.variable_scope("adv_transformer"):
        batch_size = tf.shape(targets)[0]
        targets = tf.reshape(targets, [batch_size, -1, 1])
        intermediate = tf.constant(34 * 1024 - 1)
        intermediate += tf.zeros_like(targets)
        targets = tf.concat([targets, intermediate], axis=2)
        targets = tf.reshape(targets, [batch_size, -1, 1])
        embedding = tf.get_variable("embedding",
                                    [34 * 1024, hparams.hidden_size])
        targets_emb = tf.gather(embedding, targets)

        # Noisy embedded targets.
        targets_noisy = tf.one_hot(targets, 34 * 1024)
        noise_val = hparams.noise_val
        targets_noisy += tf.random_uniform(tf.shape(targets_noisy),
                                           minval=-noise_val,
                                           maxval=noise_val)
        targets_emb_noisy = softmax_embed(targets_noisy, embedding, batch_size,
                                          hparams)

        # Encoder.
        if inputs is not None:
            inputs_emb = common_layers.flatten4d3d(inputs)
            inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc")
        else:
            ed = None

        # Masking.
        masking = common_layers.inverse_lin_decay(200000)
        masking *= common_layers.inverse_exp_decay(50000)  # Not much at start.
        masking -= tf.random_uniform([]) * 0.4
        masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
        mask = tf.less(masking, tf.random_uniform(tf.shape(targets)))
        mask = tf.expand_dims(tf.to_float(mask), 3)
        noise = tf.random_uniform(tf.shape(targets_emb))
        targets_emb = mask * targets_emb + (1.0 - mask) * noise

        # Decoder.
        res_dec = decode(inputs, ed, targets_emb, hparams, "decoder")
        res = tf.layers.dense(res_dec, 34 * 1024, name="res_sm")
        res_emb = softmax_embed(res, embedding, batch_size, hparams)

        # Extra steps.
        extra_step_prob = masking * 0.6 + 0.3
        if hparams.mode != tf.estimator.ModeKeys.TRAIN:
            extra_step_prob = 1.0
        for _ in xrange(hparams.extra_steps):

            def another_step(emb):
                res_dec = decode(inputs,
                                 ed,
                                 emb,
                                 hparams,
                                 "decoder",
                                 reuse=True)
                res = tf.layers.dense(res_dec,
                                      34 * 1024,
                                      name="res_sm",
                                      reuse=True)
                return softmax_embed(res, embedding, batch_size, hparams), res

            res_emb, res = tf.cond(tf.less(tf.random_uniform([]),
                                           extra_step_prob),
                                   lambda e=res_emb: another_step(e),
                                   lambda: (res_emb, res))

        # Adversary.
        delta = masking * hparams.delta_max
        true_logit = adversary(tf.stop_gradient(targets_emb_noisy),
                               tf.stop_gradient(inputs + inputs_emb), hparams,
                               "adversary")
        gen_logit = adversary(reverse_gradient(res_emb, delta),
                              tf.stop_gradient(inputs + inputs_emb),
                              hparams,
                              "adversary",
                              reuse=True)
        losses = {"adv": gen_logit - true_logit}
        res = tf.stop_gradient(masking * res) + (1.0 - masking) * res
        return res, losses
Exemple #24
0
 def decoder(self, x, encoder_layers=None):
   with tf.variable_scope("decoder"):
     hparams = self.hparams
     is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
     kernel, strides = self._get_kernel_and_strides()
     residual_kernel = (hparams.residual_kernel_height,
                        hparams.residual_kernel_width)
     residual_kernel1d = (hparams.residual_kernel_height, 1)
     residual_kernel = residual_kernel1d if self.is1d else residual_kernel
     residual_conv = tf.layers.conv2d
     if hparams.residual_use_separable_conv:
       residual_conv = tf.layers.separable_conv2d
     # Up-convolutions.
     for i in range(hparams.num_hidden_layers):
       j = hparams.num_hidden_layers - i - 1
       if is_training:
         nomix_p = common_layers.inverse_lin_decay(
             int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01
         if common_layers.should_generate_summaries():
           tf.summary.scalar("nomix_p_%d" % j, nomix_p)
       filters = hparams.hidden_size * 2**j
       filters = min(filters, hparams.max_hidden_size)
       with tf.variable_scope("layer_%d" % i):
         j = hparams.num_hidden_layers - i - 1
         x = tf.layers.conv2d_transpose(
             x,
             filters,
             kernel,
             strides=strides,
             padding="SAME",
             activation=common_layers.belu,
             name="strided")
         y = x
         for r in range(hparams.num_residual_layers):
           residual_filters = filters
           if r < hparams.num_residual_layers - 1:
             residual_filters = int(
                 filters * hparams.residual_filter_multiplier)
           y = residual_conv(
               y,
               residual_filters,
               residual_kernel,
               padding="SAME",
               activation=common_layers.belu,
               name="residual_%d" % r)
         x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
         x = common_layers.layer_norm(x, name="ln")
         x = common_attention.add_timing_signal_nd(x)
         if encoder_layers is not None:
           enc_x = encoder_layers[j]
           enc_shape = common_layers.shape_list(enc_x)
           x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :]
           if is_training:  # Mix at the beginning of training.
             rand = tf.random_uniform(common_layers.shape_list(x_mix))
             x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x)
           if hparams.gan_loss_factor != 0:
             x_gan = x[enc_shape[0]:, :enc_shape[1], :enc_shape[2], :]
             x = tf.concat([x_mix, x_gan], axis=0)
           else:
             x = x_mix
     return x
def gumbel_softmax_discrete_bottleneck(x,
                                       bottleneck_bits,
                                       beta=0.25,
                                       decay=0.999,
                                       epsilon=1e-5,
                                       startup_steps=15000,
                                       hard=False,
                                       summary=True):
  """VQ-VAE using Gumbel-Softmax.

  Different from `gumbel_softmax()` function as
  this function calculates the KL by using the discrete entropy
  instead of taking the argmax, and it also uses an exponential moving average
  to update the codebook while the `gumbel_softmax()` function includes no
  codebook update.

  Args:
    x: A `float`-like `Tensor` containing the latent vectors to be compared to
      the codebook, whose squared difference is used as the Gumbel-Softmax
      logits.
    bottleneck_bits: An `int` that sets the size of the bottleneck in `log_2`.
    beta: Beta factor for commitment loss (Default: 0.25).
    decay: Decay factor for exponential moving average (Default: 0.999).
    epsilon: Small value to avoid dividing by zero in EMA update
      (Default: 1e-5).
    startup_steps: Number of steps for KL warmup (Default: 25000).
    hard: When `True`, we use hard Gumbel-Softmax samples and force
      discrete latents by taking the argmax. When `False`, we use soft samples,
      which we treat as codebook weights (Default: False).
    summary: When `True`, we save histogram summaries of the KL term (Default:
      True).

  Returns:
    x_means_assignments: A `float`-like `Tensor` containing the codebook
      assignments. When `hard == True`, this is one-hot, containing the arg-max
      of the Gumbel-Softmax samples (and we use the straightthrough gradient).
      Otherwise, it contains the Gumbel-Softmax samples exactly, which are
      values from the `(K-1)`-simplex where `K` is the bottleneck size.
    loss: The loss, which is the sum of the KL between the Gumbel-Softmax and
      the uniform prior and the commitment loss multiplied by the beta factor.
      We approximate the KL by using the entropy of a categorical distribution
      instead of the Gumbel Softmax.

  """
  bottleneck_size = 2**bottleneck_bits
  x_shape = common_layers.shape_list(x)
  hidden_size = x_shape[-1]
  means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size, hidden_size)
  x = tf.reshape(x, [-1, hidden_size])

  bottleneck_size = common_layers.shape_list(means)[0]
  x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
  means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
  scalar_prod = tf.matmul(x, means, transpose_b=True)
  dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod

  class_probs = tf.nn.softmax(dist)
  log_class_probs = tf.nn.log_softmax(dist)
  gumbel_samples = gumbel_sample(common_layers.shape_list(dist))
  gumbel_samples *= common_layers.inverse_exp_decay(startup_steps // 5) * 0.5
  temperature = 1.2 - common_layers.inverse_lin_decay(startup_steps)

  # 10% of the time keep reasonably high temperature to keep learning.
  temperature = tf.cond(
      tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
      lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
  gumbel_softmax_samples = tf.nn.softmax(
      (log_class_probs + gumbel_samples) / temperature)

  # Calculate KL between q and a uniform prior.
  kl = tf.reduce_sum(class_probs * (log_class_probs -
                                    tf.log(1.0/bottleneck_size)), -1)
  if summary:
    tf.summary.histogram("KL", tf.reshape(kl, [-1]))

  # Straight-through gradient estimation when we're using hard assignments.
  if hard:
    x_means_idx = tf.reshape(tf.argmax(gumbel_softmax_samples, axis=-1), [-1])
    x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
    x_means_assignments = gumbel_softmax_samples + tf.stop_gradient(
        x_means_hot - gumbel_softmax_samples)
  else:
    x_means_assignments = gumbel_softmax_samples
  x_means_assignments_flat = tf.reshape(
      x_means_assignments, [-1, bottleneck_size])
  x_means = tf.matmul(x_means_assignments_flat, means)
  commitment_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))

  # Update the ema variables.
  updated_ema_count = moving_averages.assign_moving_average(
      ema_count,
      tf.reduce_sum(
          tf.reshape(x_means_assignments, shape=[-1, bottleneck_size]), axis=0),
      decay,
      zero_debias=False)

  dw = tf.matmul(x_means_assignments, x, transpose_a=True)
  updated_ema_means = tf.identity(moving_averages.assign_moving_average(
      ema_means, dw, decay, zero_debias=False))
  n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
  updated_ema_count = (
      (updated_ema_count + epsilon) / (n + bottleneck_size * epsilon) * n)
  updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
  with tf.control_dependencies([commitment_loss]):
    update_means = means.assign(updated_ema_means)
    with tf.control_dependencies([update_means]):
      loss = beta * commitment_loss

  # Add KL loss.
  loss += tf.reduce_mean(kl)

  x_means_assignments = tf.reshape(
      x_means_assignments, x_shape[:-1] + [bottleneck_size])
  return x_means_assignments, loss
def gumbel_softmax(x,
                   name,
                   z_size,
                   mode,
                   softmax_k=0,
                   kl_warmup_steps=150000,
                   summary=True):
  """Gumbel softmax discretization bottleneck.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding function, discrete code and loss.
  """
  with tf.variable_scope(name):
    m = tf.layers.dense(x, 2**z_size, name="mask")
    if softmax_k > 0:
      m, kl = top_k_softmax(m, softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)

    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)

    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(
        tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
        lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = -tf.reduce_max(logsm, axis=-1)

    if summary:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))

    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))

    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = -tf.reduce_mean(d_variance)
    ret = s

    if mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Exemple #27
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
Exemple #28
0
def gumbel_softmax(x,
                   name,
                   z_size,
                   mode,
                   softmax_k=0,
                   kl_warmup_steps=150000,
                   summary=True):
    """Gumbel softmax discretization bottleneck.

  Args:
    x: Input to the discretization bottleneck.
    name: Name for the bottleneck scope.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding function, discrete code and loss.
  """
    with tf.variable_scope(name):
        m = tf.layers.dense(x, 2**z_size, name="mask")
        if softmax_k > 0:
            m, kl = top_k_softmax(m, softmax_k)
            return m, m, 1.0 - tf.reduce_mean(kl)
        logsm = tf.nn.log_softmax(m)

        # Gumbel-softmax sample.
        gumbel_samples = gumbel_sample(common_layers.shape_list(m))
        steps = kl_warmup_steps
        gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
        temperature = 1.2 - common_layers.inverse_lin_decay(steps)

        # 10% of the time keep reasonably high temperature to keep learning.
        temperature = tf.cond(
            tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
            lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
        s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
        m = tf.nn.softmax(m)
        kl = -tf.reduce_max(logsm, axis=-1)

        if summary:
            tf.summary.histogram("max-log", tf.reshape(kl, [-1]))

        # Calculate the argmax and construct hot vectors.
        maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
        maxvhot = tf.stop_gradient(tf.one_hot(maxvec, 2**z_size))

        # Add losses that prevent too few being used.
        distrib = tf.reshape(logsm, [-1, 2**z_size]) * maxvhot
        d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
        d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
        d_dev = -tf.reduce_mean(d_variance)
        ret = s

        if mode != tf.contrib.learn.ModeKeys.TRAIN:
            ret = tf.reshape(maxvhot,
                             common_layers.shape_list(s))  # Just hot @eval.
        return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
def ae_transformer_internal(inputs, targets, target_space, hparams,
                            cache=None, predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  batch_size = common_layers.shape_list(inputs)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, _ = bottleneck(
          targets_c, hparams, 2*2048, "vc")
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            tf.stop_gradient(inputs), tf.stop_gradient(ed),
            tf.stop_gradient(latents_dense), hparams, "extra")
        latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=latents_discrete, logits=latents_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"] * 0.5 * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
          return bn
        pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn),
                           bn_inputs, lambda: inputs_c)
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc")
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training only the extra model of latents after 400K steps.
    # Before we train only this, we decrease lr for other weights.
    latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step()))
    decreased_lr = common_layers.inverse_lin_decay(400000)
    losses["latent_pred"] *= tf.to_float(latent_time)
    losses["extra"] *= 1.0 - tf.to_float(latent_time)
    decreased_lr_res = tf.stop_gradient(decreased_lr * res)
    decreased_lr_res += (1.0 - decreased_lr) * res
    res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res)
  return res, losses, cache
Exemple #30
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)

        def add_d(layer, d):
            z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if self.is_predicting:
            layer_shape = common_layers.shape_list(layer)
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                rand = tf.random_uniform(layer_shape[:-3] +
                                         [1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return add_d(layer, d), 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.tanh(
            tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck"))
        d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            x *= noise
            d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 -
                                     x)
            p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
            d = tf.where(tf.less(tf.random_uniform([]), p), d, x)
        return add_d(layer, d), 0.0
  def inject_latent(self, layer, inputs, target, action):
    """Inject a deterministic latent based on the target frame."""
    hparams = self.hparams
    final_filters = common_layers.shape_list(layer)[-1]
    filters = hparams.hidden_size
    kernel = (4, 4)
    layer_shape = common_layers.shape_list(layer)

    def add_bits(layer, bits):
      z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
      if not hparams.complex_addn:
        return layer + z_mul
      layer *= tf.nn.sigmoid(z_mul)
      z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
      layer += z_add
      return layer

    if not self.is_training:
      if hparams.full_latent_tower:
        rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits])
        bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
      else:
        bits, _ = discretization.predict_bits_with_lstm(
            layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
            temperature=hparams.latent_predictor_temperature)
        bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
      return add_bits(layer, bits), 0.0

    # Embed.
    frames = tf.concat(inputs + [target], axis=-1)
    x = tfl.dense(
        frames, filters, name="latent_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Add embedded action if present.
    if action is not None:
      x = common_video.inject_additional_input(
          x, action, "action_enc_latent", hparams.action_injection)

    if hparams.full_latent_tower:
      for i in range(hparams.num_compress_steps):
        with tf.variable_scope("latent_downstride%d" % i):
          x = common_layers.make_even_size(x)
          if i < hparams.filter_double_steps:
            filters *= 2
          x = common_attention.add_timing_signal_nd(x)
          x = tfl.conv2d(x, filters, kernel,
                         activation=common_layers.belu,
                         strides=(2, 2), padding="SAME")
          x = common_layers.layer_norm(x)
    else:
      x = common_layers.double_discriminator(x)
      x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

    bits, bits_clean = discretization.tanh_discrete_bottleneck(
        x, hparams.bottleneck_bits, hparams.bottleneck_noise,
        hparams.discretize_warmup_steps, hparams.mode)
    if not hparams.full_latent_tower:
      _, pred_loss = discretization.predict_bits_with_lstm(
          layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits,
          target_bits=bits_clean)
      # Mix bits from latent with predicted bits on forward pass as a noise.
      if hparams.latent_rnn_max_sampling > 0.0:
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
          bits_pred, _ = discretization.predict_bits_with_lstm(
              layer, hparams.latent_predictor_state_size,
              hparams.bottleneck_bits,
              temperature=hparams.latent_predictor_temperature)
          bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2)
        # Be bits_pred on the forward pass but bits on the backward one.
        bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean)
        # Select which bits to take from pred sampling with bit_p probability.
        which_bit = tf.random_uniform(common_layers.shape_list(bits))
        bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps)
        bit_p *= hparams.latent_rnn_max_sampling
        bits = tf.where(which_bit < bit_p, bits_pred, bits)

    res = add_bits(layer, bits)
    # During training, sometimes skip the latent to help action-conditioning.
    res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2)
    res_p *= hparams.latent_use_max_probability
    res_rand = tf.random_uniform([layer_shape[0]])
    res = tf.where(res_rand < res_p, res, layer)
    return res, pred_loss
    def body(self, features):
        hparams = self.hparams
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            labels = features["targets_raw"]
            vocab_size = self._problem_hparams.target_modality.top_dimensionality
            shape = common_layers.shape_list(labels)
            x = tf.one_hot(labels, vocab_size)
            x = self.embed(x)
            target_codes = x
            is1d = shape[2] == 1
            self.is1d = is1d
            # Run encoder.
            x, encoder_layers = self.encoder(x)
            # Bottleneck.
            b, b_loss = self.bottleneck(x)
            xb_loss = 0.0
            b_shape = common_layers.shape_list(b)
            self._cur_bottleneck_tensor = b
            b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
            if not is_training:
                x = b
            else:
                l = 2**hparams.num_hidden_layers
                warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
                nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
                if common_layers.should_generate_summaries():
                    tf.summary.scalar("nomix_p_bottleneck", nomix_p)
                rand = tf.random_uniform(common_layers.shape_list(x))
                # This is the distance between b and x. Having this as loss helps learn
                # the bottleneck function, but if we back-propagated to x it would be
                # minimized by just setting x=0 and b=0 -- so we don't want too much
                # of the influence of this, and we stop-gradient to not zero-out x.
                x_stop = tf.stop_gradient(x)
                xb_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.square(x_stop - b), axis=-1))
                # To prevent this loss from exploding we clip at 1, but anneal clipping.
                clip_max = 1.0 / common_layers.inverse_exp_decay(
                    warm_step, min_value=0.001)
                xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
                xb_loss *= clip_max / xb_clip
                x = tf.where(tf.less(rand, nomix_p), b, x)
            if hparams.gan_loss_factor != 0.0:
                # Add a purely sampled batch on which we'll compute the GAN loss.
                g = self.unbottleneck(self.sample(shape=b_shape),
                                      common_layers.shape_list(x)[-1],
                                      reuse=True)
                x = tf.concat([g, x], axis=0)
                encoder_layers = [
                    tf.concat([l, l], axis=0) for l in encoder_layers
                ]
        else:
            if self._cur_bottleneck_tensor is None:
                b = self.sample()
            else:
                b = self._cur_bottleneck_tensor
            res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
            res_size = min(res_size, hparams.max_hidden_size)
            x = self.unbottleneck(b, res_size)
        # Run decoder.
        x = self.decoder(x, encoder_layers)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            return x, {"bottleneck_loss": 0.0}
        # Cut to the right size and mix before returning.
        res = x[:, :shape[1], :shape[2], :]

        # Final dense layer.
        res = tf.layers.dense(res,
                              self.num_channels * hparams.hidden_size,
                              name="res_dense")

        output_shape = common_layers.shape_list(res)[:-1] + [
            self.num_channels, self.hparams.hidden_size
        ]
        res = tf.reshape(res, output_shape)

        if hparams.gan_loss_factor != 0.0:
            res_gan, res = tf.split(res, 2, axis=0)

        # Losses.
        losses = {
            "bottleneck_extra": b_loss,
            "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
        }

        if hparams.use_vq_loss:
            vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.2,
                min_value=hparams.vq_temperature * 2)
            if hparams.mode != tf.estimator.ModeKeys.TRAIN:
                vq_temperature = None
            with tf.variable_scope("vq_loss"):
                (reconstr, _, target_codes, code_loss,
                 targets_loss) = discretization.vq_loss(
                     res, labels, vocab_size, temperature=vq_temperature)
            losses["code_loss"] = code_loss * hparams.code_loss_factor
            losses["training"] = targets_loss
        else:
            reconstr = tf.layers.dense(res,
                                       vocab_size,
                                       name="autoencoder_final")
            targets_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=reconstr, labels=labels)
            losses["training"] = targets_loss

        # GAN losses.
        if hparams.gan_loss_factor != 0.0:
            update_means_factor = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps, min_value=0.0001)
            if hparams.use_vq_loss:
                with tf.variable_scope("vq_loss", reuse=True):
                    update_means = tf.less(tf.random_uniform([]),
                                           update_means_factor)
                    reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
                        res_gan,
                        labels,
                        vocab_size,
                        do_update=update_means,
                        temperature=vq_temperature)
                    code_loss_gan *= hparams.code_loss_factor * update_means_factor
                    losses["code_loss_gan"] = code_loss_gan
            else:
                reconstr_gan = tf.layers.dense(res_gan,
                                               vocab_size,
                                               name="autoencoder_final",
                                               reuse=True)
                reconstr_gan = tf.nn.log_softmax(reconstr_gan)
                if is_training and hparams.gumbel_temperature > 0.0:
                    gumbel_samples = discretization.gumbel_sample(
                        common_layers.shape_list(reconstr_gan))
                    gumbel_samples *= hparams.gumbel_noise_factor
                    reconstr_gan += gumbel_samples
                    reconstr_sample = latent_layers.multinomial_sample(
                        reconstr_gan, temperature=hparams.gumbel_temperature)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 hparams.gumbel_temperature)
                else:
                    reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 0.1)  # Sharpen a bit.
                # Use 1-hot forward, softmax backward.
                reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
                reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
                # Embed to codes.
                gan_codes = self.embed(reconstr_gan)

        # Add GAN loss if requested.
        gan_loss = 0.0
        if hparams.gan_loss_factor != 0.0:
            self.image_summary("gan", reconstr_gan)

            def discriminate(x):
                return self.discriminator(x, is_training=is_training)

            tc_shape = common_layers.shape_list(target_codes)
            if len(tc_shape) > 4:
                target_codes = tf.reshape(
                    target_codes,
                    tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
                gan_codes = tf.reshape(
                    gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
            gan_lr = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.5)
            rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
            gan_loss = common_layers.sliced_gan_loss(
                target_codes, rev_grad_gan_codes, discriminate,
                self.hparams.num_sliced_vecs)
            gan_loss *= hparams.gan_loss_factor * update_means_factor
            losses["gan_loss"] = -gan_loss

        self.image_summary("ae", reconstr)
        logits = reconstr
        return logits, losses
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0,
                            means=None,
                            ema_count=None,
                            ema_means=None):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  batch_size = common_layers.shape_list(inputs)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, _ = bottleneck(
          targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            tf.stop_gradient(inputs), tf.stop_gradient(ed),
            tf.stop_gradient(latents_dense), hparams, "extra")
        latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=latents_discrete, logits=latents_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"] * 0.5 * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means,
                                     ema_count, ema_means)
          return bn
        pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn),
                           bn_inputs, lambda: inputs_c)
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc",
                                            means, ema_count, ema_means)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc", means,
                                    ema_count, ema_means)
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training only the extra model of latents after 400K steps.
    # Before we train only this, we decrease lr for other weights.
    latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step()))
    decreased_lr = common_layers.inverse_lin_decay(400000)
    losses["latent_pred"] *= tf.to_float(latent_time)
    losses["extra"] *= 1.0 - tf.to_float(latent_time)
    decreased_lr_res = tf.stop_gradient(decreased_lr * res)
    decreased_lr_res += (1.0 - decreased_lr) * res
    res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res)
  return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams,
                            beam_size, cache=None, predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  orig_targets = targets
  batch_size = tf.shape(orig_targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc")
      if _DO_SUMMARIES:
        tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([]), pc)
      t_c = tf.cond(cond, lambda: t_c, lambda: targets_c)
      losses["extra"] = vc_loss * tf.to_float(cond)
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        t_pred = decode_transformer(
            inputs, ed, tf.stop_gradient(t_c), hparams, "extra")
        t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=t_bit, logits=t_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"]) * 0.5 * tf.to_float(cond)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        targets_rand = tf.random_uniform(tf.shape(targets_c))
        t_c, _, _, _ = bottleneck(targets_rand, hparams, 2*2048, "vc")
      else:
        latent_len = tf.shape(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc")
        t_c = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams)
          cache = cache[0, :, :]
          cache = tf.reshape(cache, [1, latent_len, 1])
          cache = tf.tile(cache, [beam_size, 1, 1])
        t_c = embed(cache)
    # Postprocess.
    d = t_c
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :tf.shape(t_c)[1] + 1, :, :]
    t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, tf.shape(t_c)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      all_masked = tf.less(tf.reduce_sum(mask), 0.1)
      res = tf.cond(all_masked, refine_res, lambda: res)
  return res, losses, cache
Exemple #35
0
def autoencoder_body(self, features):
    """ Customized body function for autoencoders acting on continuous images.
  This is based on tensor2tensor.models.research.AutoencoderBasic.body
  and should be compatible with most derived classes.

  The original autoencoder class relies on embedding the channels to a discrete
  vocabulary and defines the loss on that vocab. It's cool and all, but here we
  prefer expressing the reconstruction loss as an actual continuous likelihood
  function.
  """
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None
    input_shape = [
        None,
    ] + common_layers.shape_list(features["inputs"])[1:]

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        # In predict mode, we also define TensorFlow Hub modules for all pieces of
        # the autoencoder
        # First build encoder spec
        def make_model_spec():
            input_layer = tf.placeholder(tf.float32, shape=input_shape)
            x = self.embed(tf.expand_dims(input_layer, -1))
            x, encoder_layers = self.encoder(x)
            b, b_loss = self.bottleneck(x)
            hub.add_signature(inputs=input_layer, outputs=b)

        def make_model_spec_psf():
            input_layer = tf.placeholder(tf.float32, shape=input_shape)
            psf_layer = tf.placeholder(tf.float32, shape=input_shape)
            x = self.embed(tf.expand_dims(input_layer, -1))

            # If we have access to the PSF, we add this information to the encoder
            if hparams.encode_psf and 'psf' in features:
                net_psf = tf.layers.conv2d(psf_layer,
                                           hparams.hidden_size // 4,
                                           5,
                                           padding='same',
                                           name="psf_embed_1")
                net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
                x, encoder_layers = self.encoder(
                    tf.concat([x, net_psf], axis=-1))
            else:
                x, encoder_layers = self.encoder(x)
            b, b_loss = self.bottleneck(x)
            hub.add_signature(inputs={
                'input': input_layer,
                'psf': psf_layer
            },
                              outputs=b)

        spec = hub.create_module_spec(
            make_model_spec_psf if hparams.encode_psf else make_model_spec,
            drop_collections=['checkpoints'])
        encoder = hub.Module(spec, name="encoder_module")
        hub.register_module_for_export(encoder, "encoder")

        if hparams.encode_psf:
            code = encoder({
                'input': features["inputs"],
                'psf': features['psf']
            })
        else:
            code = encoder(features["inputs"])
        b_shape = [
            None,
        ] + common_layers.shape_list(code)[1:]
        res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
        res_size = min(res_size, hparams.max_hidden_size)

        # Second build decoder spec
        def make_model_spec():
            input_layer = tf.placeholder(tf.float32, shape=b_shape)
            x = self.unbottleneck(input_layer, res_size)
            x = self.decoder(x, None)
            reconstr = tf.layers.dense(x,
                                       self.num_channels,
                                       name="autoencoder_final",
                                       activation=output_activation)
            hub.add_signature(inputs=input_layer, outputs=reconstr)
            hub.attach_message(
                "stamp_size",
                tf.train.Int64List(value=[hparams.problem_hparams.img_len]))
            hub.attach_message(
                "pixel_size",
                tf.train.FloatList(
                    value=[hparams.problem_hparams.pixel_scale]))

        spec = hub.create_module_spec(make_model_spec,
                                      drop_collections=['checkpoints'])
        decoder = hub.Module(spec, name="decoder_module")
        hub.register_module_for_export(decoder, "decoder")

        reconstr = decoder(code)
        return reconstr, {"bottleneck_loss": 0.0}

    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
            or self._encode_on_predict):
        labels = features["targets_raw"]
        labels_shape = common_layers.shape_list(labels)

        shape = common_layers.shape_list(labels)
        with tf.variable_scope('encoder_module'):
            x = self.embed(tf.expand_dims(labels, -1))

        if shape[2] == 1:
            self.is1d = True

        # Run encoder.
        with tf.variable_scope('encoder_module'):
            # If we have access to the PSF, we add this information to the encoder
            if hparams.encode_psf and 'psf' in features:
                net_psf = tf.layers.conv2d(features['psf'],
                                           hparams.hidden_size // 4,
                                           5,
                                           padding='same',
                                           name="psf_embed_1")
                net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
                x, encoder_layers = self.encoder(
                    tf.concat([x, net_psf], axis=-1))
            else:
                x, encoder_layers = self.encoder(x)

        # Bottleneck.
        with tf.variable_scope('encoder_module'):
            b, b_loss = self.bottleneck(x)

        xb_loss = 0.0
        b_shape = common_layers.shape_list(b)
        self._cur_bottleneck_tensor = b
        res_size = common_layers.shape_list(x)[-1]
        with tf.variable_scope('decoder_module'):
            b = self.unbottleneck(b, res_size)
        if not is_training:
            x = b
        else:
            l = 2**hparams.num_hidden_layers
            warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
            nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
            if common_layers.should_generate_summaries():
                tf.summary.scalar("nomix_p_bottleneck", nomix_p)
            rand = tf.random_uniform(common_layers.shape_list(x))
            # This is the distance between b and x. Having this as loss helps learn
            # the bottleneck function, but if we back-propagated to x it would be
            # minimized by just setting x=0 and b=0 -- so we don't want too much
            # of the influence of this, and we stop-gradient to not zero-out x.
            x_stop = tf.stop_gradient(x)
            xb_loss = tf.reduce_mean(
                tf.reduce_sum(tf.squared_difference(x_stop, b), axis=-1))
            # To prevent this loss from exploding we clip at 1, but anneal clipping.
            clip_max = 1.0 / common_layers.inverse_exp_decay(warm_step,
                                                             min_value=0.001)
            xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
            xb_loss *= clip_max / xb_clip
            x = tf.where(tf.less(rand, nomix_p), b, x)
    else:
        if self._cur_bottleneck_tensor is None:
            b = self.sample()
        else:
            b = self._cur_bottleneck_tensor
        self._cur_bottleneck_tensor = b
        res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
        res_size = min(res_size, hparams.max_hidden_size)

        with tf.variable_scope('decoder_module'):
            x = self.unbottleneck(b, res_size)

    # Run decoder.
    with tf.variable_scope('decoder_module'):
        x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        res = x[:, :shape[1], :shape[2], :]

    with tf.variable_scope('decoder_module'):
        reconstr = tf.layers.dense(res,
                                   self.num_channels,
                                   name="autoencoder_final",
                                   activation=output_activation)

    # Apply channel-wise convolution with the PSF if requested
    # TODO: Handle multiple bands
    if hparams.apply_psf and 'psf' in features:
        if self.num_channels > 1:
            raise NotImplementedError
        rec_padded = tf.pad(
            reconstr[:, :, :, 0],
            [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])],
             [0, int(hparams.psf_convolution_pad_factor * shape[2])]])
        psf_padded = tf.pad(
            features['psf'][..., 0],
            [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])],
             [0, int(hparams.psf_convolution_pad_factor * shape[2])]])
        reconstr = tf.expand_dims(tf.spectral.irfft2d(
            tf.spectral.rfft2d(rec_padded) *
            tf.cast(tf.abs(tf.spectral.rfft2d(psf_padded)), tf.complex64)),
                                  axis=-1)
        reconstr = reconstr[:, :shape[1], :shape[2], :]

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    loglik = loglikelihood_fn(labels, reconstr, features, hparams)
    targets_loss = tf.reduce_mean(-loglik)

    tf.summary.scalar("negloglik", targets_loss)
    tf.summary.scalar("bottleneck_loss", b_loss)

    losses["training"] = targets_loss
    logits = tf.reshape(reconstr, labels_shape)

    image_summary("ae", reconstr)
    image_summary("input", labels)

    return logits, losses
Exemple #36
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets = targets
        original_targets_shape = tf.shape(original_targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = contrib.framework().argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        # Add positional information
        targets_shape = common_layers.shape_list(targets)
        targets = tf.reshape(
            targets, [targets_shape[0], targets_shape[1], targets_shape[3]])
        targets = common_attention.add_positional_embedding(
            targets, hparams.max_length, name="targets_position")
        targets = tf.reshape(targets, shape=targets_shape)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets

        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    tf.squared_difference(inputs_c, targets_c)) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        d_shape = common_layers.shape_list(d)
        d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
        d = common_attention.add_positional_embedding(d,
                                                      hparams.max_length,
                                                      name="latents_position")
        d = tf.reshape(d, shape=d_shape)

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        else:
            targets = d

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]

    data_dim = common_layers.shape_list(res)[1]
    latent_dim = common_layers.shape_list(targets_c)[1]
    return res, losses, cache, data_dim, latent_dim
Exemple #37
0
def gumbel_softmax_discrete_bottleneck(x,
                                       bottleneck_bits,
                                       beta=0.25,
                                       decay=0.999,
                                       epsilon=1e-5,
                                       startup_steps=15000,
                                       hard=False,
                                       summary=True):
    """VQ-VAE using Gumbel-Softmax.

  Different from `gumbel_softmax()` function as
  this function calculates the KL by using the discrete entropy
  instead of taking the argmax, and it also uses an exponential moving average
  to update the codebook while the `gumbel_softmax()` function includes no
  codebook update.

  Args:
    x: A `float`-like `Tensor` containing the latent vectors to be compared to
      the codebook, whose squared difference is used as the Gumbel-Softmax
      logits.
    bottleneck_bits: An `int` that sets the size of the bottleneck in `log_2`.
    beta: Beta factor for commitment loss (Default: 0.25).
    decay: Decay factor for exponential moving average (Default: 0.999).
    epsilon: Small value to avoid dividing by zero in EMA update
      (Default: 1e-5).
    startup_steps: Number of steps for KL warmup (Default: 25000).
    hard: When `True`, we use hard Gumbel-Softmax samples and force
      discrete latents by taking the argmax. When `False`, we use soft samples,
      which we treat as codebook weights (Default: False).
    summary: When `True`, we save histogram summaries of the KL term (Default:
      True).

  Returns:
    x_means_assignments: A `float`-like `Tensor` containing the codebook
      assignments. When `hard == True`, this is one-hot, containing the arg-max
      of the Gumbel-Softmax samples (and we use the straightthrough gradient).
      Otherwise, it contains the Gumbel-Softmax samples exactly, which are
      values from the `(K-1)`-simplex where `K` is the bottleneck size.
    loss: The loss, which is the sum of the KL between the Gumbel-Softmax and
      the uniform prior and the commitment loss multiplied by the beta factor.
      We approximate the KL by using the entropy of a categorical distribution
      instead of the Gumbel Softmax.

  """
    bottleneck_size = 2**bottleneck_bits
    x_shape = common_layers.shape_list(x)
    hidden_size = x_shape[-1]
    means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size,
                                                    hidden_size)
    x = tf.reshape(x, [-1, hidden_size])

    bottleneck_size = common_layers.shape_list(means)[0]
    x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
    means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
    scalar_prod = tf.matmul(x, means, transpose_b=True)
    dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod

    class_probs = tf.nn.softmax(dist)
    log_class_probs = tf.nn.log_softmax(dist)
    gumbel_samples = gumbel_sample(common_layers.shape_list(dist))
    gumbel_samples *= common_layers.inverse_exp_decay(startup_steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(startup_steps)

    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(
        tf.less(tf.random_uniform([]), 0.9), lambda: temperature,
        lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    gumbel_softmax_samples = tf.nn.softmax(
        (log_class_probs + gumbel_samples) / temperature)

    # Calculate KL between q and a uniform prior.
    kl = tf.reduce_sum(
        class_probs * (log_class_probs - tf.log(1.0 / bottleneck_size)), -1)
    if summary:
        tf.summary.histogram("KL", tf.reshape(kl, [-1]))

    # Straight-through gradient estimation when we're using hard assignments.
    if hard:
        x_means_idx = tf.reshape(tf.argmax(gumbel_softmax_samples, axis=-1),
                                 [-1])
        x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
        x_means_assignments = gumbel_softmax_samples + tf.stop_gradient(
            x_means_hot - gumbel_softmax_samples)
    else:
        x_means_assignments = gumbel_softmax_samples
    x_means_assignments_flat = tf.reshape(x_means_assignments,
                                          [-1, bottleneck_size])
    x_means = tf.matmul(x_means_assignments_flat, means)
    commitment_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))

    # Update the ema variables.
    updated_ema_count = moving_averages.assign_moving_average(
        ema_count,
        tf.reduce_sum(tf.reshape(x_means_assignments,
                                 shape=[-1, bottleneck_size]),
                      axis=0),
        decay,
        zero_debias=False)

    dw = tf.matmul(x_means_assignments, x, transpose_a=True)
    updated_ema_means = tf.identity(
        moving_averages.assign_moving_average(ema_means,
                                              dw,
                                              decay,
                                              zero_debias=False))
    n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
    updated_ema_count = ((updated_ema_count + epsilon) /
                         (n + bottleneck_size * epsilon) * n)
    updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)
    with tf.control_dependencies([commitment_loss]):
        update_means = means.assign(updated_ema_means)
        with tf.control_dependencies([update_means]):
            loss = beta * commitment_loss

    # Add KL loss.
    loss += tf.reduce_mean(kl)

    x_means_assignments = tf.reshape(x_means_assignments,
                                     x_shape[:-1] + [bottleneck_size])
    return x_means_assignments, loss
Exemple #38
0
    def get_scheduled_sample_func(self, batch_size):
        """Creates a function for scheduled sampling based on given hparams."""
        with tf.variable_scope("scheduled_sampling_func", reuse=tf.AUTO_REUSE):
            iter_num = self.get_iteration_num()

            # Simple function to bypass scheduled sampling in gt or pred only modes.
            def scheduled_sampling_simple(ground_truth_x, generated_x,
                                          batch_size, scheduled_sample_var):
                del batch_size
                if scheduled_sample_var:
                    return ground_truth_x
                return generated_x

            mode = self.hparams.scheduled_sampling_mode
            if mode == "ground_truth_only":
                scheduled_sampling_func = scheduled_sampling_simple
                scheduled_sampling_func_var = True
            elif mode == "prediction_only":
                scheduled_sampling_func = scheduled_sampling_simple
                scheduled_sampling_func_var = False
            elif mode == "prob":
                decay_steps = self.hparams.scheduled_sampling_decay_steps
                probability = tf.train.polynomial_decay(
                    1.0, iter_num, decay_steps, 0.0)
                scheduled_sampling_func = common_video.scheduled_sample_prob
                scheduled_sampling_func_var = probability
            elif mode == "prob_inverse_exp":
                decay_steps = self.hparams.scheduled_sampling_decay_steps
                probability = common_layers.inverse_exp_decay(decay_steps,
                                                              step=iter_num)
                probability *= self.hparams.scheduled_sampling_max_prob
                probability = 1.0 - probability
                scheduled_sampling_func = common_video.scheduled_sample_prob
                scheduled_sampling_func_var = probability
            elif mode == "prob_inverse_lin":
                decay_steps = self.hparams.scheduled_sampling_decay_steps
                probability = common_layers.inverse_exp_decay(
                    decay_steps // 4, step=iter_num)  # Very low at start.
                probability *= common_layers.inverse_lin_decay(decay_steps,
                                                               step=iter_num)
                probability *= self.hparams.scheduled_sampling_max_prob
                probability = 1.0 - probability
                scheduled_sampling_func = common_video.scheduled_sample_prob
                scheduled_sampling_func_var = probability
            elif mode == "count":
                # Calculate number of ground-truth frames to pass in.
                k = self.hparams.scheduled_sampling_k
                num_ground_truth = tf.to_int32(
                    tf.round(
                        tf.to_float(batch_size) *
                        (k /
                         (k + tf.exp(tf.to_float(iter_num) / tf.to_float(k)))))
                )
                scheduled_sampling_func = common_video.scheduled_sample_count
                scheduled_sampling_func_var = num_ground_truth
            else:
                raise ValueError("unknown scheduled sampling method: %s" %
                                 mode)

            if isinstance(scheduled_sampling_func_var, tf.Tensor):
                tf.summary.scalar("scheduled_sampling_var",
                                  scheduled_sampling_func_var)
            partial_func = functools.partial(
                scheduled_sampling_func,
                batch_size=batch_size,
                scheduled_sample_var=scheduled_sampling_func_var)
            return partial_func
Exemple #39
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None):
    """Main step used for training."""
    # Encoder.
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets,
        max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        # Compress and bottleneck.
        latents_discrete_hot, extra_loss = vq_discrete_bottleneck(
            x=targets_c, hparams=hparams)
        latents_dense = vq_discrete_unbottleneck(latents_discrete_hot,
                                                 hparams=hparams)
        latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c)
        latents_discrete = tf.argmax(latents_discrete_hot, axis=-1)
        tf.summary.histogram("codes",
                             tf.reshape(latents_discrete[:, 0, :], [-1]))
        losses["extra"] = extra_loss

        # Extra loss predicting latent code from input.
        latents_pred = decode_transformer(inputs, ed, latents_dense, hparams,
                                          "extra")
        latent_pred_loss = get_latent_pred_loss(latents_pred,
                                                latents_discrete_hot, hparams)
        losses["latent_pred"] = tf.reduce_mean(latent_pred_loss)
    else:
        latent_len = common_layers.shape_list(targets_c)[1]
        embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams)
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed,
                                          hparams)
        cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
        latents_dense = embed(cache_hot)

    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Decompressing the dense latents
    for i in range(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, "decompress_%d" % j)

    masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
    masking *= common_layers.inverse_exp_decay(hparams.mask_startup_steps //
                                               4)  # Not much at start.
    masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = 1.0
    mask = tf.less(masking,
                   tf.random_uniform(common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 3)

    # targets is always [batch, length, 1, depth]
    targets = mask * targets + (1.0 - mask) * d

    res = decode_transformer(inputs, ed, targets, hparams, "decoder")
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
Exemple #40
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
    if hparams.do_ae:
        # flatten here
        original_targets_shape = tf.shape(targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        targets_c = compress(targets, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
                x=targets_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)
                losses["latent_pred"] = tf.reduce_mean(latent_pred_loss *
                                                       tf.to_float(cond))
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    (inputs_c - targets_c)**2) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _ = hparams.bottleneck(
                            x=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            name="vc",
                            mode=hparams.mode)
                    return bn

                inputs_c = bn_inputs
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _ = hparams.bottleneck(
                    x=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc",
                    mode=hparams.mode)
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed = hparams.bottleneck(
                    x=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
        pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
        latents_dense = tf.pad(latents_dense,
                               [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        if hparams.task == "translate":
            targets = tf.concat([tf.reverse(latents_dense, [1]), targets],
                                axis=1)

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.task == "translate":
            res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None):
  """Main step used for training."""
  # Encoder.
  inputs = common_layers.flatten4d3d(inputs)
  inputs, ed = encode(inputs, target_space, hparams, "input_enc")

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

  max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
  targets, _ = common_layers.pad_to_same_length(
      targets,
      max_targets_len_from_inputs,
      final_length_divisible_by=2**hparams.num_compress_steps)
  targets_c = compress(targets, hparams, "compress")
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    # Compress and bottleneck.
    latents_discrete_hot, extra_loss = vq_discrete_bottleneck(
        x=targets_c, hparams=hparams)
    latents_dense = vq_discrete_unbottleneck(
        latents_discrete_hot, hparams=hparams)
    latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c)
    latents_discrete = tf.argmax(latents_discrete_hot, axis=-1)
    tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1]))
    losses["extra"] = extra_loss

    # Extra loss predicting latent code from input.
    latents_pred = decode_transformer(inputs, ed, latents_dense, hparams,
                                      "extra")
    latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot,
                                            hparams)
    losses["latent_pred"] = tf.reduce_mean(latent_pred_loss)
  else:
    latent_len = common_layers.shape_list(targets_c)[1]
    embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams)
    latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
    if cache is None:
      cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed,
                                    hparams)
    cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
    latents_dense = embed(cache_hot)

  # Postprocess.
  d = latents_dense
  pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
  pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
  latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

  # Decompressing the dense latents
  for i in range(hparams.num_compress_steps):
    j = hparams.num_compress_steps - i - 1
    d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
    d = decompress_step(d, hparams, i > 0, "decompress_%d" % j)

  masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
  masking *= common_layers.inverse_exp_decay(
      hparams.mask_startup_steps // 4)  # Not much at start.
  masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    masking = 1.0
  mask = tf.less(masking,
                 tf.random_uniform(common_layers.shape_list(targets)[:-1]))
  mask = tf.expand_dims(tf.to_float(mask), 3)

  # targets is always [batch, length, 1, depth]
  targets = mask * targets + (1.0 - mask) * d

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  latent_time = tf.less(hparams.mask_startup_steps,
                        tf.to_int32(tf.train.get_global_step()))
  losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
Exemple #42
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
Exemple #43
0
  def get_scheduled_sample_func(self, batch_size):
    """Creates a function for scheduled sampling based on given hparams."""
    with tf.variable_scope("scheduled_sampling_func", reuse=tf.AUTO_REUSE):
      iter_num = self.get_iteration_num()

      # Simple function to bypass scheduled sampling in gt or pred only modes.
      def scheduled_sampling_simple(ground_truth_x, generated_x,
                                    batch_size, scheduled_sample_var):
        del batch_size
        if scheduled_sample_var:
          return ground_truth_x
        return generated_x

      mode = self.hparams.scheduled_sampling_mode
      if mode == "ground_truth_only":
        scheduled_sampling_func = scheduled_sampling_simple
        scheduled_sampling_func_var = True
      elif mode == "prediction_only":
        scheduled_sampling_func = scheduled_sampling_simple
        scheduled_sampling_func_var = False
      elif mode == "prob":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = tf.train.polynomial_decay(
            1.0, iter_num, decay_steps, 0.0)
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "prob_inverse_exp":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = common_layers.inverse_exp_decay(
            decay_steps, step=iter_num)
        probability *= self.hparams.scheduled_sampling_max_prob
        probability = 1.0 - probability
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "prob_inverse_lin":
        decay_steps = self.hparams.scheduled_sampling_decay_steps
        probability = common_layers.inverse_exp_decay(
            decay_steps // 4, step=iter_num)  # Very low at start.
        probability *= common_layers.inverse_lin_decay(
            decay_steps, step=iter_num)
        probability *= self.hparams.scheduled_sampling_max_prob
        probability = 1.0 - probability
        scheduled_sampling_func = common_video.scheduled_sample_prob
        scheduled_sampling_func_var = probability
      elif mode == "count":
        # Calculate number of ground-truth frames to pass in.
        k = self.hparams.scheduled_sampling_k
        num_ground_truth = tf.to_int32(
            tf.round(
                tf.to_float(batch_size) *
                (k / (k + tf.exp(tf.to_float(iter_num) / tf.to_float(k))))))
        scheduled_sampling_func = common_video.scheduled_sample_count
        scheduled_sampling_func_var = num_ground_truth
      else:
        raise ValueError("unknown scheduled sampling method: %s" % mode)

      if isinstance(scheduled_sampling_func_var, tf.Tensor):
        tf.summary.scalar("scheduled_sampling_var", scheduled_sampling_func_var)
      partial_func = functools.partial(
          scheduled_sampling_func,
          batch_size=batch_size,
          scheduled_sample_var=scheduled_sampling_func_var)
      return partial_func