def losses(self, inputs, generated): """Losses in the sliced case.""" is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN def discriminate(x): return self.discriminator(x, is_training=is_training, reuse=False) generator_loss = common_layers.sliced_gan_loss( inputs, reverse_gradient(generated), discriminate, self.hparams.num_sliced_vecs) return {"training": - generator_loss}
def losses(self, inputs, generated): """Losses in the sliced case.""" is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN def discriminate(x): return self.discriminator(x, is_training=is_training, reuse=False) generator_loss = common_layers.sliced_gan_loss( inputs, reverse_gradient(generated), discriminate, self.hparams.num_sliced_vecs) return {"training": - generator_loss}
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: # Split back if we added a purely sampled batch. res_gan, res = tf.split(res, 2, axis=0) num_channels = self.hparams.problem.num_channels res_rgb = common_layers.convert_real_to_rgb( tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb"))) tf.summary.image( "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1) orig_rgb = tf.to_float(features["targets_raw"]) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(orig_rgb, reverse_gradient(res_rgb), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor # Mix the final result and return. res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.vocab_size["targets"] if hasattr(self._hparams, "vocab_divisor"): vocab_size += (-vocab_size) % self._hparams.vocab_divisor encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum( tf.squared_difference(x_stop, b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([x, g], axis=0) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hparams.use_vq_loss: (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size) else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} if hparams.gan_loss_factor != 0.0: res, res_gan = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=tf.reshape(reconstr, labels_shape + [vocab_size]), labels=tf.reshape(labels, labels_shape)) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) reconstr_gan_nonoise = reconstr_gan code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense( res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan_nonoise = reconstr_gan reconstr_gan = self.gumbel_sample(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan_nonoise) def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs, do_tanh=hparams.sliced_do_tanh) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = tf.reshape(reconstr, labels_shape + [vocab_size]) return logits, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] labels = features["targets_raw"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] is_image = isinstance(self.hparams.problem, image_utils.ImageProblem) if is_image: vocab_size = self.hparams.problem.vocab_size res = tf.layers.dense( res, self.hparams.problem.num_channels * self.hparams.hidden_size) output_shape = common_layers.shape_list(res)[:-1] + [ self.hparams.problem.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) elif isinstance(self.hparams.problem, text_problems.Text2TextProblem): vocab_size = self._problem_hparams.target_modality.top_dimensionality res = tf.layers.dense(res, self.hparams.hidden_size) else: raise Exception("Unsupported problem type: %s" % self.hparams.problem) one_hot_labels = tf.one_hot(labels, vocab_size) code_loss_gan = 0.0 if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) with tf.variable_scope("vq"): reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss( res, one_hot_labels, vocab_size) with tf.variable_scope("vq", reuse=tf.AUTO_REUSE): reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss( res, one_hot_labels, vocab_size) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: if is_image: tf.summary.image( "gan", common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)), max_outputs=1) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(target_codes, reverse_gradient(res_gan), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor if is_image: tf.summary.image( "ae", common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)), max_outputs=1) return reconstr, { "training": targets_loss, "code_loss": code_loss, "code_loss_gan": code_loss_gan, "b_loss": b_loss, "gan_loss": -gan_loss }
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.modality["targets"].top_dimensionality encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([x, g], axis=0) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hparams.use_vq_loss: (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size) else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} if hparams.gan_loss_factor != 0.0: res, res_gan = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=tf.reshape(reconstr, labels_shape + [vocab_size]), labels=tf.reshape(labels, labels_shape)) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) reconstr_gan_nonoise = reconstr_gan code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense( res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan_nonoise = reconstr_gan reconstr_gan = self.gumbel_sample(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan_nonoise) def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs, do_tanh=hparams.sliced_do_tanh) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = tf.reshape(reconstr, labels_shape + [vocab_size]) return logits, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] vocab_size = self._problem_hparams.target_modality.top_dimensionality shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = tf.reshape(x, shape[:-1] + [shape[-1] * vocab_size]) x = self.embed(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) # Losses. losses = {} if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) with tf.variable_scope("vq"): reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size) losses["code_loss_gan"] = (code_loss_gan * hparams.code_loss_factor * hparams.gan_loss_factor) with tf.variable_scope("vq", reuse=tf.AUTO_REUSE): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss(res, labels, vocab_size) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan) def discriminate(x): return self.discriminator(x, is_training=is_training) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_loss = common_layers.sliced_gan_loss(target_codes, reverse_gradient(gan_codes), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor self.image_summary("ae", reconstr) losses["b_loss"] = b_loss losses["gan_loss"] = -gan_loss logits = reconstr return logits, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] vocab_size = self._problem_hparams.target_modality.top_dimensionality shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean( tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck(self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([g, x], axis=0) encoder_layers = [ tf.concat([l, l], axis=0) for l in encoder_layers ] else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense(res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=reconstr, labels=labels) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense(res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan = tf.nn.log_softmax(reconstr_gan) if is_training and hparams.gumbel_temperature > 0.0: gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(reconstr_gan)) gumbel_samples *= hparams.gumbel_noise_factor reconstr_gan += gumbel_samples reconstr_sample = latent_layers.multinomial_sample( reconstr_gan, temperature=hparams.gumbel_temperature) reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature) else: reconstr_sample = tf.argmax(reconstr_gan, axis=-1) reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1) # Sharpen a bit. # Use 1-hot forward, softmax backward. reconstr_hot = tf.one_hot(reconstr_sample, vocab_size) reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan) def discriminate(x): return self.discriminator(x, is_training=is_training) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape( target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape( gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = reconstr return logits, losses