Beispiel #1
def generator(hparams,
    """Define the Generator graph.

    G will now impute tokens that have been masked from the input seqeunce.
        'Undirectional generative model is not a useful model for this MaskGAN '
        'because future context is needed.  Use only for debugging purposes.')
    config = get_config()
    config.keep_prob = [
        hparams.gen_nas_keep_prob_0, hparams.gen_nas_keep_prob_1

    init_scale = config.init_scale
    initializer = tf.random_uniform_initializer(-init_scale, init_scale)

    with tf.variable_scope('gen', reuse=reuse, initializer=initializer):
        # Neural architecture search cell.
        cell = custom_cell.Alien(config.hidden_size)

        if is_training:
            [h2h_masks, _, _, output_mask
             ] = variational_dropout.generate_variational_dropout_masks(
                 hparams, config.keep_prob)
            output_mask = None

        cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
        initial_state = cell_gen.zero_state(FLAGS.batch_size, tf.float32)

        with tf.variable_scope('rnn'):
            sequence, logits, log_probs = [], [], []
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])
            softmax_w = tf.matrix_transpose(embedding)
            softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

            rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)

            if is_training and FLAGS.keep_prob < 1:
                rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)

            for t in xrange(FLAGS.sequence_length):
                if t > 0:

                # Input to the model is the first token to provide context.  The
                # model will then predict token t > 0.
                if t == 0:
                    # Always provide the real input at t = 0.
                    state_gen = initial_state
                    rnn_inp = rnn_inputs[:, t]

                # If the input is present, read in the input at t.
                # If the input is not present, read in the previously generated.
                    real_rnn_inp = rnn_inputs[:, t]
                    fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)

                    # While validating, the decoder should be operating in teacher
                    # forcing regime.  Also, if we're just training with cross_entropy
                    # use teacher forcing.
                    if is_validating or (is_training
                                         and FLAGS.gen_training_strategy
                                         == 'cross_entropy'):
                        rnn_inp = real_rnn_inp
                        rnn_inp = tf.where(targets_present[:, t - 1],
                                           real_rnn_inp, fake_rnn_inp)

                if is_training:
                    state_gen = list(state_gen)
                    for layer_num, per_layer_state in enumerate(state_gen):
                        per_layer_state = LSTMTuple(
                            per_layer_state[1] * h2h_masks[layer_num])
                        state_gen[layer_num] = per_layer_state

                # RNN.
                rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

                if is_training:
                    rnn_out = output_mask * rnn_out

                logit = tf.matmul(rnn_out, softmax_w) + softmax_b

                # Real sample.
                real = targets[:, t]

                categorical = tf.contrib.distributions.Categorical(
                fake = categorical.sample()
                log_prob = categorical.log_prob(fake)

                # Output for Generator will either be generated or the input.
                # If present:   Return real.
                # If not present:  Return fake.
                output = tf.where(targets_present[:, t], real, fake)

                # Add to lists.

            # Produce the RNN state had the model operated only
            # over real data.
            real_state_gen = initial_state
            for t in xrange(FLAGS.sequence_length):

                rnn_inp = rnn_inputs[:, t]

                # RNN.
                rnn_out, real_state_gen = cell_gen(rnn_inp, real_state_gen)

            final_state = real_state_gen

    return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1),
            tf.stack(log_probs, axis=1), initial_state, final_state)
Beispiel #2
def discriminator(hparams, sequence, is_training, reuse=None):
    """Define the Discriminator graph."""
        'Undirectional Discriminative model is not a useful model for this '
        'MaskGAN because future context is needed.  Use only for debugging '
    sequence = tf.cast(sequence, tf.int32)

    if FLAGS.dis_share_embedding:
        assert hparams.dis_rnn_size == hparams.gen_rnn_size, (
            'If you wish to share Discriminator/Generator embeddings, they must be'
            ' same dimension.')
        with tf.variable_scope('gen/rnn', reuse=True):
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])

    config = get_config()
    config.keep_prob = [
        hparams.dis_nas_keep_prob_0, hparams.dis_nas_keep_prob_1

    with tf.variable_scope('dis', reuse=reuse):
        # Neural architecture search cell.
        cell = custom_cell.Alien(config.hidden_size)

        if is_training:
            [h2h_masks, _, _, output_mask
             ] = variational_dropout.generate_variational_dropout_masks(
                 hparams, config.keep_prob)
            output_mask = None

        cell_dis = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)
        state_dis = cell_dis.zero_state(FLAGS.batch_size, tf.float32)

        with tf.variable_scope('rnn') as vs:
            predictions = []
            if not FLAGS.dis_share_embedding:
                embedding = tf.get_variable(
                    'embedding', [FLAGS.vocab_size, hparams.dis_rnn_size])

            rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)

            if is_training and FLAGS.keep_prob < 1:
                rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)

            for t in xrange(FLAGS.sequence_length):
                if t > 0:

                rnn_in = rnn_inputs[:, t]

                if is_training:
                    state_dis = list(state_dis)
                    for layer_num, per_layer_state in enumerate(state_dis):
                        per_layer_state = LSTMTuple(
                            per_layer_state[1] * h2h_masks[layer_num])
                        state_dis[layer_num] = per_layer_state

                # RNN.
                rnn_out, state_dis = cell_dis(rnn_in, state_dis)

                if is_training:
                    rnn_out = output_mask * rnn_out

                # Prediction is linear output for Discriminator.
                pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)

    predictions = tf.stack(predictions, axis=1)
    return tf.squeeze(predictions, axis=2)
Beispiel #3
def gen_encoder(hparams, inputs, targets_present, is_training, reuse=None):
    """Define the Encoder graph.

    hparams:  Hyperparameters for the MaskGAN.
    inputs:  tf.int32 Tensor of shape [batch_size, sequence_length] with tokens
      up to, but not including, vocab_size.
    targets_present:  tf.bool Tensor of shape [batch_size, sequence_length] with
      True representing the presence of the target.
    is_training:  Boolean indicating operational mode (train/inference).
    reuse (Optional):   Whether to reuse the variables.

    Tuple of (hidden_states, final_state).
    config = get_config()
    # We will use the same variable from the decoder.
    if FLAGS.seq2seq_share_embedding:
        with tf.variable_scope('decoder/rnn'):
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])

    with tf.variable_scope('encoder', reuse=reuse):
        # Neural architecture search cell.
        cell = custom_cell.Alien(config.hidden_size)

        if is_training:
            [h2h_masks, h2i_masks, _, output_mask
             ] = variational_dropout.generate_variational_dropout_masks(
                 hparams, config.keep_prob)
            h2i_masks, output_mask = None, None

        cell = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)

        initial_state = cell.zero_state(FLAGS.batch_size, tf.float32)

        # Add a missing token for inputs not present.
        real_inputs = inputs
        masked_inputs = transform_input_with_is_missing_token(
            inputs, targets_present)

        with tf.variable_scope('rnn'):
            hidden_states = []

            # Split the embedding into two parts so that we can load the PTB
            # weights into one part of the Variable.
            if not FLAGS.seq2seq_share_embedding:
                embedding = tf.get_variable(
                    'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])
            missing_embedding = tf.get_variable('missing_embedding',
                                                [1, hparams.gen_rnn_size])
            embedding = tf.concat([embedding, missing_embedding], axis=0)

            real_rnn_inputs = tf.nn.embedding_lookup(embedding, real_inputs)
            masked_rnn_inputs = tf.nn.embedding_lookup(embedding,

            if is_training and FLAGS.keep_prob < 1:
                masked_rnn_inputs = tf.nn.dropout(masked_rnn_inputs,

            state = initial_state
            for t in xrange(FLAGS.sequence_length):
                if t > 0:

                rnn_inp = masked_rnn_inputs[:, t]

                if is_training:
                    state = list(state)
                    for layer_num, per_layer_state in enumerate(state):
                        per_layer_state = LSTMTuple(
                            per_layer_state[1] * h2h_masks[layer_num])
                        state[layer_num] = per_layer_state

                rnn_out, state = cell(rnn_inp, state, h2i_masks)

                if is_training:
                    rnn_out = output_mask * rnn_out

            final_masked_state = state
            hidden_states = tf.stack(hidden_states, axis=1)

            # Produce the RNN state had the model operated only
            # over real data.
            real_state = initial_state
            for t in xrange(FLAGS.sequence_length):

                # RNN.
                rnn_inp = real_rnn_inputs[:, t]
                rnn_out, real_state = cell(rnn_inp, real_state)
            final_state = real_state

    return (hidden_states, final_masked_state), initial_state, final_state
