def body(self, features): hparams = self._hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN x = features["targets"] shape = common_layers.shape_list(x) kernel = (hparams.kernel_height, hparams.kernel_width) is1d = shape[2] == 1 kernel = (hparams.kernel_height, 1) if is1d else kernel strides = (2, 1) if is1d else (2, 2) x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=2) # Down-convolutions. for i in xrange(hparams.num_hidden_layers): x = tf.layers.conv2d( x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides, padding="SAME", activation=tf.nn.relu, name="conv_%d" % i) x = common_layers.layer_norm(x) # Bottleneck (mix during early training, not too important but very stable). b = self.bottleneck(x, hparams.hidden_size * 2**hparams.num_hidden_layers) x = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # Up-convolutions. for i in xrange(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, hparams.hidden_size * 2**j, kernel, strides=strides, padding="SAME", activation=tf.nn.relu, name="deconv_%d" % j) x = common_layers.layer_norm(x) res = x[:, :shape[1], :shape[2], :] return common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training)
def body(self, features): hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=2) # Run encoder. x = self.encoder(x) x_size = common_layers.shape_list(x)[-1] # Bottleneck (mix during early training, not too important but stable). b = self.bottleneck(x) b_loss = self.bottleneck_loss(b) losses = {"bottleneck0_loss": b_loss} b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_size, losses, is_training, num_stacks - 1) b = self.unbottleneck(b, x_size) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: b = self.sample() res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) hparams.num_hidden_layers = num_stacks return res, losses
def isemhash_bottleneck(x, bottleneck_size, bottleneck_noise, discretize_warmup_steps, mode, isemhash_noise_dev=0.5, isemhash_mix_prob=0.5): """Improved semantic hashing bottleneck.""" with tf.variable_scope("isemhash_bottleneck"): x = tf.layers.dense(x, bottleneck_size, name="dense") y = common_layers.saturating_sigmoid(x) if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal(common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev) y = common_layers.saturating_sigmoid(x + noise) d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y) d = 2.0 * d - 1.0 # Move from [0, 1] to [-1, 1]. if mode == tf.estimator.ModeKeys.TRAIN: # Flip some bits. noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN, max_prob=isemhash_mix_prob) return d
def body(self, features): hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=2) # Run encoder. x = self.encoder(x) x_size = common_layers.shape_list(x)[-1] # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) losses = {"bottleneck0_loss": b_loss} b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_bits, losses, is_training, num_stacks - 1) b = self.unbottleneck(b, x_size) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: b = self.sample() res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) hparams.num_hidden_layers = num_stacks return res, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT features["targets"] = tf.zeros_like(basic_result) targets_dropout = common_layers.mix( features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) # Sometimes it's useful to look at non-autoregressive evals. if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d(concat1d, shape[3], 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d(concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) return tf.reshape(res, shape), losses raise ValueError("Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
def bottleneck(self, x, res_size): hparams = self._hparams x = tf.tanh( tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck")) d = x + tf.stop_gradient(2 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) y = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout) x = common_layers.mix(d, y, hparams.discretize_warmup_steps, hparams.mode == tf.estimator.ModeKeys.TRAIN) x = tf.layers.dense(x, res_size, name="unbottleneck") return x
def dropout(self, x): if self.hparams.dropout <= 0.0: return x # For simple dropout just do this: # return tf.nn.dropout(x, 1.0 - self.hparams.dropout) is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN return common_layers.mix( tf.zeros_like(x), x, self.hparams.bottleneck_warmup_steps, is_training, max_prob=self.hparams.dropout, broadcast_last=True)
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b = self.bottleneck(x) self._cur_bottleneck_tensor = b b_loss = self.bottleneck_loss(b) b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) return res, {"bottleneck_loss": b_loss}
def bottleneck(self, x): hparams = self.hparams x = tf.tanh(tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 d *= noise x = common_layers.mix(d, x, hparams.discretize_warmup_steps, hparams.mode == tf.estimator.ModeKeys.TRAIN) return x, 0.0
def tanh_discrete_bottleneck(x, bottleneck_size, bottleneck_noise, discretize_warmup_steps, mode): """Simple discretization through tanh, flip bottleneck_noise many bits.""" x = tf.tanh( tf.layers.dense(x, bottleneck_size, name="tanh_discrete_bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, x, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN) return d
def dropout(self, x): if self.hparams.dropout <= 0.0: return x # For simple dropout just do this: # return tf.nn.dropout(x, 1.0 - self.hparams.dropout) is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN return common_layers.mix( tf.zeros_like(x), x, self.hparams.bottleneck_warmup_steps, is_training, max_prob=self.hparams.dropout, broadcast_last=True)
def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise, discretize_warmup_steps, mode): """Simple discretization through tanh, flip bottleneck_noise many bits.""" x = tf.tanh(tf.layers.dense(x, bottleneck_bits, name="tanh_discrete_bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, x, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN) return d, 0.0
def isemhash_bottleneck(x, bottleneck_bits, bottleneck_noise, discretize_warmup_steps, mode, isemhash_noise_dev=0.5, isemhash_mix_prob=0.5): """Improved semantic hashing bottleneck.""" with tf.variable_scope("isemhash_bottleneck"): x = tf.layers.dense(x, bottleneck_bits, name="dense") y = common_layers.saturating_sigmoid(x) if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal( common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev) y = common_layers.saturating_sigmoid(x + noise) d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y) d = 2.0 * d - 1.0 # Move from [0, 1] to [-1, 1]. if mode == tf.estimator.ModeKeys.TRAIN: # Flip some bits. noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN, max_prob=isemhash_mix_prob) return d, 0.0
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT features["targets"] = tf.zeros_like(basic_result) targets_dropout = common_layers.mix( features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) # Sometimes it's useful to look at non-autoregressive evals. if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, shape[3], 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) return tf.reshape(res, shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: # Split back if we added a purely sampled batch. res_gan, res = tf.split(res, 2, axis=0) num_channels = self.hparams.problem.num_channels res_rgb = common_layers.convert_real_to_rgb( tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb"))) tf.summary.image( "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1) orig_rgb = tf.to_float(features["targets_raw"]) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(orig_rgb, reverse_gradient(res_rgb), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor # Mix the final result and return. res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] labels = features["targets_raw"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] is_image = isinstance(self.hparams.problem, image_utils.ImageProblem) if is_image: vocab_size = self.hparams.problem.vocab_size res = tf.layers.dense( res, self.hparams.problem.num_channels * self.hparams.hidden_size) output_shape = common_layers.shape_list(res)[:-1] + [ self.hparams.problem.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) elif isinstance(self.hparams.problem, text_problems.Text2TextProblem): vocab_size = self._problem_hparams.target_modality.top_dimensionality res = tf.layers.dense(res, self.hparams.hidden_size) else: raise Exception("Unsupported problem type: %s" % self.hparams.problem) one_hot_labels = tf.one_hot(labels, vocab_size) code_loss_gan = 0.0 if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) with tf.variable_scope("vq"): reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss( res, one_hot_labels, vocab_size) with tf.variable_scope("vq", reuse=tf.AUTO_REUSE): reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss( res, one_hot_labels, vocab_size) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: if is_image: tf.summary.image( "gan", common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)), max_outputs=1) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(target_codes, reverse_gradient(res_gan), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor if is_image: tf.summary.image( "ae", common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)), max_outputs=1) return reconstr, { "training": targets_loss, "code_loss": code_loss, "code_loss_gan": code_loss_gan, "b_loss": b_loss, "gan_loss": -gan_loss }
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] vocab_size = self._problem_hparams.target_modality.top_dimensionality shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = tf.reshape(x, shape[:-1] + [shape[-1] * vocab_size]) x = self.embed(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) # Losses. losses = {} if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) with tf.variable_scope("vq"): reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size) losses["code_loss_gan"] = (code_loss_gan * hparams.code_loss_factor * hparams.gan_loss_factor) with tf.variable_scope("vq", reuse=tf.AUTO_REUSE): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss(res, labels, vocab_size) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan) def discriminate(x): return self.discriminator(x, is_training=is_training) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_loss = common_layers.sliced_gan_loss(target_codes, reverse_gradient(gan_codes), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor self.image_summary("ae", reconstr) losses["b_loss"] = b_loss losses["gan_loss"] = -gan_loss logits = reconstr return logits, losses