def dis_decoder(hparams, sequence, encoding_state, is_training, reuse=None, embedding=None): """Define the Discriminator decoder. Read in the sequence and predict at each time point.""" sequence = tf.cast(sequence, tf.int32) with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell(hparams.dis_rnn_size, forget_bias=0.0, state_is_tuple=True, reuse=reuse) attn_cell = lstm_cell if is_training and hparams.dis_vd_keep_prob < 1: def attn_cell(): return variational_dropout.VariationalDropoutWrapper( lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size, hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob) cell_dis = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.dis_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=hparams.dis_rnn_size, reuse=reuse) def make_mask(keep_prob, units): random_tensor = keep_prob # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) random_tensor += tf.random_uniform( tf.stack([FLAGS.batch_size, units])) return tf.floor(random_tensor) / keep_prob if is_training: output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size) with tf.variable_scope('rnn') as vs: predictions = [] rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() rnn_in = rnn_inputs[:, t] rnn_out, state = cell_dis(rnn_in, state) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) if is_training: rnn_out *= output_mask # Prediction is linear output for Discriminator. pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs) predictions.append(pred) predictions = tf.stack(predictions, axis=1) return tf.squeeze(predictions, axis=2)
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ gen_decoder_rnn_size = hparams.gen_rnn_size with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.LayerNormBasicLSTMCell(gen_decoder_rnn_size, reuse=reuse) attn_cell = lstm_cell if FLAGS.zoneout_drop_prob > 0.0: def attn_cell(): return zoneout.ZoneoutWrapper( lstm_cell(), zoneout_drop_prob=FLAGS.zoneout_drop_prob, is_training=is_training) cell_gen = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.gen_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) softmax_w = tf.get_variable( 'softmax_w', [gen_decoder_rnn_size, FLAGS.vocab_size]) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) # # TODO(liamfedus): Assert not "monotonic" attention_type. # # TODO(liamfedus): FLAGS.attention_type. # context_state = revised_attention_utils._empty_state() # rnn_out, context_state = attention_construct_fn( # rnn_out, attention_keys, attention_values, context_state, t) logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical( logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(log_probs, axis=1))
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ gen_decoder_rnn_size = hparams.gen_rnn_size targets = tf.Print(targets, [targets], message='targets', summarize=50) if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn', reuse=True): embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell(gen_decoder_rnn_size, forget_bias=0.0, state_is_tuple=True, reuse=reuse) attn_cell = lstm_cell if is_training and hparams.gen_vd_keep_prob < 1: def attn_cell(): return variational_dropout.VariationalDropoutWrapper( lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size, hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob) cell_gen = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.gen_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) def make_mask(keep_prob, units): random_tensor = keep_prob # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) random_tensor += tf.random_uniform( tf.stack([FLAGS.batch_size, units])) return tf.floor(random_tensor) / keep_prob if is_training: output_mask = make_mask(hparams.gen_vd_keep_prob, hparams.gen_rnn_size) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) # TODO(adai): Perhaps append IMDB labels placeholder to input at # each time point. rnn_outs = [] fake = None for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or FLAGS.gen_training_strategy == 'cross_entropy': rnn_inp = real_rnn_inp else: fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) if is_training: rnn_out *= output_mask rnn_outs.append(rnn_out) if FLAGS.gen_training_strategy != 'cross_entropy': logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical( logits=logit) if FLAGS.use_gen_mode: fake = categorical.mode() else: fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) else: real = targets[:, t] logit = tf.zeros( tf.stack([FLAGS.batch_size, FLAGS.vocab_size])) log_prob = tf.zeros(tf.stack([FLAGS.batch_size])) output = real # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) if FLAGS.gen_training_strategy == 'cross_entropy': logits = tf.nn.bias_add( tf.matmul( tf.reshape(tf.stack(rnn_outs, 1), [-1, gen_decoder_rnn_size]), softmax_w), softmax_b) logits = tf.reshape( logits, [-1, FLAGS.sequence_length, FLAGS.vocab_size]) else: logits = tf.stack(logits, axis=1) return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
def dis_decoder(hparams, sequence, encoding_state, is_training, reuse=None, embedding=None): """Define the Discriminator decoder. Read in the sequence and predict at each time point.""" sequence = tf.cast(sequence, tf.int32) with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell( hparams.dis_rnn_size, forget_bias=0.0, state_is_tuple=True, reuse=reuse) attn_cell = lstm_cell if is_training and hparams.dis_vd_keep_prob < 1: def attn_cell(): return variational_dropout.VariationalDropoutWrapper( lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size, hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob) cell_dis = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.dis_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=hparams.dis_rnn_size, reuse=reuse) def make_mask(keep_prob, units): random_tensor = keep_prob # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units])) return tf.floor(random_tensor) / keep_prob if is_training: output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size) with tf.variable_scope('rnn') as vs: predictions = [] rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() rnn_in = rnn_inputs[:, t] rnn_out, state = cell_dis(rnn_in, state) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) if is_training: rnn_out *= output_mask # Prediction is linear output for Discriminator. pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs) predictions.append(pred) predictions = tf.stack(predictions, axis=1) return tf.squeeze(predictions, axis=2)
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ gen_decoder_rnn_size = hparams.gen_rnn_size targets = tf.Print(targets, [targets], message='targets', summarize=50) if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn', reuse=True): embedding = tf.get_variable('embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.BasicLSTMCell( gen_decoder_rnn_size, forget_bias=0.0, state_is_tuple=True, reuse=reuse) attn_cell = lstm_cell if is_training and hparams.gen_vd_keep_prob < 1: def attn_cell(): return variational_dropout.VariationalDropoutWrapper( lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size, hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob) cell_gen = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.gen_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) def make_mask(keep_prob, units): random_tensor = keep_prob # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units])) return tf.floor(random_tensor) / keep_prob if is_training: output_mask = make_mask(hparams.gen_vd_keep_prob, hparams.gen_rnn_size) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable('embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) # TODO(adai): Perhaps append IMDB labels placeholder to input at # each time point. rnn_outs = [] fake = None for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or FLAGS.gen_training_strategy == 'cross_entropy': rnn_inp = real_rnn_inp else: fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) if is_training: rnn_out *= output_mask rnn_outs.append(rnn_out) if FLAGS.gen_training_strategy != 'cross_entropy': logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b) # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical(logits=logit) if FLAGS.use_gen_mode: fake = categorical.mode() else: fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) else: real = targets[:, t] logit = tf.zeros(tf.stack([FLAGS.batch_size, FLAGS.vocab_size])) log_prob = tf.zeros(tf.stack([FLAGS.batch_size])) output = real # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) if FLAGS.gen_training_strategy == 'cross_entropy': logits = tf.nn.bias_add( tf.matmul( tf.reshape(tf.stack(rnn_outs, 1), [-1, gen_decoder_rnn_size]), softmax_w), softmax_b) logits = tf.reshape(logits, [-1, FLAGS.sequence_length, FLAGS.vocab_size]) else: logits = tf.stack(logits, axis=1) return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ config = get_config() gen_decoder_rnn_size = hparams.gen_rnn_size if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn', reuse=True): embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) with tf.variable_scope('decoder', reuse=reuse): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, _, _, output_mask ] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: output_mask = None cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) if is_training and FLAGS.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) if is_training: state_gen = list(state_gen) for layer_num, per_layer_state in enumerate(state_gen): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state_gen[layer_num] = per_layer_state # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if is_training: rnn_out = output_mask * rnn_out if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) # # TODO(liamfedus): Assert not "monotonic" attention_type. # # TODO(liamfedus): FLAGS.attention_type. # context_state = revised_attention_utils._empty_state() # rnn_out, context_state = attention_construct_fn( # rnn_out, attention_keys, attention_values, context_state, t) logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical( logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(log_probs, axis=1))
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ gen_decoder_rnn_size = hparams.gen_rnn_size with tf.variable_scope('decoder', reuse=reuse): def lstm_cell(): return tf.contrib.rnn.LayerNormBasicLSTMCell( gen_decoder_rnn_size, reuse=reuse) attn_cell = lstm_cell if FLAGS.zoneout_drop_prob > 0.0: def attn_cell(): return zoneout.ZoneoutWrapper( lstm_cell(), zoneout_drop_prob=FLAGS.zoneout_drop_prob, is_training=is_training) cell_gen = tf.contrib.rnn.MultiRNNCell( [attn_cell() for _ in range(hparams.gen_num_layers)], state_is_tuple=True) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] embedding = tf.get_variable('embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) softmax_w = tf.get_variable('softmax_w', [gen_decoder_rnn_size, FLAGS.vocab_size]) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) # # TODO(liamfedus): Assert not "monotonic" attention_type. # # TODO(liamfedus): FLAGS.attention_type. # context_state = revised_attention_utils._empty_state() # rnn_out, context_state = attention_construct_fn( # rnn_out, attention_keys, attention_values, context_state, t) logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical(logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack( log_probs, axis=1))
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ config = get_config() gen_decoder_rnn_size = hparams.gen_rnn_size if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn', reuse=True): embedding = tf.get_variable('embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) with tf.variable_scope('decoder', reuse=reuse): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, _, _, output_mask] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: output_mask = None cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable('embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) if is_training and FLAGS.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) if is_training: state_gen = list(state_gen) for layer_num, per_layer_state in enumerate(state_gen): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state_gen[layer_num] = per_layer_state # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if is_training: rnn_out = output_mask * rnn_out if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) # # TODO(liamfedus): Assert not "monotonic" attention_type. # # TODO(liamfedus): FLAGS.attention_type. # context_state = revised_attention_utils._empty_state() # rnn_out, context_state = attention_construct_fn( # rnn_out, attention_keys, attention_values, context_state, t) logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical(logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack( log_probs, axis=1))