def body(self, features): hparams = self._hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN x = features["targets"] shape = common_layers.shape_list(x) kernel = (hparams.kernel_height, hparams.kernel_width) is1d = shape[2] == 1 kernel = (hparams.kernel_height, 1) if is1d else kernel strides = (2, 1) if is1d else (2, 2) x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=2) # Down-convolutions. for i in xrange(hparams.num_hidden_layers): x = tf.layers.conv2d( x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides, padding="SAME", activation=tf.nn.relu, name="conv_%d" % i) x = common_layers.layer_norm(x) # Bottleneck (mix during early training, not too important but very stable). b = self.bottleneck(x, hparams.hidden_size * 2**hparams.num_hidden_layers) x = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # Up-convolutions. for i in xrange(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, hparams.hidden_size * 2**j, kernel, strides=strides, padding="SAME", activation=tf.nn.relu, name="deconv_%d" % j) x = common_layers.layer_norm(x) res = x[:, :shape[1], :shape[2], :] return common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training)
def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Pad to make size powers of 2 as needed. x = features["inputs"] inputs_shape = common_layers.shape_list(x) x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=1) x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**hparams.num_compress_steps, axis=2) # Down-stride. for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) filters *= 2 # Add embedded action. action = tf.reshape(features["input_action"][:, 1, :], [-1, 1, 1, hparams.hidden_size]) zeros = tf.zeros(common_layers.shape_list(x)[:-1] + [hparams.hidden_size], dtype=tf.float32) x = tf.concat([x, action + zeros], axis=-1) # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") y = tf.nn.dropout(y, 1.0 - hparams.dropout) if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) x = tf.nn.dropout(x, 1.0 - hparams.dropout) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] # Reward prediction. reward_pred = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) return {"targets": x, "target_reward": reward_pred}
def body(self, features): hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=2) # Run encoder. x = self.encoder(x) x_size = common_layers.shape_list(x)[-1] # Bottleneck (mix during early training, not too important but stable). b = self.bottleneck(x) b_loss = self.bottleneck_loss(b) losses = {"bottleneck0_loss": b_loss} b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_size, losses, is_training, num_stacks - 1) b = self.unbottleneck(b, x_size) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: b = self.sample() res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) hparams.num_hidden_layers = num_stacks return res, losses
def testPadToSameLength(self): x1 = np.random.rand(5, 7, 11) x2 = np.random.rand(5, 9, 11) a, b = common_layers.pad_to_same_length( tf.constant(x1, dtype=tf.float32), tf.constant(x2, dtype=tf.float32)) c, d = common_layers.pad_to_same_length( tf.constant(x1, dtype=tf.float32), tf.constant(x2, dtype=tf.float32), final_length_divisible_by=4) res1, res2 = self.evaluate([a, b]) res1a, res2a = self.evaluate([c, d]) self.assertEqual(res1.shape, (5, 9, 11)) self.assertEqual(res2.shape, (5, 9, 11)) self.assertEqual(res1a.shape, (5, 12, 11)) self.assertEqual(res2a.shape, (5, 12, 11))
def testPadToSameLength(self): x1 = np.random.rand(5, 7, 11) x2 = np.random.rand(5, 9, 11) a, b = common_layers.pad_to_same_length( tf.constant(x1, dtype=tf.float32), tf.constant(x2, dtype=tf.float32)) c, d = common_layers.pad_to_same_length( tf.constant(x1, dtype=tf.float32), tf.constant(x2, dtype=tf.float32), final_length_divisible_by=4) res1, res2 = self.evaluate([a, b]) res1a, res2a = self.evaluate([c, d]) self.assertEqual(res1.shape, (5, 9, 11)) self.assertEqual(res2.shape, (5, 9, 11)) self.assertEqual(res1a.shape, (5, 12, 11)) self.assertEqual(res2a.shape, (5, 12, 11))
def slicenet_similarity_cost(a, b, hparams=None): """Hinge cosine similarity poached from slicenet. TODO: Not clear on cost_im or why we're clearing the diagonals. """ margin = 0.2 with tf.name_scope("slicenet_loss"): a, b = common_layers.pad_to_same_length(a, b) cosine_similarity = _cosine_similarity(a, b) diagonal = tf.diag_part(cosine_similarity) cost_s = tf.maximum(0.0, margin - diagonal + cosine_similarity) cost_im = tf.maximum( 0.0, margin - tf.reshape(diagonal, [-1, 1]) + cosine_similarity) # Clear diagonals. batch_size = tf.shape(a)[0] empty_diagonal_mat = tf.ones_like(cost_s) - tf.eye(batch_size) cost_s *= empty_diagonal_mat cost_im *= empty_diagonal_mat return tf.reduce_mean(cost_s) + tf.reduce_mean(cost_im)
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) input_len = tf.shape(inputs)[1] # Double input size to cover targets. inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]]) inputs.set_shape([None, None, hparams.hidden_size]) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) inputs = encode(inputs, target_space, hparams, "input_enc") # Compress and vae. z, kl_loss, _, _ = vae_compress(tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2), hparams, "vae_compress", "vae_decompress") # Join z with inputs, run decoder. to_decode = common_layers.conv_block( tf.concat([z, tf.expand_dims(inputs, axis=2)], axis=3), hparams.hidden_size, [((1, 1), (1, 1))], name="join_z") ret = encode(tf.squeeze(to_decode, axis=2), target_space, hparams, "dec") # For experiments with one-sided decoder: # decoder_in = tf.squeeze(to_decode, axis=2) # (decoder_input, decoder_self_attention_bias) = ( # transformer.transformer_prepare_decoder(decoder_in, hparams)) # ret = transformer.transformer_decoder( # decoder_input, inputs, decoder_self_attention_bias, None, hparams) kl_loss *= common_layers.inverse_exp_decay(hparams.kl_warmup_steps) * 3.0 losses = {"kl": kl_loss} return tf.expand_dims(ret, axis=2), losses
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape(inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape( inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
def ae_transformer_internal(inputs, targets, target_space, hparams): """AE Transformer, main step used for training.""" with tf.variable_scope("ae_transformer"): # Prepare inputs, targets, k. k = 2**hparams.num_compress_steps _, targets = common_layers.pad_to_same_length( targets, targets, final_length_divisible_by=k) inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") # Compress and ae. ae, hot, kl = ae_compress(targets, hparams.is_2d, hparams, "ae") tf.summary.histogram("hot", tf.reshape(tf.argmax(hot, axis=-1), [-1])) emb = ae_embed(hot, hparams, "ae", reuse=True) # Compress context and run autoregressive decoder on emb-hot. emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2) dec_c = decode(None, None, emb_flat, inputs, ed, hparams) dec_c = tf.reshape(dec_c, tf.shape(emb)) c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(labels=hot, logits=c_z) # If not training, use the predicted z instead of the autoregressive one. if hparams.mode == tf.estimator.ModeKeys.PREDICT: hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) # Decompress, pass for ae loss. z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae") kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8)) reconstruct_loss *= common_layers.inverse_exp_decay( hparams.startup_steps) losses = {"kl": kl, "reconstruction": reconstruct_loss} return z, losses
def similarity_cost(inputs_encoded, targets_encoded): """Loss telling to be more similar to your own targets than to others.""" # This is a first very simple version: handle variable-length by padding # to same length and putting everything into batch. In need of a better way. x, y = common_layers.pad_to_same_length(inputs_encoded, targets_encoded) depth = tf.shape(inputs_encoded)[3] x, y = tf.reshape(x, [-1, depth]), tf.reshape(y, [-1, depth]) return rank_loss(x, y)
def body(self, features): hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=2) # Run encoder. x = self.encoder(x) x_size = common_layers.shape_list(x)[-1] # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) losses = {"bottleneck0_loss": b_loss} b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_bits, losses, is_training, num_stacks - 1) b = self.unbottleneck(b, x_size) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: b = self.sample() res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) hparams.num_hidden_layers = num_stacks return res, losses
def similarity_cost(inputs_encoded, targets_encoded): """Loss telling to be more similar to your own targets than to others.""" # This is a first very simple version: handle variable-length by padding # to same length and putting everything into batch. In need of a better way. x, y = common_layers.pad_to_same_length(inputs_encoded, targets_encoded) depth = tf.shape(inputs_encoded)[3] x, y = tf.reshape(x, [-1, depth]), tf.reshape(y, [-1, depth]) return rank_loss(x, y)
def make_even_size(self, x): if not self.is1d: return common_layers.make_even_size(x) shape1 = x.get_shape().as_list()[1] if shape1 is not None and shape1 % 2 == 0: return x x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=1) return x
def make_even_size(self, x): """Pad x to be even-sized on axis 1 and 2, but only if necessary.""" shape = [dim if dim is not None else -1 for dim in x.get_shape().as_list()] if shape[1] % 2 == 0 and shape[2] % 2 == 0: return x if shape[1] % 2 == 0: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=2) return x if shape[2] % 2 == 0: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=1) return x x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=1) x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=2) return x
def make_even_size(self, x): shape = [ dim if dim is not None else -1 for dim in x.get_shape().as_list() ] if shape[1] % 2 == 0 and shape[2] % 2 == 0: return x if shape[1] % 2 == 0 and self.is1d: return x x, _ = common_layers.pad_to_same_length(x, x, final_length_divisible_by=2, axis=1) if self.is1d: return x x, _ = common_layers.pad_to_same_length(x, x, final_length_divisible_by=2, axis=2) return x
def cycle_vae_gan_internal(inputs, targets, _, hparams): """Cycle GAN, main step used for training.""" with tf.variable_scope("cycle_vae_gan"): # Embed inputs and targets. inputs_orig, targets_orig = tf.to_int32(inputs), tf.to_int32(targets) k = 2 ** hparams.num_compress_steps inputs_orig, targets_orig = common_layers.pad_to_same_length( inputs_orig, targets_orig, final_length_divisible_by=k) inputs = common_layers.embedding( inputs_orig, hparams.vocab_size, hparams.hidden_size, "embed") targets = common_layers.embedding( targets_orig, hparams.vocab_size, hparams.hidden_size, "embed", reuse=True) # Split the batch into input-input and target-target parts. inputs1, _ = split_on_batch(inputs) _, targets2 = split_on_batch(targets) # Input-input part. inp1_back, kl_loss1, inp1_mu, inp1_log_sigma = transformer_vae.vae_compress( inputs1, None, hparams, "inp2hyp", "hyp2inp") inp1_hyp = tf.concat([inp1_mu, inp1_log_sigma], axis=3) # Target-target part. tgt2_back, kl_loss2, tgt2_mu, tgt2_log_sigma = transformer_vae.vae_compress( targets2, None, hparams, "tgt2hyp", "hyp2tgt") tgt2_hyp = tf.concat([tgt2_mu, tgt2_log_sigma], axis=3) # Reconstruction losses. inp1_orig, _ = split_on_batch(inputs_orig) _, tgt2_orig = split_on_batch(targets_orig) inp1_loss = reconstruct_loss( inp1_back, tf.squeeze(inp1_orig, axis=3), hparams) tgt2_loss = reconstruct_loss( tgt2_back, tf.squeeze(tgt2_orig, axis=3), hparams, reuse=True) # Discriminator loss. dloss = discriminate_loss(inp1_hyp, tgt2_hyp, False, hparams, "dloss") # Reconstruct targets from inputs. tgt, _, _, _ = transformer_vae.vae_compress( inputs, None, hparams, "inp2hyp", "hyp2tgt", reuse=True) tgt = tf.layers.dense(tgt, hparams.vocab_size, name="softmax", reuse=True) # We use the reconstruction only for tracking progress, no gradients here! tgt = tf.stop_gradient(tf.expand_dims(tgt, axis=2)) kl_rev_decay = common_layers.inverse_exp_decay(hparams.kl_warmup_steps) losses = {"input_input": hparams.cycle_loss_multiplier * inp1_loss, "target_target": hparams.cycle_loss_multiplier * tgt2_loss, "input_kl": kl_loss1 * kl_rev_decay * 15.0, "target_kl": kl_loss2 * kl_rev_decay * 15.0, "discriminator": dloss} return tgt, losses
def adversary(embedded, inputs, hparams, name, reuse=False): with tf.variable_scope(name, reuse=reuse): h0, i0 = common_layers.pad_to_same_length(embedded, inputs, final_length_divisible_by=16) h0 = tf.concat([h0, tf.expand_dims(i0, axis=2)], axis=-1) h0 = tf.layers.dense(h0, hparams.hidden_size, name="io") h1 = transformer_vae.compress(h0, None, False, hparams, "compress1") h2 = transformer_vae.compress(h1, None, False, hparams, "compress2") res_dense = tf.reduce_mean(h2, axis=[1, 2]) res_single = tf.squeeze(tf.layers.dense(res_dense, 1), axis=-1) return tf.nn.sigmoid(res_single)
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) input_len = tf.shape(inputs)[1] # Double input size to cover targets. inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]]) inputs.set_shape([None, None, hparams.hidden_size]) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) inputs = encode(inputs, target_space, hparams, "input_enc") # Dropout targets or swap for zeros 5% of the time. targets_nodrop = targets max_prestep = hparams.kl_warmup_steps prob_targets = 0.95 if is_training else 1.0 targets_dropout_max = common_layers.inverse_lin_decay( max_prestep) - 0.01 targets = dropmask(targets, targets_dropout_max * 0.7, is_training) targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets), lambda: targets, lambda: tf.zeros_like(targets)) targets = targets_nodrop # Compress and vae. z = tf.get_variable("z", [hparams.hidden_size]) z = tf.reshape(z, [1, 1, 1, -1]) z = tf.tile(z, [tf.shape(inputs)[0], 1, 1, 1]) z = attend(z, inputs, hparams, "z_attendsi") z = ffn(z, hparams, "zff2") z = attend(z, targets, hparams, "z_attendst2") z = ffn(z, hparams, "zff3") z, kl_loss, _, _ = vae(z, hparams, name="vae") z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense") # z, kl_loss, _, _ = vae_compress( # tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2), # hparams, "vae_compress", "vae_decompress") decoder_in = tf.squeeze(z, axis=2) + tf.zeros_like(targets) (decoder_input, decoder_self_attention_bias) = ( transformer.transformer_prepare_decoder(decoder_in, hparams)) ret = transformer.transformer_decoder(decoder_input, inputs, decoder_self_attention_bias, None, hparams) kl_loss *= common_layers.inverse_exp_decay(int( max_prestep * 1.5)) * 5.0 losses = {"kl": kl_loss} return tf.expand_dims(ret, axis=2), losses
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) input_len = tf.shape(inputs)[1] # Double input size to cover targets. inputs = tf.pad(inputs, [[0, 0], [0, input_len], [0, 0]]) inputs.set_shape([None, None, hparams.hidden_size]) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) inputs, ed_bias = encode(inputs, target_space, hparams, "input_enc") # Compress and vae. z, kl, r = vae_compress(tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2), ed_bias, hparams, "vae_compress", "vae_decompress") kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.5)) r *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 2.0)) losses = {"kl": kl, "reconstruction": r} return z, losses
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0, means=None, ema_count=None, ema_means=None): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. batch_size = common_layers.shape_list(inputs)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, _ = bottleneck( targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( tf.stop_gradient(inputs), tf.stop_gradient(ed), tf.stop_gradient(latents_dense), hparams, "extra") latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits") losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=latents_discrete, logits=latents_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"] * 0.5 * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) return bn pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn), bn_inputs, lambda: inputs_c) ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay(25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training only the extra model of latents after 400K steps. # Before we train only this, we decrease lr for other weights. latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step())) decreased_lr = common_layers.inverse_lin_decay(400000) losses["latent_pred"] *= tf.to_float(latent_time) losses["extra"] *= 1.0 - tf.to_float(latent_time) decreased_lr_res = tf.stop_gradient(decreased_lr * res) decreased_lr_res += (1.0 - decreased_lr) * res res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. batch_size = common_layers.shape_list(inputs)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, _ = bottleneck( targets_c, hparams, 2*2048, "vc") if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( tf.stop_gradient(inputs), tf.stop_gradient(ed), tf.stop_gradient(latents_dense), hparams, "extra") latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits") losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=latents_discrete, logits=latents_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"] * 0.5 * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc") return bn pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn), bn_inputs, lambda: inputs_c) ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2*2048, "vc") else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay(25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training only the extra model of latents after 400K steps. # Before we train only this, we decrease lr for other weights. latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step())) decreased_lr = common_layers.inverse_lin_decay(400000) losses["latent_pred"] *= tf.to_float(latent_time) losses["extra"] *= 1.0 - tf.to_float(latent_time) decreased_lr_res = tf.stop_gradient(decreased_lr * res) decreased_lr_res += (1.0 - decreased_lr) * res res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, beam_size, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. orig_targets = targets batch_size = tf.shape(orig_targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc") if _DO_SUMMARIES: tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([]), pc) t_c = tf.cond(cond, lambda: t_c, lambda: targets_c) losses["extra"] = vc_loss * tf.to_float(cond) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: t_pred = decode_transformer( inputs, ed, tf.stop_gradient(t_c), hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=t_bit, logits=t_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"]) * 0.5 * tf.to_float(cond) else: if hparams.bottleneck_kind in ["dense", "vae"]: targets_rand = tf.random_uniform(tf.shape(targets_c)) t_c, _, _, _ = bottleneck(targets_rand, hparams, 2*2048, "vc") else: latent_len = tf.shape(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc") t_c = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams) cache = cache[0, :, :] cache = tf.reshape(cache, [1, latent_len, 1]) cache = tf.tile(cache, [beam_size, 1, 1]) t_c = embed(cache) # Postprocess. d = t_c pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :tf.shape(t_c)[1] + 1, :, :] t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay(25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, tf.shape(t_c)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") all_masked = tf.less(tf.reduce_sum(mask), 0.1) res = tf.cond(all_masked, refine_res, lambda: res) return res, losses, cache
def body(self, features): if self.hparams.mode != tf.estimator.ModeKeys.EVAL: t, i = features["targets_raw"], features["inputs_raw"] t, i = common_layers.pad_to_same_length(t, i) features["targets_raw"] = tf.concat([t, i], axis=0) return super(AutoencoderDualDiscrete, self).body(features)
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = { "extra": tf.constant(0.0), "latent_pred": tf.constant(0.0), "neg_q_entropy": tf.constant(0.0) } if hparams.do_ae: # flatten here original_targets = targets original_targets_shape = tf.shape(original_targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": if inputs is not None: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: max_targets_len_from_inputs = targets else: assert hparams.task == "image" max_targets_len_from_inputs = targets if hparams.word_shuffle: tf.logging.info("Using word shuffle with rate = {}".format( hparams.word_shuffle)) targets_idx = tf.range(start=0, limit=common_layers.shape_list(targets)[1], delta=1) targets_idx = tf.to_float(targets_idx) noise = tf.random_uniform( shape=common_layers.shape_list(targets_idx), minval=0, maxval=1 + hparams.word_shuffle) targets_idx += noise permutation = contrib.framework().argsort(targets_idx) targets_permuted = tf.gather(targets, indices=permutation, axis=1) targets = targets_permuted targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) # Add positional information targets_shape = common_layers.shape_list(targets) targets = tf.reshape( targets, [targets_shape[0], targets_shape[1], targets_shape[3]]) targets = common_attention.add_positional_embedding( targets, hparams.max_length, name="targets_position") targets = tf.reshape(targets, shape=targets_shape) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = ( hparams.bottleneck(inputs=targets_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc")) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) # Scale by latent dimension for summary so we can compare across # batches. if _DO_SUMMARIES: tf.summary.scalar("latent_pred_loss_mean", tf.reduce_mean(latent_pred_loss)) if hparams.sum_over_latents: latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2]) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( tf.squared_difference(inputs_c, targets_c)) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") return bn inputs_c = bn_inputs() ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed, _ = hparams.bottleneck( inputs=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense d_shape = common_layers.shape_list(d) d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]]) d = common_attention.add_positional_embedding(d, hparams.max_length, name="latents_position") d = tf.reshape(d, shape=d_shape) # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if inputs is not None and hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) else: targets = d res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) # res was generated from padded targets, which means it has some extra # elements. These can cause shape problems when computing loss with respect to # the original (unpadded) targets. So we remove their extra elements here. res = res[:, :original_targets_shape[1], :, :] data_dim = common_layers.shape_list(res)[1] latent_dim = common_layers.shape_list(targets_c)[1] return res, losses, cache, data_dim, latent_dim
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None): """Main step used for training.""" # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_discrete_hot, extra_loss = vq_discrete_bottleneck( x=targets_c, hparams=hparams) latents_dense = vq_discrete_unbottleneck(latents_discrete_hot, hparams) latents_discrete = tf.argmax(latents_discrete_hot, axis=-1) tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. latents_pred = decode_transformer(inputs_ex, ed_ex, latents_dense, hparams, "extra") latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot, hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss * tf.to_float(cond)) else: latent_len = common_layers.shape_list(targets_c)[1] embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs_ex, ed_ex, embed, hparams) latents_dense = embed( tf.one_hot(cache, depth=2**hparams.bottleneck_bits)) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, "decompress_%d" % j) res = decode_transformer(inputs, ed, targets, hparams, "decoder") # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None): """Main step used for training.""" # Encoder. inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_discrete_hot, extra_loss = vq_discrete_bottleneck( x=targets_c, hparams=hparams) latents_dense = vq_discrete_unbottleneck(latents_discrete_hot, hparams=hparams) latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c) latents_discrete = tf.argmax(latents_discrete_hot, axis=-1) tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1])) losses["extra"] = extra_loss # Extra loss predicting latent code from input. latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot, hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss) else: latent_len = common_layers.shape_list(targets_c)[1] embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed, hparams) cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed(cache_hot) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, "decompress_%d" % j) masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay(hparams.mask_startup_steps // 4) # Not much at start. masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = 1.0 mask = tf.less(masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d res = decode_transformer(inputs, ed, targets, hparams, "decoder") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, beam_size, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. orig_targets = targets batch_size = tf.shape(orig_targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2 * 2048, "vc") if _DO_SUMMARIES: tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([]), pc) t_c = tf.cond(cond, lambda: t_c, lambda: targets_c) losses["extra"] = vc_loss * tf.to_float(cond) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: t_pred = decode_transformer(inputs, ed, tf.stop_gradient(t_c), hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") losses[ "latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=t_bit, logits=t_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"]) * 0.5 * tf.to_float(cond) else: if hparams.bottleneck_kind in ["dense", "vae"]: targets_rand = tf.random_uniform(tf.shape(targets_c)) t_c, _, _, _ = bottleneck(targets_rand, hparams, 2 * 2048, "vc") else: latent_len = tf.shape(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc") t_c = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams) cache = cache[0, :, :] cache = tf.reshape(cache, [1, latent_len, 1]) cache = tf.tile(cache, [beam_size, 1, 1]) t_c = embed(cache) # Postprocess. d = t_c pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :tf.shape(t_c)[1] + 1, :, :] t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay( 25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, tf.shape(t_c)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") all_masked = tf.less(tf.reduce_sum(mask), 0.1) res = tf.cond(all_masked, refine_res, lambda: res) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( (inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) if hparams.task == "translate": targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.task == "translate": res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def body(self, features): if self.hparams.mode != tf.estimator.ModeKeys.EVAL: t, i = features["targets_raw"], features["inputs_raw"] t, i = common_layers.pad_to_same_length(t, i) features["targets_raw"] = tf.concat([t, i], axis=0) return super(AutoencoderDualDiscrete, self).body(features)
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample( latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps _, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) # Transformer preparations and encoder. (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias ) = transformer.transformer_prepare_encoder(inputs, target_space, hparams) residual_fn = transformer.get_residual_fn(hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_output = transformer.transformer_encoder( encoder_input, residual_fn, encoder_self_attention_bias, hparams) def get_decoder_autoregressive(): """Decoder input for autoregressive computation.""" (a, b) = transformer.transformer_prepare_decoder(targets, hparams) return (a, b, tf.constant(0.0)) # 10% of the time we compress all-zeros, as will be at decoding start. prob_targets = 0.9 if is_training else 1.0 to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets), lambda: targets, lambda: tf.zeros_like(targets)) z, kl_loss = compress_vae(to_compress, hparams, "vae") # Decompress. for i in xrange(hparams.num_compress_steps): j = hparams.num_hidden_layers - i - 1 z = decompress(z, hparams, "decompress_%d" % j) def get_decoder_from_vae(): """Decoder input computed by VAE.""" # Return decoder stuff. (a, b) = transformer.transformer_prepare_decoder( tf.squeeze(z, axis=2), hparams) return (a, b, kl_loss) # Randomize decoder inputs.. prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7 step = tf.to_float(tf.contrib.framework.get_global_step()) if not is_training: prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0), lambda: tf.constant(1.0)) (decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond(tf.less(tf.random_uniform([]), prob_do_vae), get_decoder_from_vae, get_decoder_autoregressive) # Transformer decoder. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0), lambda: tf.constant(0.0)) prob_self = 0.4 if is_training else cond_self (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self), lambda: (z, kl_loss), lambda: (decoder_output, kl_loss2)) kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0 return ret, kl_loss
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None): """Main step used for training.""" # Encoder. inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_discrete_hot, extra_loss = vq_discrete_bottleneck( x=targets_c, hparams=hparams) latents_dense = vq_discrete_unbottleneck( latents_discrete_hot, hparams=hparams) latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c) latents_discrete = tf.argmax(latents_discrete_hot, axis=-1) tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1])) losses["extra"] = extra_loss # Extra loss predicting latent code from input. latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot, hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss) else: latent_len = common_layers.shape_list(targets_c)[1] embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed, hparams) cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed(cache_hot) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, "decompress_%d" % j) masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = 1.0 mask = tf.less(masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d res = decode_transformer(inputs, ed, targets, hparams, "decoder") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, beam_size, cache=None): """AE Transformer, main step used for training.""" hparams.z_size = hparams.hidden_size with tf.variable_scope("ae_transformer"): # Prepare inputs, targets, k. orig_targets = targets batch_size = tf.shape(orig_targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) k = hparams.num_compress_steps # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"vc": tf.constant(0.0), "sm": tf.constant(0.0)} latent_len = hparams.latent_length if hparams.do_ae: targets_pad, _ = common_layers.pad_to_same_length( targets, targets, final_length_divisible_by=latent_len * 2**k) targets_c = compress(targets_pad, None, False, hparams, "compress") targets_c = targets_c[:, :latent_len, :, :] if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2 * 2048, "vc") tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay( hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([]), pc) t_c = tf.cond(cond, lambda: t_c, lambda: targets_c) losses["vc"] = vc_loss * tf.to_float(cond) # Extra loss predicting latent code from input. t_pred = decode_transformer(inputs, ed, tf.stop_gradient(t_c), hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") losses["sm"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=t_bit, logits=t_pred) losses["sm"] = tf.reduce_mean( losses["sm"]) * 0.2 * tf.to_float(cond) else: _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc") t_c = tf.zeros_like(targets_c) if cache is None: cache = ae_latent_sample(t_c, inputs, ed, embed, 3, hparams) cache = cache[0, :, :] cache = tf.reshape(cache, [1, latent_len, 1]) cache = tf.tile(cache, [beam_size, 1, 1]) t_c = embed(cache) # Postprocess. pos = tf.get_variable("pos", [1, latent_len + 1, 1, hparams.hidden_size]) t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) else: targets = tf.pad(targets, [[0, 0], [latent_len + 1, 0], [0, 0], [0, 0]]) res = decode_transformer(inputs, ed, targets, hparams, "decoder") res = res[:, latent_len + 1:, :, :] return res, losses, cache