Beispiel #1
0
def dis_decoder(hparams,
                sequence,
                encoding_state,
                is_training,
                reuse=None,
                embedding=None):
    """Define the Discriminator decoder.  Read in the sequence and predict
    at each time point."""
    sequence = tf.cast(sequence, tf.int32)

    with tf.variable_scope('decoder', reuse=reuse):

        def lstm_cell():
            return tf.contrib.rnn.BasicLSTMCell(hparams.dis_rnn_size,
                                                forget_bias=0.0,
                                                state_is_tuple=True,
                                                reuse=reuse)

        attn_cell = lstm_cell
        if is_training and hparams.dis_vd_keep_prob < 1:

            def attn_cell():
                return variational_dropout.VariationalDropoutWrapper(
                    lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
                    hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)

        cell_dis = tf.contrib.rnn.MultiRNNCell(
            [attn_cell() for _ in range(hparams.dis_num_layers)],
            state_is_tuple=True)

        # Hidden encoder states.
        hidden_vector_encodings = encoding_state[0]

        # Carry forward the final state tuple from the encoder.
        # State tuples.
        state = encoding_state[1]

        if FLAGS.attention_option is not None:
            (attention_keys, attention_values, _,
             attention_construct_fn) = attention_utils.prepare_attention(
                 hidden_vector_encodings,
                 FLAGS.attention_option,
                 num_units=hparams.dis_rnn_size,
                 reuse=reuse)

        def make_mask(keep_prob, units):
            random_tensor = keep_prob
            # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
            random_tensor += tf.random_uniform(
                tf.stack([FLAGS.batch_size, units]))
            return tf.floor(random_tensor) / keep_prob

        if is_training:
            output_mask = make_mask(hparams.dis_vd_keep_prob,
                                    hparams.dis_rnn_size)

        with tf.variable_scope('rnn') as vs:
            predictions = []

            rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)

            for t in xrange(FLAGS.sequence_length):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                rnn_in = rnn_inputs[:, t]
                rnn_out, state = cell_dis(rnn_in, state)

                if FLAGS.attention_option is not None:
                    rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                                     attention_values)
                if is_training:
                    rnn_out *= output_mask

                # Prediction is linear output for Discriminator.
                pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
                predictions.append(pred)

    predictions = tf.stack(predictions, axis=1)
    return tf.squeeze(predictions, axis=2)
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
    """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
    gen_decoder_rnn_size = hparams.gen_rnn_size

    with tf.variable_scope('decoder', reuse=reuse):

        def lstm_cell():
            return tf.contrib.rnn.LayerNormBasicLSTMCell(gen_decoder_rnn_size,
                                                         reuse=reuse)

        attn_cell = lstm_cell
        if FLAGS.zoneout_drop_prob > 0.0:

            def attn_cell():
                return zoneout.ZoneoutWrapper(
                    lstm_cell(),
                    zoneout_drop_prob=FLAGS.zoneout_drop_prob,
                    is_training=is_training)

        cell_gen = tf.contrib.rnn.MultiRNNCell(
            [attn_cell() for _ in range(hparams.gen_num_layers)],
            state_is_tuple=True)

        # Hidden encoder states.
        hidden_vector_encodings = encoding_state[0]

        # Carry forward the final state tuple from the encoder.
        # State tuples.
        state_gen = encoding_state[1]

        if FLAGS.attention_option is not None:
            (attention_keys, attention_values, _,
             attention_construct_fn) = attention_utils.prepare_attention(
                 hidden_vector_encodings,
                 FLAGS.attention_option,
                 num_units=gen_decoder_rnn_size,
                 reuse=reuse)

        with tf.variable_scope('rnn'):
            sequence, logits, log_probs = [], [], []
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size])
            softmax_w = tf.get_variable(
                'softmax_w', [gen_decoder_rnn_size, FLAGS.vocab_size])
            softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

            rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)

            for t in xrange(FLAGS.sequence_length):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                # Input to the Decoder.
                if t == 0:
                    # Always provide the real input at t = 0.
                    rnn_inp = rnn_inputs[:, t]

                # If the input is present, read in the input at t.
                # If the input is not present, read in the previously generated.
                else:
                    real_rnn_inp = rnn_inputs[:, t]
                    fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)

                    # While validating, the decoder should be operating in teacher
                    # forcing regime.  Also, if we're just training with cross_entropy
                    # use teacher forcing.
                    if is_validating or (is_training
                                         and FLAGS.gen_training_strategy
                                         == 'cross_entropy'):
                        rnn_inp = real_rnn_inp
                    else:
                        rnn_inp = tf.where(targets_present[:, t - 1],
                                           real_rnn_inp, fake_rnn_inp)

                # RNN.
                rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

                if FLAGS.attention_option is not None:
                    rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                                     attention_values)
                #   # TODO(liamfedus): Assert not "monotonic" attention_type.
                #   # TODO(liamfedus): FLAGS.attention_type.
                #   context_state = revised_attention_utils._empty_state()
                #   rnn_out, context_state = attention_construct_fn(
                #       rnn_out, attention_keys, attention_values, context_state, t)
                logit = tf.matmul(rnn_out, softmax_w) + softmax_b

                # Output for Decoder.
                # If input is present:   Return real at t+1.
                # If input is not present:  Return fake for t+1.
                real = targets[:, t]

                categorical = tf.contrib.distributions.Categorical(
                    logits=logit)
                fake = categorical.sample()
                log_prob = categorical.log_prob(fake)

                output = tf.where(targets_present[:, t], real, fake)

                # Add to lists.
                sequence.append(output)
                log_probs.append(log_prob)
                logits.append(logit)

    return (tf.stack(sequence,
                     axis=1), tf.stack(logits,
                                       axis=1), tf.stack(log_probs, axis=1))
Beispiel #3
0
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
    """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
    gen_decoder_rnn_size = hparams.gen_rnn_size

    targets = tf.Print(targets, [targets], message='targets', summarize=50)
    if FLAGS.seq2seq_share_embedding:
        with tf.variable_scope('decoder/rnn', reuse=True):
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])

    with tf.variable_scope('decoder', reuse=reuse):

        def lstm_cell():
            return tf.contrib.rnn.BasicLSTMCell(gen_decoder_rnn_size,
                                                forget_bias=0.0,
                                                state_is_tuple=True,
                                                reuse=reuse)

        attn_cell = lstm_cell
        if is_training and hparams.gen_vd_keep_prob < 1:

            def attn_cell():
                return variational_dropout.VariationalDropoutWrapper(
                    lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size,
                    hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob)

        cell_gen = tf.contrib.rnn.MultiRNNCell(
            [attn_cell() for _ in range(hparams.gen_num_layers)],
            state_is_tuple=True)

        # Hidden encoder states.
        hidden_vector_encodings = encoding_state[0]

        # Carry forward the final state tuple from the encoder.
        # State tuples.
        state_gen = encoding_state[1]

        if FLAGS.attention_option is not None:
            (attention_keys, attention_values, _,
             attention_construct_fn) = attention_utils.prepare_attention(
                 hidden_vector_encodings,
                 FLAGS.attention_option,
                 num_units=gen_decoder_rnn_size,
                 reuse=reuse)

        def make_mask(keep_prob, units):
            random_tensor = keep_prob
            # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
            random_tensor += tf.random_uniform(
                tf.stack([FLAGS.batch_size, units]))
            return tf.floor(random_tensor) / keep_prob

        if is_training:
            output_mask = make_mask(hparams.gen_vd_keep_prob,
                                    hparams.gen_rnn_size)

        with tf.variable_scope('rnn'):
            sequence, logits, log_probs = [], [], []

            if not FLAGS.seq2seq_share_embedding:
                embedding = tf.get_variable(
                    'embedding', [FLAGS.vocab_size, hparams.gen_rnn_size])
            softmax_w = tf.matrix_transpose(embedding)
            softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

            rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
            # TODO(adai): Perhaps append IMDB labels placeholder to input at
            # each time point.

            rnn_outs = []

            fake = None
            for t in xrange(FLAGS.sequence_length):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                # Input to the Decoder.
                if t == 0:
                    # Always provide the real input at t = 0.
                    rnn_inp = rnn_inputs[:, t]

                # If the input is present, read in the input at t.
                # If the input is not present, read in the previously generated.
                else:
                    real_rnn_inp = rnn_inputs[:, t]

                    # While validating, the decoder should be operating in teacher
                    # forcing regime.  Also, if we're just training with cross_entropy
                    # use teacher forcing.
                    if is_validating or FLAGS.gen_training_strategy == 'cross_entropy':
                        rnn_inp = real_rnn_inp
                    else:
                        fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
                        rnn_inp = tf.where(targets_present[:, t - 1],
                                           real_rnn_inp, fake_rnn_inp)

                # RNN.
                rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

                if FLAGS.attention_option is not None:
                    rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                                     attention_values)
                if is_training:
                    rnn_out *= output_mask

                rnn_outs.append(rnn_out)
                if FLAGS.gen_training_strategy != 'cross_entropy':
                    logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w),
                                           softmax_b)

                    # Output for Decoder.
                    # If input is present:   Return real at t+1.
                    # If input is not present:  Return fake for t+1.
                    real = targets[:, t]

                    categorical = tf.contrib.distributions.Categorical(
                        logits=logit)
                    if FLAGS.use_gen_mode:
                        fake = categorical.mode()
                    else:
                        fake = categorical.sample()
                    log_prob = categorical.log_prob(fake)
                    output = tf.where(targets_present[:, t], real, fake)

                else:
                    real = targets[:, t]
                    logit = tf.zeros(
                        tf.stack([FLAGS.batch_size, FLAGS.vocab_size]))
                    log_prob = tf.zeros(tf.stack([FLAGS.batch_size]))
                    output = real

                # Add to lists.
                sequence.append(output)
                log_probs.append(log_prob)
                logits.append(logit)

            if FLAGS.gen_training_strategy == 'cross_entropy':
                logits = tf.nn.bias_add(
                    tf.matmul(
                        tf.reshape(tf.stack(rnn_outs, 1),
                                   [-1, gen_decoder_rnn_size]), softmax_w),
                    softmax_b)
                logits = tf.reshape(
                    logits, [-1, FLAGS.sequence_length, FLAGS.vocab_size])
            else:
                logits = tf.stack(logits, axis=1)

    return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
Beispiel #4
0
def dis_decoder(hparams,
                sequence,
                encoding_state,
                is_training,
                reuse=None,
                embedding=None):
  """Define the Discriminator decoder.  Read in the sequence and predict
    at each time point."""
  sequence = tf.cast(sequence, tf.int32)

  with tf.variable_scope('decoder', reuse=reuse):

    def lstm_cell():
      return tf.contrib.rnn.BasicLSTMCell(
          hparams.dis_rnn_size,
          forget_bias=0.0,
          state_is_tuple=True,
          reuse=reuse)

    attn_cell = lstm_cell
    if is_training and hparams.dis_vd_keep_prob < 1:

      def attn_cell():
        return variational_dropout.VariationalDropoutWrapper(
            lstm_cell(), FLAGS.batch_size, hparams.dis_rnn_size,
            hparams.dis_vd_keep_prob, hparams.dis_vd_keep_prob)

    cell_dis = tf.contrib.rnn.MultiRNNCell(
        [attn_cell() for _ in range(hparams.dis_num_layers)],
        state_is_tuple=True)

    # Hidden encoder states.
    hidden_vector_encodings = encoding_state[0]

    # Carry forward the final state tuple from the encoder.
    # State tuples.
    state = encoding_state[1]

    if FLAGS.attention_option is not None:
      (attention_keys, attention_values, _,
       attention_construct_fn) = attention_utils.prepare_attention(
           hidden_vector_encodings,
           FLAGS.attention_option,
           num_units=hparams.dis_rnn_size,
           reuse=reuse)

    def make_mask(keep_prob, units):
      random_tensor = keep_prob
      # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
      random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
      return tf.floor(random_tensor) / keep_prob

    if is_training:
      output_mask = make_mask(hparams.dis_vd_keep_prob, hparams.dis_rnn_size)

    with tf.variable_scope('rnn') as vs:
      predictions = []

      rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)

      for t in xrange(FLAGS.sequence_length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()

        rnn_in = rnn_inputs[:, t]
        rnn_out, state = cell_dis(rnn_in, state)

        if FLAGS.attention_option is not None:
          rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                           attention_values)
        if is_training:
          rnn_out *= output_mask

        # Prediction is linear output for Discriminator.
        pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
        predictions.append(pred)

  predictions = tf.stack(predictions, axis=1)
  return tf.squeeze(predictions, axis=2)
Beispiel #5
0
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
  """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
  gen_decoder_rnn_size = hparams.gen_rnn_size

  targets = tf.Print(targets, [targets], message='targets', summarize=50)
  if FLAGS.seq2seq_share_embedding:
    with tf.variable_scope('decoder/rnn', reuse=True):
      embedding = tf.get_variable('embedding',
                                  [FLAGS.vocab_size, hparams.gen_rnn_size])

  with tf.variable_scope('decoder', reuse=reuse):

    def lstm_cell():
      return tf.contrib.rnn.BasicLSTMCell(
          gen_decoder_rnn_size,
          forget_bias=0.0,
          state_is_tuple=True,
          reuse=reuse)

    attn_cell = lstm_cell
    if is_training and hparams.gen_vd_keep_prob < 1:

      def attn_cell():
        return variational_dropout.VariationalDropoutWrapper(
            lstm_cell(), FLAGS.batch_size, hparams.gen_rnn_size,
            hparams.gen_vd_keep_prob, hparams.gen_vd_keep_prob)

    cell_gen = tf.contrib.rnn.MultiRNNCell(
        [attn_cell() for _ in range(hparams.gen_num_layers)],
        state_is_tuple=True)

    # Hidden encoder states.
    hidden_vector_encodings = encoding_state[0]

    # Carry forward the final state tuple from the encoder.
    # State tuples.
    state_gen = encoding_state[1]

    if FLAGS.attention_option is not None:
      (attention_keys, attention_values, _,
       attention_construct_fn) = attention_utils.prepare_attention(
           hidden_vector_encodings,
           FLAGS.attention_option,
           num_units=gen_decoder_rnn_size,
           reuse=reuse)

    def make_mask(keep_prob, units):
      random_tensor = keep_prob
      # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
      random_tensor += tf.random_uniform(tf.stack([FLAGS.batch_size, units]))
      return tf.floor(random_tensor) / keep_prob

    if is_training:
      output_mask = make_mask(hparams.gen_vd_keep_prob, hparams.gen_rnn_size)

    with tf.variable_scope('rnn'):
      sequence, logits, log_probs = [], [], []

      if not FLAGS.seq2seq_share_embedding:
        embedding = tf.get_variable('embedding',
                                    [FLAGS.vocab_size, hparams.gen_rnn_size])
      softmax_w = tf.matrix_transpose(embedding)
      softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

      rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)
      # TODO(adai): Perhaps append IMDB labels placeholder to input at
      # each time point.

      rnn_outs = []

      fake = None
      for t in xrange(FLAGS.sequence_length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()

        # Input to the Decoder.
        if t == 0:
          # Always provide the real input at t = 0.
          rnn_inp = rnn_inputs[:, t]

        # If the input is present, read in the input at t.
        # If the input is not present, read in the previously generated.
        else:
          real_rnn_inp = rnn_inputs[:, t]

          # While validating, the decoder should be operating in teacher
          # forcing regime.  Also, if we're just training with cross_entropy
          # use teacher forcing.
          if is_validating or FLAGS.gen_training_strategy == 'cross_entropy':
            rnn_inp = real_rnn_inp
          else:
            fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)
            rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
                               fake_rnn_inp)

        # RNN.
        rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

        if FLAGS.attention_option is not None:
          rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                           attention_values)
        if is_training:
          rnn_out *= output_mask

        rnn_outs.append(rnn_out)
        if FLAGS.gen_training_strategy != 'cross_entropy':
          logit = tf.nn.bias_add(tf.matmul(rnn_out, softmax_w), softmax_b)

          # Output for Decoder.
          # If input is present:   Return real at t+1.
          # If input is not present:  Return fake for t+1.
          real = targets[:, t]

          categorical = tf.contrib.distributions.Categorical(logits=logit)
          if FLAGS.use_gen_mode:
            fake = categorical.mode()
          else:
            fake = categorical.sample()
          log_prob = categorical.log_prob(fake)
          output = tf.where(targets_present[:, t], real, fake)

        else:
          real = targets[:, t]
          logit = tf.zeros(tf.stack([FLAGS.batch_size, FLAGS.vocab_size]))
          log_prob = tf.zeros(tf.stack([FLAGS.batch_size]))
          output = real

        # Add to lists.
        sequence.append(output)
        log_probs.append(log_prob)
        logits.append(logit)

      if FLAGS.gen_training_strategy == 'cross_entropy':
        logits = tf.nn.bias_add(
            tf.matmul(
                tf.reshape(tf.stack(rnn_outs, 1), [-1, gen_decoder_rnn_size]),
                softmax_w), softmax_b)
        logits = tf.reshape(logits,
                            [-1, FLAGS.sequence_length, FLAGS.vocab_size])
      else:
        logits = tf.stack(logits, axis=1)

  return (tf.stack(sequence, axis=1), logits, tf.stack(log_probs, axis=1))
Beispiel #6
0
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
    """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
    config = get_config()
    gen_decoder_rnn_size = hparams.gen_rnn_size

    if FLAGS.seq2seq_share_embedding:
        with tf.variable_scope('decoder/rnn', reuse=True):
            embedding = tf.get_variable(
                'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size])

    with tf.variable_scope('decoder', reuse=reuse):
        # Neural architecture search cell.
        cell = custom_cell.Alien(config.hidden_size)

        if is_training:
            [h2h_masks, _, _, output_mask
             ] = variational_dropout.generate_variational_dropout_masks(
                 hparams, config.keep_prob)
        else:
            output_mask = None

        cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)

        # Hidden encoder states.
        hidden_vector_encodings = encoding_state[0]

        # Carry forward the final state tuple from the encoder.
        # State tuples.
        state_gen = encoding_state[1]

        if FLAGS.attention_option is not None:
            (attention_keys, attention_values, _,
             attention_construct_fn) = attention_utils.prepare_attention(
                 hidden_vector_encodings,
                 FLAGS.attention_option,
                 num_units=gen_decoder_rnn_size,
                 reuse=reuse)

        with tf.variable_scope('rnn'):
            sequence, logits, log_probs = [], [], []

            if not FLAGS.seq2seq_share_embedding:
                embedding = tf.get_variable(
                    'embedding', [FLAGS.vocab_size, gen_decoder_rnn_size])
            softmax_w = tf.matrix_transpose(embedding)
            softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

            rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)

            if is_training and FLAGS.keep_prob < 1:
                rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)

            for t in xrange(FLAGS.sequence_length):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                # Input to the Decoder.
                if t == 0:
                    # Always provide the real input at t = 0.
                    rnn_inp = rnn_inputs[:, t]

                # If the input is present, read in the input at t.
                # If the input is not present, read in the previously generated.
                else:
                    real_rnn_inp = rnn_inputs[:, t]
                    fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)

                    # While validating, the decoder should be operating in teacher
                    # forcing regime.  Also, if we're just training with cross_entropy
                    # use teacher forcing.
                    if is_validating or (is_training
                                         and FLAGS.gen_training_strategy
                                         == 'cross_entropy'):
                        rnn_inp = real_rnn_inp
                    else:
                        rnn_inp = tf.where(targets_present[:, t - 1],
                                           real_rnn_inp, fake_rnn_inp)

                if is_training:
                    state_gen = list(state_gen)
                    for layer_num, per_layer_state in enumerate(state_gen):
                        per_layer_state = LSTMTuple(
                            per_layer_state[0],
                            per_layer_state[1] * h2h_masks[layer_num])
                        state_gen[layer_num] = per_layer_state

                # RNN.
                rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

                if is_training:
                    rnn_out = output_mask * rnn_out

                if FLAGS.attention_option is not None:
                    rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                                     attention_values)
                #   # TODO(liamfedus): Assert not "monotonic" attention_type.
                #   # TODO(liamfedus): FLAGS.attention_type.
                #   context_state = revised_attention_utils._empty_state()
                #   rnn_out, context_state = attention_construct_fn(
                #       rnn_out, attention_keys, attention_values, context_state, t)
                logit = tf.matmul(rnn_out, softmax_w) + softmax_b

                # Output for Decoder.
                # If input is present:   Return real at t+1.
                # If input is not present:  Return fake for t+1.
                real = targets[:, t]

                categorical = tf.contrib.distributions.Categorical(
                    logits=logit)
                fake = categorical.sample()
                log_prob = categorical.log_prob(fake)

                output = tf.where(targets_present[:, t], real, fake)

                # Add to lists.
                sequence.append(output)
                log_probs.append(log_prob)
                logits.append(logit)

    return (tf.stack(sequence,
                     axis=1), tf.stack(logits,
                                       axis=1), tf.stack(log_probs, axis=1))
Beispiel #7
0
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
  """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
  gen_decoder_rnn_size = hparams.gen_rnn_size

  with tf.variable_scope('decoder', reuse=reuse):

    def lstm_cell():
      return tf.contrib.rnn.LayerNormBasicLSTMCell(
          gen_decoder_rnn_size, reuse=reuse)

    attn_cell = lstm_cell
    if FLAGS.zoneout_drop_prob > 0.0:

      def attn_cell():
        return zoneout.ZoneoutWrapper(
            lstm_cell(),
            zoneout_drop_prob=FLAGS.zoneout_drop_prob,
            is_training=is_training)

    cell_gen = tf.contrib.rnn.MultiRNNCell(
        [attn_cell() for _ in range(hparams.gen_num_layers)],
        state_is_tuple=True)

    # Hidden encoder states.
    hidden_vector_encodings = encoding_state[0]

    # Carry forward the final state tuple from the encoder.
    # State tuples.
    state_gen = encoding_state[1]

    if FLAGS.attention_option is not None:
      (attention_keys, attention_values, _,
       attention_construct_fn) = attention_utils.prepare_attention(
           hidden_vector_encodings,
           FLAGS.attention_option,
           num_units=gen_decoder_rnn_size,
           reuse=reuse)

    with tf.variable_scope('rnn'):
      sequence, logits, log_probs = [], [], []
      embedding = tf.get_variable('embedding',
                                  [FLAGS.vocab_size, gen_decoder_rnn_size])
      softmax_w = tf.get_variable('softmax_w',
                                  [gen_decoder_rnn_size, FLAGS.vocab_size])
      softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

      rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)

      for t in xrange(FLAGS.sequence_length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()

        # Input to the Decoder.
        if t == 0:
          # Always provide the real input at t = 0.
          rnn_inp = rnn_inputs[:, t]

        # If the input is present, read in the input at t.
        # If the input is not present, read in the previously generated.
        else:
          real_rnn_inp = rnn_inputs[:, t]
          fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)

          # While validating, the decoder should be operating in teacher
          # forcing regime.  Also, if we're just training with cross_entropy
          # use teacher forcing.
          if is_validating or (is_training and
                               FLAGS.gen_training_strategy == 'cross_entropy'):
            rnn_inp = real_rnn_inp
          else:
            rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
                               fake_rnn_inp)

        # RNN.
        rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

        if FLAGS.attention_option is not None:
          rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                           attention_values)
        #   # TODO(liamfedus): Assert not "monotonic" attention_type.
        #   # TODO(liamfedus): FLAGS.attention_type.
        #   context_state = revised_attention_utils._empty_state()
        #   rnn_out, context_state = attention_construct_fn(
        #       rnn_out, attention_keys, attention_values, context_state, t)
        logit = tf.matmul(rnn_out, softmax_w) + softmax_b

        # Output for Decoder.
        # If input is present:   Return real at t+1.
        # If input is not present:  Return fake for t+1.
        real = targets[:, t]

        categorical = tf.contrib.distributions.Categorical(logits=logit)
        fake = categorical.sample()
        log_prob = categorical.log_prob(fake)

        output = tf.where(targets_present[:, t], real, fake)

        # Add to lists.
        sequence.append(output)
        log_probs.append(log_prob)
        logits.append(logit)

  return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
      log_probs, axis=1))
Beispiel #8
0
def gen_decoder(hparams,
                inputs,
                targets,
                targets_present,
                encoding_state,
                is_training,
                is_validating,
                reuse=None):
  """Define the Decoder graph. The Decoder will now impute tokens that
      have been masked from the input seqeunce.
  """
  config = get_config()
  gen_decoder_rnn_size = hparams.gen_rnn_size

  if FLAGS.seq2seq_share_embedding:
    with tf.variable_scope('decoder/rnn', reuse=True):
      embedding = tf.get_variable('embedding',
                                  [FLAGS.vocab_size, gen_decoder_rnn_size])

  with tf.variable_scope('decoder', reuse=reuse):
    # Neural architecture search cell.
    cell = custom_cell.Alien(config.hidden_size)

    if is_training:
      [h2h_masks, _, _,
       output_mask] = variational_dropout.generate_variational_dropout_masks(
           hparams, config.keep_prob)
    else:
      output_mask = None

    cell_gen = custom_cell.GenericMultiRNNCell([cell] * config.num_layers)

    # Hidden encoder states.
    hidden_vector_encodings = encoding_state[0]

    # Carry forward the final state tuple from the encoder.
    # State tuples.
    state_gen = encoding_state[1]

    if FLAGS.attention_option is not None:
      (attention_keys, attention_values, _,
       attention_construct_fn) = attention_utils.prepare_attention(
           hidden_vector_encodings,
           FLAGS.attention_option,
           num_units=gen_decoder_rnn_size,
           reuse=reuse)

    with tf.variable_scope('rnn'):
      sequence, logits, log_probs = [], [], []

      if not FLAGS.seq2seq_share_embedding:
        embedding = tf.get_variable('embedding',
                                    [FLAGS.vocab_size, gen_decoder_rnn_size])
      softmax_w = tf.matrix_transpose(embedding)
      softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size])

      rnn_inputs = tf.nn.embedding_lookup(embedding, inputs)

      if is_training and FLAGS.keep_prob < 1:
        rnn_inputs = tf.nn.dropout(rnn_inputs, FLAGS.keep_prob)

      for t in xrange(FLAGS.sequence_length):
        if t > 0:
          tf.get_variable_scope().reuse_variables()

        # Input to the Decoder.
        if t == 0:
          # Always provide the real input at t = 0.
          rnn_inp = rnn_inputs[:, t]

        # If the input is present, read in the input at t.
        # If the input is not present, read in the previously generated.
        else:
          real_rnn_inp = rnn_inputs[:, t]
          fake_rnn_inp = tf.nn.embedding_lookup(embedding, fake)

          # While validating, the decoder should be operating in teacher
          # forcing regime.  Also, if we're just training with cross_entropy
          # use teacher forcing.
          if is_validating or (is_training and
                               FLAGS.gen_training_strategy == 'cross_entropy'):
            rnn_inp = real_rnn_inp
          else:
            rnn_inp = tf.where(targets_present[:, t - 1], real_rnn_inp,
                               fake_rnn_inp)

        if is_training:
          state_gen = list(state_gen)
          for layer_num, per_layer_state in enumerate(state_gen):
            per_layer_state = LSTMTuple(
                per_layer_state[0], per_layer_state[1] * h2h_masks[layer_num])
            state_gen[layer_num] = per_layer_state

        # RNN.
        rnn_out, state_gen = cell_gen(rnn_inp, state_gen)

        if is_training:
          rnn_out = output_mask * rnn_out

        if FLAGS.attention_option is not None:
          rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                           attention_values)
        #   # TODO(liamfedus): Assert not "monotonic" attention_type.
        #   # TODO(liamfedus): FLAGS.attention_type.
        #   context_state = revised_attention_utils._empty_state()
        #   rnn_out, context_state = attention_construct_fn(
        #       rnn_out, attention_keys, attention_values, context_state, t)
        logit = tf.matmul(rnn_out, softmax_w) + softmax_b

        # Output for Decoder.
        # If input is present:   Return real at t+1.
        # If input is not present:  Return fake for t+1.
        real = targets[:, t]

        categorical = tf.contrib.distributions.Categorical(logits=logit)
        fake = categorical.sample()
        log_prob = categorical.log_prob(fake)

        output = tf.where(targets_present[:, t], real, fake)

        # Add to lists.
        sequence.append(output)
        log_probs.append(log_prob)
        logits.append(logit)

  return (tf.stack(sequence, axis=1), tf.stack(logits, axis=1), tf.stack(
      log_probs, axis=1))