Beispiel #1
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(
            tf.squared_difference(x_stop, b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
Beispiel #2
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      x = features["targets"]
      labels = features["targets_raw"]
      shape = common_layers.shape_list(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(), common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]

    is_image = isinstance(self.hparams.problem, image_utils.ImageProblem)
    if is_image:
      vocab_size = self.hparams.problem.vocab_size

      res = tf.layers.dense(
          res, self.hparams.problem.num_channels * self.hparams.hidden_size)
      output_shape = common_layers.shape_list(res)[:-1] + [
          self.hparams.problem.num_channels, self.hparams.hidden_size
      ]
      res = tf.reshape(res, output_shape)
    elif isinstance(self.hparams.problem, text_problems.Text2TextProblem):
      vocab_size = self._problem_hparams.target_modality.top_dimensionality
      res = tf.layers.dense(res, self.hparams.hidden_size)
    else:
      raise Exception("Unsupported problem type: %s" % self.hparams.problem)

    one_hot_labels = tf.one_hot(labels, vocab_size)
    code_loss_gan = 0.0
    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)
      with tf.variable_scope("vq"):
        reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss(
            res, one_hot_labels, vocab_size)

    with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
      reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss(
          res, one_hot_labels, vocab_size)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      if is_image:
        tf.summary.image(
            "gan",
            common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)),
            max_outputs=1)

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      gan_loss = common_layers.sliced_gan_loss(target_codes,
                                               reverse_gradient(res_gan),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor

    if is_image:
      tf.summary.image(
          "ae",
          common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)),
          max_outputs=1)

    return reconstr, {
        "training": targets_loss,
        "code_loss": code_loss,
        "code_loss_gan": code_loss_gan,
        "b_loss": b_loss,
        "gan_loss": -gan_loss
    }
Beispiel #3
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      labels = features["targets_raw"]
      vocab_size = self._problem_hparams.target_modality.top_dimensionality
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = tf.reshape(x, shape[:-1] + [shape[-1] * vocab_size])
      x = self.embed(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    # Losses.
    losses = {}
    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)
      with tf.variable_scope("vq"):
        reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
            res_gan, labels, vocab_size)
        losses["code_loss_gan"] = (code_loss_gan * hparams.code_loss_factor *
                                   hparams.gan_loss_factor)

    with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
      (reconstr, _, target_codes, code_loss,
       targets_loss) = discretization.vq_loss(res, labels, vocab_size)

    losses["code_loss"] = code_loss * hparams.code_loss_factor
    losses["training"] = targets_loss

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan)

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_loss = common_layers.sliced_gan_loss(target_codes,
                                               reverse_gradient(gan_codes),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor

    self.image_summary("ae", reconstr)

    losses["b_loss"] = b_loss
    losses["gan_loss"] = -gan_loss

    logits = reconstr
    return logits, losses
Beispiel #4
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
Beispiel #5
0
    def body(self, features):
        hparams = self.hparams
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            labels = features["targets_raw"]
            vocab_size = self._problem_hparams.target_modality.top_dimensionality
            shape = common_layers.shape_list(labels)
            x = tf.one_hot(labels, vocab_size)
            x = self.embed(x)
            target_codes = x
            is1d = shape[2] == 1
            self.is1d = is1d
            # Run encoder.
            x, encoder_layers = self.encoder(x)
            # Bottleneck.
            b, b_loss = self.bottleneck(x)
            xb_loss = 0.0
            b_shape = common_layers.shape_list(b)
            self._cur_bottleneck_tensor = b
            b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
            if not is_training:
                x = b
            else:
                l = 2**hparams.num_hidden_layers
                warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
                nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
                if common_layers.should_generate_summaries():
                    tf.summary.scalar("nomix_p_bottleneck", nomix_p)
                rand = tf.random_uniform(common_layers.shape_list(x))
                # This is the distance between b and x. Having this as loss helps learn
                # the bottleneck function, but if we back-propagated to x it would be
                # minimized by just setting x=0 and b=0 -- so we don't want too much
                # of the influence of this, and we stop-gradient to not zero-out x.
                x_stop = tf.stop_gradient(x)
                xb_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.square(x_stop - b), axis=-1))
                # To prevent this loss from exploding we clip at 1, but anneal clipping.
                clip_max = 1.0 / common_layers.inverse_exp_decay(
                    warm_step, min_value=0.001)
                xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
                xb_loss *= clip_max / xb_clip
                x = tf.where(tf.less(rand, nomix_p), b, x)
            if hparams.gan_loss_factor != 0.0:
                # Add a purely sampled batch on which we'll compute the GAN loss.
                g = self.unbottleneck(self.sample(shape=b_shape),
                                      common_layers.shape_list(x)[-1],
                                      reuse=True)
                x = tf.concat([g, x], axis=0)
                encoder_layers = [
                    tf.concat([l, l], axis=0) for l in encoder_layers
                ]
        else:
            if self._cur_bottleneck_tensor is None:
                b = self.sample()
            else:
                b = self._cur_bottleneck_tensor
            res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
            res_size = min(res_size, hparams.max_hidden_size)
            x = self.unbottleneck(b, res_size)
        # Run decoder.
        x = self.decoder(x, encoder_layers)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            return x, {"bottleneck_loss": 0.0}
        # Cut to the right size and mix before returning.
        res = x[:, :shape[1], :shape[2], :]

        # Final dense layer.
        res = tf.layers.dense(res,
                              self.num_channels * hparams.hidden_size,
                              name="res_dense")

        output_shape = common_layers.shape_list(res)[:-1] + [
            self.num_channels, self.hparams.hidden_size
        ]
        res = tf.reshape(res, output_shape)

        if hparams.gan_loss_factor != 0.0:
            res_gan, res = tf.split(res, 2, axis=0)

        # Losses.
        losses = {
            "bottleneck_extra": b_loss,
            "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
        }

        if hparams.use_vq_loss:
            vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.2,
                min_value=hparams.vq_temperature * 2)
            if hparams.mode != tf.estimator.ModeKeys.TRAIN:
                vq_temperature = None
            with tf.variable_scope("vq_loss"):
                (reconstr, _, target_codes, code_loss,
                 targets_loss) = discretization.vq_loss(
                     res, labels, vocab_size, temperature=vq_temperature)
            losses["code_loss"] = code_loss * hparams.code_loss_factor
            losses["training"] = targets_loss
        else:
            reconstr = tf.layers.dense(res,
                                       vocab_size,
                                       name="autoencoder_final")
            targets_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=reconstr, labels=labels)
            losses["training"] = targets_loss

        # GAN losses.
        if hparams.gan_loss_factor != 0.0:
            update_means_factor = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps, min_value=0.0001)
            if hparams.use_vq_loss:
                with tf.variable_scope("vq_loss", reuse=True):
                    update_means = tf.less(tf.random_uniform([]),
                                           update_means_factor)
                    reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
                        res_gan,
                        labels,
                        vocab_size,
                        do_update=update_means,
                        temperature=vq_temperature)
                    code_loss_gan *= hparams.code_loss_factor * update_means_factor
                    losses["code_loss_gan"] = code_loss_gan
            else:
                reconstr_gan = tf.layers.dense(res_gan,
                                               vocab_size,
                                               name="autoencoder_final",
                                               reuse=True)
                reconstr_gan = tf.nn.log_softmax(reconstr_gan)
                if is_training and hparams.gumbel_temperature > 0.0:
                    gumbel_samples = discretization.gumbel_sample(
                        common_layers.shape_list(reconstr_gan))
                    gumbel_samples *= hparams.gumbel_noise_factor
                    reconstr_gan += gumbel_samples
                    reconstr_sample = latent_layers.multinomial_sample(
                        reconstr_gan, temperature=hparams.gumbel_temperature)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 hparams.gumbel_temperature)
                else:
                    reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 0.1)  # Sharpen a bit.
                # Use 1-hot forward, softmax backward.
                reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
                reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
                # Embed to codes.
                gan_codes = self.embed(reconstr_gan)

        # Add GAN loss if requested.
        gan_loss = 0.0
        if hparams.gan_loss_factor != 0.0:
            self.image_summary("gan", reconstr_gan)

            def discriminate(x):
                return self.discriminator(x, is_training=is_training)

            tc_shape = common_layers.shape_list(target_codes)
            if len(tc_shape) > 4:
                target_codes = tf.reshape(
                    target_codes,
                    tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
                gan_codes = tf.reshape(
                    gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
            gan_lr = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.5)
            rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
            gan_loss = common_layers.sliced_gan_loss(
                target_codes, rev_grad_gan_codes, discriminate,
                self.hparams.num_sliced_vecs)
            gan_loss *= hparams.gan_loss_factor * update_means_factor
            losses["gan_loss"] = -gan_loss

        self.image_summary("ae", reconstr)
        logits = reconstr
        return logits, losses