示例#1
0
 def gumbel_sample(self, reconstr_gan):
   hparams = self.hparams
   is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
   vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
   reconstr_gan = tf.nn.log_softmax(reconstr_gan)
   if is_training and hparams.gumbel_temperature > 0.0:
     gumbel_samples = discretization.gumbel_sample(
         common_layers.shape_list(reconstr_gan))
     gumbel_samples *= hparams.gumbel_noise_factor
     reconstr_gan += gumbel_samples
     reconstr_sample = latent_layers.multinomial_sample(
         reconstr_gan, temperature=hparams.gumbel_temperature)
     reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature)
   else:
     reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
     reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1)  # Sharpen a bit.
   # Use 1-hot forward, softmax backward.
   reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
   reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
   return reconstr_gan
示例#2
0
 def gumbel_sample(self, reconstr_gan):
   hparams = self.hparams
   is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
   vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
   reconstr_gan = tf.nn.log_softmax(reconstr_gan)
   if is_training and hparams.gumbel_temperature > 0.0:
     gumbel_samples = discretization.gumbel_sample(
         common_layers.shape_list(reconstr_gan))
     gumbel_samples *= hparams.gumbel_noise_factor
     reconstr_gan += gumbel_samples
     reconstr_sample = latent_layers.multinomial_sample(
         reconstr_gan, temperature=hparams.gumbel_temperature)
     reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature)
   else:
     reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
     reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1)  # Sharpen a bit.
   # Use 1-hot forward, softmax backward.
   reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
   reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
   return reconstr_gan
示例#3
0
    def body(self, features):
        hparams = self.hparams
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            labels = features["targets_raw"]
            vocab_size = self._problem_hparams.target_modality.top_dimensionality
            shape = common_layers.shape_list(labels)
            x = tf.one_hot(labels, vocab_size)
            x = self.embed(x)
            target_codes = x
            is1d = shape[2] == 1
            self.is1d = is1d
            # Run encoder.
            x, encoder_layers = self.encoder(x)
            # Bottleneck.
            b, b_loss = self.bottleneck(x)
            xb_loss = 0.0
            b_shape = common_layers.shape_list(b)
            self._cur_bottleneck_tensor = b
            b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
            if not is_training:
                x = b
            else:
                l = 2**hparams.num_hidden_layers
                warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
                nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
                if common_layers.should_generate_summaries():
                    tf.summary.scalar("nomix_p_bottleneck", nomix_p)
                rand = tf.random_uniform(common_layers.shape_list(x))
                # This is the distance between b and x. Having this as loss helps learn
                # the bottleneck function, but if we back-propagated to x it would be
                # minimized by just setting x=0 and b=0 -- so we don't want too much
                # of the influence of this, and we stop-gradient to not zero-out x.
                x_stop = tf.stop_gradient(x)
                xb_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.square(x_stop - b), axis=-1))
                # To prevent this loss from exploding we clip at 1, but anneal clipping.
                clip_max = 1.0 / common_layers.inverse_exp_decay(
                    warm_step, min_value=0.001)
                xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
                xb_loss *= clip_max / xb_clip
                x = tf.where(tf.less(rand, nomix_p), b, x)
            if hparams.gan_loss_factor != 0.0:
                # Add a purely sampled batch on which we'll compute the GAN loss.
                g = self.unbottleneck(self.sample(shape=b_shape),
                                      common_layers.shape_list(x)[-1],
                                      reuse=True)
                x = tf.concat([g, x], axis=0)
                encoder_layers = [
                    tf.concat([l, l], axis=0) for l in encoder_layers
                ]
        else:
            if self._cur_bottleneck_tensor is None:
                b = self.sample()
            else:
                b = self._cur_bottleneck_tensor
            res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
            res_size = min(res_size, hparams.max_hidden_size)
            x = self.unbottleneck(b, res_size)
        # Run decoder.
        x = self.decoder(x, encoder_layers)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            return x, {"bottleneck_loss": 0.0}
        # Cut to the right size and mix before returning.
        res = x[:, :shape[1], :shape[2], :]

        # Final dense layer.
        res = tf.layers.dense(res,
                              self.num_channels * hparams.hidden_size,
                              name="res_dense")

        output_shape = common_layers.shape_list(res)[:-1] + [
            self.num_channels, self.hparams.hidden_size
        ]
        res = tf.reshape(res, output_shape)

        if hparams.gan_loss_factor != 0.0:
            res_gan, res = tf.split(res, 2, axis=0)

        # Losses.
        losses = {
            "bottleneck_extra": b_loss,
            "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
        }

        if hparams.use_vq_loss:
            vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.2,
                min_value=hparams.vq_temperature * 2)
            if hparams.mode != tf.estimator.ModeKeys.TRAIN:
                vq_temperature = None
            with tf.variable_scope("vq_loss"):
                (reconstr, _, target_codes, code_loss,
                 targets_loss) = discretization.vq_loss(
                     res, labels, vocab_size, temperature=vq_temperature)
            losses["code_loss"] = code_loss * hparams.code_loss_factor
            losses["training"] = targets_loss
        else:
            reconstr = tf.layers.dense(res,
                                       vocab_size,
                                       name="autoencoder_final")
            targets_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=reconstr, labels=labels)
            losses["training"] = targets_loss

        # GAN losses.
        if hparams.gan_loss_factor != 0.0:
            update_means_factor = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps, min_value=0.0001)
            if hparams.use_vq_loss:
                with tf.variable_scope("vq_loss", reuse=True):
                    update_means = tf.less(tf.random_uniform([]),
                                           update_means_factor)
                    reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
                        res_gan,
                        labels,
                        vocab_size,
                        do_update=update_means,
                        temperature=vq_temperature)
                    code_loss_gan *= hparams.code_loss_factor * update_means_factor
                    losses["code_loss_gan"] = code_loss_gan
            else:
                reconstr_gan = tf.layers.dense(res_gan,
                                               vocab_size,
                                               name="autoencoder_final",
                                               reuse=True)
                reconstr_gan = tf.nn.log_softmax(reconstr_gan)
                if is_training and hparams.gumbel_temperature > 0.0:
                    gumbel_samples = discretization.gumbel_sample(
                        common_layers.shape_list(reconstr_gan))
                    gumbel_samples *= hparams.gumbel_noise_factor
                    reconstr_gan += gumbel_samples
                    reconstr_sample = latent_layers.multinomial_sample(
                        reconstr_gan, temperature=hparams.gumbel_temperature)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 hparams.gumbel_temperature)
                else:
                    reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 0.1)  # Sharpen a bit.
                # Use 1-hot forward, softmax backward.
                reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
                reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
                # Embed to codes.
                gan_codes = self.embed(reconstr_gan)

        # Add GAN loss if requested.
        gan_loss = 0.0
        if hparams.gan_loss_factor != 0.0:
            self.image_summary("gan", reconstr_gan)

            def discriminate(x):
                return self.discriminator(x, is_training=is_training)

            tc_shape = common_layers.shape_list(target_codes)
            if len(tc_shape) > 4:
                target_codes = tf.reshape(
                    target_codes,
                    tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
                gan_codes = tf.reshape(
                    gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
            gan_lr = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.5)
            rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
            gan_loss = common_layers.sliced_gan_loss(
                target_codes, rev_grad_gan_codes, discriminate,
                self.hparams.num_sliced_vecs)
            gan_loss *= hparams.gan_loss_factor * update_means_factor
            losses["gan_loss"] = -gan_loss

        self.image_summary("ae", reconstr)
        logits = reconstr
        return logits, losses