def generator(hparams, inputs, targets, targets_present, is_training, is_validating, reuse=None): """Define the Generator graph. G will now impute tokens that have been masked from the input seqeunce. """ tf.logging.info( 'Undirectional generative model is not a useful model for this MaskGAN ' 'because future context is needed. Use only for debugging purposes.') config = get_config() config.keep_prob = [ hparams.gen_nas_keep_prob_0, hparams.gen_nas_keep_prob_1 ] configs.print_config(config) init_scale = config.init_scale initializer = tf.random_uniform_initializer(-init_scale, init_scale) with tf.variable_scope('gen', reuse=reuse, initializer=initializer): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, _, _, output_mask ] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: output_mask = None cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) if is_training and FLAGS.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the model is the first token to provide context. The # model will then predict token t > 0. if t == 0: # Always provide the real input at t = 0. state_gen = initial_state rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) if is_training: state_gen = list(state_gen) for layer_num, per_layer_state in enumerate(state_gen): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state_gen[layer_num] = per_layer_state # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if is_training: rnn_out = output_mask * rnn_out logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Real sample. real = targets[:, t] categorical = tf.contrib.distributions.Categorical( logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) # Output for Generator will either be generated or the input. # # If present: Return real. # If not present: Return fake. output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) # Produce the RNN state had the model operated only # over real data. real_state_gen = initial_state for t in xrange(FLAGS.sequence_length): tf.get_variable_scope().reuse_variables() rnn_inp = rnn_inputs[:, t] # RNN. rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen) final_state = real_state_gen return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(log_probs, axis=1), initial_state, final_state)
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None): """Define the Encoder graph. Args: hparams: Hyperparameters for the MaskGAN. inputs: tf.int32 Tensor of shape [batch_size, sequence_length] with tokens up to, but not including, vocab_size. targets_present: tf.bool Tensor of shape [batch_size, sequence_length] with True representing the presence of the target. is_training: Boolean indicating operational mode (train/inference). reuse (Optional): Whether to reuse the variables. Returns: Tuple of (hidden_states, final_state). """ config = get_config() configs.print_config(config) # We will use the same variable from the decoder. if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn'): embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) with tf.variable_scope('encoder', reuse=reuse): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, h2i_masks, _, output_mask ] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: h2i_masks, output_mask = None, None cell = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) initial_state = cell.zero_state(FLAGS.batch_size, tf.float32) # Add a missing token for inputs not present. real_inputs = inputs masked_inputs = transform_input_with_is_missing_token( inputs, targets_present) with tf.variable_scope('rnn'): hidden_states = [] # Split the embedding into two parts so that we can load the PTB # weights into one part of the Variable. if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) missing_embedding = tf.get_variable('missing_embedding', [1, hparams.gen_rnn_size]) embedding = tf.concat([embedding, missing_embedding], axis=0) real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs) masked_rnn_inputs = tf.nn.embedding_lookup(embedding, masked_inputs) if is_training and FLAGS.keep_prob < 1: masked_rnn_inputs = tf.nn.dropout(masked_rnn_inputs, FLAGS.keep_prob) state = initial_state for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() rnn_inp = masked_rnn_inputs[:, t] if is_training: state = list(state) for layer_num, per_layer_state in enumerate(state): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state[layer_num] = per_layer_state rnn_out, state = cell(rnn_inp, state, h2i_masks) if is_training: rnn_out = output_mask * rnn_out hidden_states.append(rnn_out) final_masked_state = state hidden_states = tf.stack(hidden_states, axis=1) # Produce the RNN state had the model operated only # over real data. real_state = initial_state for t in xrange(FLAGS.sequence_length): tf.get_variable_scope().reuse_variables() # RNN. rnn_inp = real_rnn_inputs[:, t] rnn_out, real_state = cell(rnn_inp, real_state) final_state = real_state return (hidden_states, final_masked_state), initial_state, final_state
def discriminator(hparams, sequence, is_training, reuse=None): """Define the Discriminator graph.""" tf.logging.info( 'Undirectional Discriminative model is not a useful model for this ' 'MaskGAN because future context is needed. Use only for debugging ' 'purposes.') sequence = tf.cast(sequence, tf.int32) if FLAGS.dis_share_embedding: assert hparams.dis_rnn_size == hparams.gen_rnn_size, ( 'If you wish to share Discriminator/Generator embeddings, they must be' ' same dimension.') with tf.variable_scope('gen/rnn', reuse=True): embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size]) config = get_config() config.keep_prob = [ hparams.dis_nas_keep_prob_0, hparams.dis_nas_keep_prob_1 ] configs.print_config(config) with tf.variable_scope('dis', reuse=reuse): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, _, _, output_mask ] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: output_mask = None cell_dis = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32) with tf.variable_scope('rnn') as vs: predictions = [] if not FLAGS.dis_share_embedding: embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, hparams.dis_rnn_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, sequence) if is_training and FLAGS.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() rnn_in = rnn_inputs[:, t] if is_training: state_dis = list(state_dis) for layer_num, per_layer_state in enumerate(state_dis): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state_dis[layer_num] = per_layer_state # RNN. rnn_out, state_dis = cell_dis(rnn_in, state_dis) if is_training: rnn_out = output_mask * rnn_out # Prediction is linear output for Discriminator. pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs) predictions.append(pred) predictions = tf.stack(predictions, axis=1) return tf.squeeze(predictions, axis=2)
def gen_decoder(hparams, inputs, targets, targets_present, encoding_state, is_training, is_validating, reuse=None): """Define the Decoder graph. The Decoder will now impute tokens that have been masked from the input seqeunce. """ config = get_config() gen_decoder_rnn_size = hparams.gen_rnn_size if FLAGS.seq2seq_share_embedding: with tf.variable_scope('decoder/rnn', reuse=True): embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) with tf.variable_scope('decoder', reuse=reuse): # Neural architecture search cell. cell = custom_cell.Alien(config.hidden_size) if is_training: [h2h_masks, _, _, output_mask ] = variational_dropout.generate_variational_dropout_masks( hparams, config.keep_prob) else: output_mask = None cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers) # Hidden encoder states. hidden_vector_encodings = encoding_state[0] # Carry forward the final state tuple from the encoder. # State tuples. state_gen = encoding_state[1] if FLAGS.attention_option is not None: (attention_keys, attention_values, _, attention_construct_fn) = attention_utils.prepare_attention( hidden_vector_encodings, FLAGS.attention_option, num_units=gen_decoder_rnn_size, reuse=reuse) with tf.variable_scope('rnn'): sequence, logits, log_probs = [], [], [] if not FLAGS.seq2seq_share_embedding: embedding = tf.get_variable( 'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size]) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) rnn_inputs = tf.nn.embedding_lookup(embedding, inputs) if is_training and FLAGS.keep_prob < 1: rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob) for t in xrange(FLAGS.sequence_length): if t > 0: tf.get_variable_scope().reuse_variables() # Input to the Decoder. if t == 0: # Always provide the real input at t = 0. rnn_inp = rnn_inputs[:, t] # If the input is present, read in the input at t. # If the input is not present, read in the previously generated. else: real_rnn_inp = rnn_inputs[:, t] fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake) # While validating, the decoder should be operating in teacher # forcing regime. Also, if we're just training with cross_entropy # use teacher forcing. if is_validating or (is_training and FLAGS.gen_training_strategy == 'cross_entropy'): rnn_inp = real_rnn_inp else: rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp, fake_rnn_inp) if is_training: state_gen = list(state_gen) for layer_num, per_layer_state in enumerate(state_gen): per_layer_state = LSTMTuple( per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num]) state_gen[layer_num] = per_layer_state # RNN. rnn_out, state_gen = cell_gen(rnn_inp, state_gen) if is_training: rnn_out = output_mask * rnn_out if FLAGS.attention_option is not None: rnn_out = attention_construct_fn(rnn_out, attention_keys, attention_values) # # TODO(liamfedus): Assert not "monotonic" attention_type. # # TODO(liamfedus): FLAGS.attention_type. # context_state = revised_attention_utils._empty_state() # rnn_out, context_state = attention_construct_fn( # rnn_out, attention_keys, attention_values, context_state, t) logit = tf.matmul(rnn_out, softmax_w) + softmax_b # Output for Decoder. # If input is present: Return real at t+1. # If input is not present: Return fake for t+1. real = targets[:, t] categorical = tf.contrib.distributions.Categorical( logits=logit) fake = categorical.sample() log_prob = categorical.log_prob(fake) output = tf.where(targets_present[:, t], real, fake) # Add to lists. sequence.append(output) log_probs.append(log_prob) logits.append(logit) return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(log_probs, axis=1))