Example #1
0
 def body(self, features):
   hparams = self._hparams
   is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
   x = features["targets"]
   shape = common_layers.shape_list(x)
   kernel = (hparams.kernel_height, hparams.kernel_width)
   is1d = shape[2] == 1
   kernel = (hparams.kernel_height, 1) if is1d else kernel
   strides = (2, 1) if is1d else (2, 2)
   x, _ = common_layers.pad_to_same_length(
       x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1)
   if not is1d:
     x, _ = common_layers.pad_to_same_length(
         x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=2)
   # Down-convolutions.
   for i in xrange(hparams.num_hidden_layers):
     x = tf.layers.conv2d(
         x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides,
         padding="SAME", activation=tf.nn.relu, name="conv_%d" % i)
     x = common_layers.layer_norm(x)
   # Bottleneck (mix during early training, not too important but very stable).
   b = self.bottleneck(x, hparams.hidden_size * 2**hparams.num_hidden_layers)
   x = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
   # Up-convolutions.
   for i in xrange(hparams.num_hidden_layers):
     j = hparams.num_hidden_layers - i - 1
     x = tf.layers.conv2d_transpose(
         x, hparams.hidden_size * 2**j, kernel, strides=strides,
         padding="SAME", activation=tf.nn.relu, name="deconv_%d" % j)
     x = common_layers.layer_norm(x)
   res = x[:, :shape[1], :shape[2], :]
   return common_layers.mix(res, features["targets"],
                            hparams.bottleneck_warmup_steps // 2, is_training)
Example #2
0
 def body(self, features):
     hparams = self.hparams
     num_stacks = hparams.num_hidden_layers
     hparams.num_hidden_layers = 1
     is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
     if hparams.mode != tf.estimator.ModeKeys.PREDICT:
         x = features["targets"]
         shape = common_layers.shape_list(x)
         is1d = shape[2] == 1
         self.is1d = is1d
         x, _ = common_layers.pad_to_same_length(
             x, x, final_length_divisible_by=2**num_stacks, axis=1)
         if not is1d:
             x, _ = common_layers.pad_to_same_length(
                 x, x, final_length_divisible_by=2**num_stacks, axis=2)
         # Run encoder.
         x = self.encoder(x)
         x_size = common_layers.shape_list(x)[-1]
         # Bottleneck (mix during early training, not too important but stable).
         b = self.bottleneck(x)
         b_loss = self.bottleneck_loss(b)
         losses = {"bottleneck0_loss": b_loss}
         b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_size,
                             losses, is_training, num_stacks - 1)
         b = self.unbottleneck(b, x_size)
         b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps,
                               is_training)
         # With probability bottleneck_max_prob use the bottleneck, otherwise x.
         if hparams.bottleneck_max_prob < 1.0:
             x = tf.where(
                 tf.less(tf.random_uniform([]),
                         hparams.bottleneck_max_prob), b, x)
         else:
             x = b
     else:
         b = self.sample()
         res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
         res_size = min(res_size, hparams.max_hidden_size)
         x = self.unbottleneck(b, res_size)
     # Run decoder.
     x = self.decoder(x)
     if hparams.mode == tf.estimator.ModeKeys.PREDICT:
         return x
     # Cut to the right size and mix before returning.
     res = x[:, :shape[1], :shape[2], :]
     res = common_layers.mix(res, features["targets"],
                             hparams.bottleneck_warmup_steps // 2,
                             is_training)
     hparams.num_hidden_layers = num_stacks
     return res, losses
Example #3
0
def isemhash_bottleneck(x,
                        bottleneck_size,
                        bottleneck_noise,
                        discretize_warmup_steps,
                        mode,
                        isemhash_noise_dev=0.5,
                        isemhash_mix_prob=0.5):
    """Improved semantic hashing bottleneck."""
    with tf.variable_scope("isemhash_bottleneck"):
        x = tf.layers.dense(x, bottleneck_size, name="dense")
        y = common_layers.saturating_sigmoid(x)
        if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
            noise = tf.truncated_normal(common_layers.shape_list(x),
                                        mean=0.0,
                                        stddev=isemhash_noise_dev)
            y = common_layers.saturating_sigmoid(x + noise)
        d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
        d = 2.0 * d - 1.0  # Move from [0, 1] to [-1, 1].
        if mode == tf.estimator.ModeKeys.TRAIN:  # Flip some bits.
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
            d *= noise
            d = common_layers.mix(d,
                                  2.0 * y - 1.0,
                                  discretize_warmup_steps,
                                  mode == tf.estimator.ModeKeys.TRAIN,
                                  max_prob=isemhash_mix_prob)
        return d
Example #4
0
 def body(self, features):
   hparams = self.hparams
   num_stacks = hparams.num_hidden_layers
   hparams.num_hidden_layers = 1
   is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
   if hparams.mode != tf.estimator.ModeKeys.PREDICT:
     x = features["targets"]
     shape = common_layers.shape_list(x)
     is1d = shape[2] == 1
     self.is1d = is1d
     x, _ = common_layers.pad_to_same_length(
         x, x, final_length_divisible_by=2**num_stacks, axis=1)
     if not is1d:
       x, _ = common_layers.pad_to_same_length(
           x, x, final_length_divisible_by=2**num_stacks, axis=2)
     # Run encoder.
     x = self.encoder(x)
     x_size = common_layers.shape_list(x)[-1]
     # Bottleneck (mix during early training, not too important but stable).
     b, b_loss = self.bottleneck(x)
     losses = {"bottleneck0_loss": b_loss}
     b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_bits, losses,
                         is_training, num_stacks - 1)
     b = self.unbottleneck(b, x_size)
     b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
     # With probability bottleneck_max_prob use the bottleneck, otherwise x.
     if hparams.bottleneck_max_prob < 1.0:
       x = tf.where(
           tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
     else:
       x = b
   else:
     b = self.sample()
     res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
     res_size = min(res_size, hparams.max_hidden_size)
     x = self.unbottleneck(b, res_size)
   # Run decoder.
   x = self.decoder(x)
   if hparams.mode == tf.estimator.ModeKeys.PREDICT:
     return x
   # Cut to the right size and mix before returning.
   res = x[:, :shape[1], :shape[2], :]
   res = common_layers.mix(res, features["targets"],
                           hparams.bottleneck_warmup_steps // 2, is_training)
   hparams.num_hidden_layers = num_stacks
   return res, losses
Example #5
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      features["targets"] = tf.zeros_like(basic_result)
    targets_dropout = common_layers.mix(
        features["targets"], tf.zeros_like(basic_result),
        hparams.bottleneck_warmup_steps, is_training,
        max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True)
    # Sometimes it's useful to look at non-autoregressive evals.
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_conv3")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(concat1d, shape[3], 5, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_conv5")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT",
                                 activation=common_layers.belu,
                                 name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      return tf.reshape(res, shape), losses

    raise ValueError("Unsupported autoregressive mode: %s"
                     % hparams.autoregressive_mode)
Example #6
0
 def bottleneck(self, x, res_size):
     hparams = self._hparams
     x = tf.tanh(
         tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck"))
     d = x + tf.stop_gradient(2 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
     y = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
     x = common_layers.mix(d, y, hparams.discretize_warmup_steps,
                           hparams.mode == tf.estimator.ModeKeys.TRAIN)
     x = tf.layers.dense(x, res_size, name="unbottleneck")
     return x
Example #7
0
 def dropout(self, x):
   if self.hparams.dropout <= 0.0:
     return x
   # For simple dropout just do this:
   # return tf.nn.dropout(x, 1.0 - self.hparams.dropout)
   is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
   return common_layers.mix(
       tf.zeros_like(x), x,
       self.hparams.bottleneck_warmup_steps, is_training,
       max_prob=self.hparams.dropout, broadcast_last=True)
Example #8
0
 def body(self, features):
     hparams = self.hparams
     is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
     if hparams.mode != tf.estimator.ModeKeys.PREDICT:
         x = features["targets"]
         shape = common_layers.shape_list(x)
         is1d = shape[2] == 1
         self.is1d = is1d
         # Run encoder.
         x = self.encoder(x)
         # Bottleneck (mix during early training, not too important but stable).
         b = self.bottleneck(x)
         self._cur_bottleneck_tensor = b
         b_loss = self.bottleneck_loss(b)
         b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
         b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps,
                               is_training)
         # With probability bottleneck_max_prob use the bottleneck, otherwise x.
         if hparams.bottleneck_max_prob < 1.0:
             x = tf.where(
                 tf.less(tf.random_uniform([]),
                         hparams.bottleneck_max_prob), b, x)
         else:
             x = b
     else:
         if self._cur_bottleneck_tensor is None:
             b = self.sample()
         else:
             b = self._cur_bottleneck_tensor
         res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
         res_size = min(res_size, hparams.max_hidden_size)
         x = self.unbottleneck(b, res_size)
     # Run decoder.
     x = self.decoder(x)
     if hparams.mode == tf.estimator.ModeKeys.PREDICT:
         return x, {"bottleneck_loss": 0.0}
     # Cut to the right size and mix before returning.
     res = x[:, :shape[1], :shape[2], :]
     res = common_layers.mix(res, features["targets"],
                             hparams.bottleneck_warmup_steps // 2,
                             is_training)
     return res, {"bottleneck_loss": b_loss}
Example #9
0
 def bottleneck(self, x):
   hparams = self.hparams
   x = tf.tanh(tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck"))
   d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
   if hparams.mode == tf.estimator.ModeKeys.TRAIN:
     noise = tf.random_uniform(common_layers.shape_list(x))
     noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0
     d *= noise
   x = common_layers.mix(d, x, hparams.discretize_warmup_steps,
                         hparams.mode == tf.estimator.ModeKeys.TRAIN)
   return x, 0.0
Example #10
0
def tanh_discrete_bottleneck(x, bottleneck_size, bottleneck_noise,
                             discretize_warmup_steps, mode):
    """Simple discretization through tanh, flip bottleneck_noise many bits."""
    x = tf.tanh(
        tf.layers.dense(x, bottleneck_size, name="tanh_discrete_bottleneck"))
    d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
    if mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.random_uniform(common_layers.shape_list(x))
        noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
        d *= noise
    d = common_layers.mix(d, x, discretize_warmup_steps,
                          mode == tf.estimator.ModeKeys.TRAIN)
    return d
Example #11
0
 def dropout(self, x):
   if self.hparams.dropout <= 0.0:
     return x
   # For simple dropout just do this:
   # return tf.nn.dropout(x, 1.0 - self.hparams.dropout)
   is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
   return common_layers.mix(
       tf.zeros_like(x),
       x,
       self.hparams.bottleneck_warmup_steps,
       is_training,
       max_prob=self.hparams.dropout,
       broadcast_last=True)
Example #12
0
def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
                             discretize_warmup_steps, mode):
  """Simple discretization through tanh, flip bottleneck_noise many bits."""
  x = tf.tanh(tf.layers.dense(x, bottleneck_bits,
                              name="tanh_discrete_bottleneck"))
  d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x)
  if mode == tf.estimator.ModeKeys.TRAIN:
    noise = tf.random_uniform(common_layers.shape_list(x))
    noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
    d *= noise
  d = common_layers.mix(d, x, discretize_warmup_steps,
                        mode == tf.estimator.ModeKeys.TRAIN)
  return d, 0.0
Example #13
0
def isemhash_bottleneck(x, bottleneck_bits, bottleneck_noise,
                        discretize_warmup_steps, mode,
                        isemhash_noise_dev=0.5, isemhash_mix_prob=0.5):
  """Improved semantic hashing bottleneck."""
  with tf.variable_scope("isemhash_bottleneck"):
    x = tf.layers.dense(x, bottleneck_bits, name="dense")
    y = common_layers.saturating_sigmoid(x)
    if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
      noise = tf.truncated_normal(
          common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev)
      y = common_layers.saturating_sigmoid(x + noise)
    d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
    d = 2.0 * d - 1.0  # Move from [0, 1] to [-1, 1].
    if mode == tf.estimator.ModeKeys.TRAIN:  # Flip some bits.
      noise = tf.random_uniform(common_layers.shape_list(x))
      noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
      d *= noise
      d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps,
                            mode == tf.estimator.ModeKeys.TRAIN,
                            max_prob=isemhash_mix_prob)
    return d, 0.0
Example #14
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      features["targets"] = tf.zeros_like(basic_result)
    targets_dropout = common_layers.mix(
        features["targets"],
        tf.zeros_like(basic_result),
        hparams.bottleneck_warmup_steps,
        is_training,
        max_prob=1.0 - hparams.autoregressive_dropout,
        broadcast_last=True)
    # Sometimes it's useful to look at non-autoregressive evals.
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      return tf.reshape(res, shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Example #15
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      x = features["targets"]
      shape = common_layers.shape_list(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(), common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]
    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      # Split back if we added a purely sampled batch.
      res_gan, res = tf.split(res, 2, axis=0)
      num_channels = self.hparams.problem.num_channels
      res_rgb = common_layers.convert_real_to_rgb(
          tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb")))
      tf.summary.image(
          "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1)
      orig_rgb = tf.to_float(features["targets_raw"])

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      gan_loss = common_layers.sliced_gan_loss(orig_rgb,
                                               reverse_gradient(res_rgb),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor
    # Mix the final result and return.
    res = common_layers.mix(res, features["targets"],
                            hparams.bottleneck_warmup_steps // 2, is_training)
    return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
Example #16
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      x = features["targets"]
      labels = features["targets_raw"]
      shape = common_layers.shape_list(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(), common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]

    is_image = isinstance(self.hparams.problem, image_utils.ImageProblem)
    if is_image:
      vocab_size = self.hparams.problem.vocab_size

      res = tf.layers.dense(
          res, self.hparams.problem.num_channels * self.hparams.hidden_size)
      output_shape = common_layers.shape_list(res)[:-1] + [
          self.hparams.problem.num_channels, self.hparams.hidden_size
      ]
      res = tf.reshape(res, output_shape)
    elif isinstance(self.hparams.problem, text_problems.Text2TextProblem):
      vocab_size = self._problem_hparams.target_modality.top_dimensionality
      res = tf.layers.dense(res, self.hparams.hidden_size)
    else:
      raise Exception("Unsupported problem type: %s" % self.hparams.problem)

    one_hot_labels = tf.one_hot(labels, vocab_size)
    code_loss_gan = 0.0
    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)
      with tf.variable_scope("vq"):
        reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss(
            res, one_hot_labels, vocab_size)

    with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
      reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss(
          res, one_hot_labels, vocab_size)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      if is_image:
        tf.summary.image(
            "gan",
            common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)),
            max_outputs=1)

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      gan_loss = common_layers.sliced_gan_loss(target_codes,
                                               reverse_gradient(res_gan),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor

    if is_image:
      tf.summary.image(
          "ae",
          common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)),
          max_outputs=1)

    return reconstr, {
        "training": targets_loss,
        "code_loss": code_loss,
        "code_loss_gan": code_loss_gan,
        "b_loss": b_loss,
        "gan_loss": -gan_loss
    }
Example #17
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      labels = features["targets_raw"]
      vocab_size = self._problem_hparams.target_modality.top_dimensionality
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = tf.reshape(x, shape[:-1] + [shape[-1] * vocab_size])
      x = self.embed(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    # Losses.
    losses = {}
    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)
      with tf.variable_scope("vq"):
        reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
            res_gan, labels, vocab_size)
        losses["code_loss_gan"] = (code_loss_gan * hparams.code_loss_factor *
                                   hparams.gan_loss_factor)

    with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
      (reconstr, _, target_codes, code_loss,
       targets_loss) = discretization.vq_loss(res, labels, vocab_size)

    losses["code_loss"] = code_loss * hparams.code_loss_factor
    losses["training"] = targets_loss

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan)

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_loss = common_layers.sliced_gan_loss(target_codes,
                                               reverse_gradient(gan_codes),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor

    self.image_summary("ae", reconstr)

    losses["b_loss"] = b_loss
    losses["gan_loss"] = -gan_loss

    logits = reconstr
    return logits, losses