def gumbel_sample(self, reconstr_gan): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.modality["targets"].top_dimensionality reconstr_gan = tf.nn.log_softmax(reconstr_gan) if is_training and hparams.gumbel_temperature > 0.0: gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(reconstr_gan)) gumbel_samples *= hparams.gumbel_noise_factor reconstr_gan += gumbel_samples reconstr_sample = latent_layers.multinomial_sample( reconstr_gan, temperature=hparams.gumbel_temperature) reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature) else: reconstr_sample = tf.argmax(reconstr_gan, axis=-1) reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1) # Sharpen a bit. # Use 1-hot forward, softmax backward. reconstr_hot = tf.one_hot(reconstr_sample, vocab_size) reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan) return reconstr_gan
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] vocab_size = self._problem_hparams.target_modality.top_dimensionality shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean( tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck(self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([g, x], axis=0) encoder_layers = [ tf.concat([l, l], axis=0) for l in encoder_layers ] else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense(res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=reconstr, labels=labels) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense(res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan = tf.nn.log_softmax(reconstr_gan) if is_training and hparams.gumbel_temperature > 0.0: gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(reconstr_gan)) gumbel_samples *= hparams.gumbel_noise_factor reconstr_gan += gumbel_samples reconstr_sample = latent_layers.multinomial_sample( reconstr_gan, temperature=hparams.gumbel_temperature) reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature) else: reconstr_sample = tf.argmax(reconstr_gan, axis=-1) reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1) # Sharpen a bit. # Use 1-hot forward, softmax backward. reconstr_hot = tf.one_hot(reconstr_sample, vocab_size) reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan) def discriminate(x): return self.discriminator(x, is_training=is_training) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape( target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape( gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = reconstr return logits, losses