Ejemplo n.º 1
0
def encoder(features, mode, vocab, hps):
    """Model function.

  Atttention seq2seq model, augmented with an encoder
  over the targets of the nearest neighbors.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    hps: Hyperparams.

  Returns:
    Encoder outputs.
  """

    # [batch_size, src_len]
    src_inputs = features["src_inputs"]
    src_len = features["src_len"]

    with tf.variable_scope("embeddings"):
        scale = (3.0 / hps.emb_dim)**0.5
        embeddings = tf.get_variable("embeddings", [vocab.size(), hps.emb_dim],
                                     dtype=tf.float32,
                                     initializer=tf.random_uniform_initializer(
                                         minval=-scale, maxval=scale))

    # [batch_size, src_len, emb_dim]
    src_input_emb = tf.nn.embedding_lookup(embeddings, src_inputs)

    if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
        src_input_emb = tf.nn.dropout(src_input_emb,
                                      keep_prob=1.0 - hps.emb_drop)
    src_att_context, neighbor_att_context = None, None
    src_copy_context, neighbor_copy_context = None, None
    with tf.variable_scope("src_encoder"):

        # 2 * [batch_size, src_len, encoder_dim]
        src_encoder_outputs, src_encoder_states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=get_rnn_cell(mode=mode,
                                 hps=hps,
                                 input_dim=hps.emb_dim,
                                 num_units=hps.encoder_dim,
                                 num_layers=hps.num_encoder_layers,
                                 dropout=hps.encoder_drop,
                                 cell_type="lstm"),
            cell_bw=get_rnn_cell(mode=mode,
                                 hps=hps,
                                 input_dim=hps.emb_dim,
                                 num_units=hps.encoder_dim,
                                 num_layers=hps.num_encoder_layers,
                                 dropout=hps.encoder_drop,
                                 cell_type="lstm"),
            inputs=src_input_emb,
            dtype=tf.float32,
            sequence_length=src_len)

        # [batch_size, src_len, 2*encoder_dim]
        src_encoder_outputs = tf.concat(src_encoder_outputs, 2)

        with tf.variable_scope("src_att_context"):
            src_att_context = _build_context(
                hps=hps, encoder_outputs=src_encoder_outputs)
        if hps.use_copy:
            with tf.variable_scope("src_copy_context"):
                src_copy_context = _build_context(
                    hps=hps, encoder_outputs=src_encoder_outputs)

    if hps.encode_neighbor or hps.att_neighbor or hps.sum_neighbor:
        # [batch_size, neighbor_len]
        neighbor_inputs = features["neighbor_inputs"]
        neighbor_len = features["neighbor_len"]

        # [batch_size, neighbor_len, emb_dim]
        neighbor_input_emb = tf.nn.embedding_lookup(embeddings,
                                                    neighbor_inputs)

        if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
            neighbor_input_emb = tf.nn.dropout(neighbor_input_emb,
                                               keep_prob=1.0 - hps.emb_drop)
        if hps.binary_neighbor:
            neighbor_binary_input = features["neighbor_binary"]
            if hps.binary_dim == 1:
                neighbor_binary_emb = tf.to_float(neighbor_binary_input)
                neighbor_binary_emb = tf.expand_dims(neighbor_binary_emb,
                                                     axis=-1)
            else:
                with tf.variable_scope("binary_emb"):
                    scale = (3.0 / hps.binary_dim)**0.5
                    binary_embeddings = tf.get_variable(
                        "binary_emb", [2, hps.binary_dim],
                        dtype=tf.float32,
                        initializer=tf.random_uniform_initializer(
                            minval=-scale, maxval=scale))
                neighbor_binary_emb = tf.nn.embedding_lookup(
                    binary_embeddings, neighbor_binary_input)

            neighbor_input_emb = tf.concat(
                [neighbor_input_emb, neighbor_binary_emb], axis=2)

        with tf.variable_scope("neighbor_encoder"):
            # 2 * [batch_size, neighbor_len, encoder_dim]
            input_dim = hps.emb_dim
            if hps.binary_neighbor:
                input_dim += hps.binary_dim
            neighbor_encoder_outputs, neighbor_encoder_states = \
                tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=get_rnn_cell(
                        mode=mode, hps=hps,
                        input_dim=input_dim,
                        num_units=hps.neighbor_dim,
                        num_layers=1,
                        dropout=hps.encoder_drop,
                        cell_type="lstm"),
                    cell_bw=get_rnn_cell(
                        mode=mode, hps=hps,
                        input_dim=input_dim,
                        num_units=hps.neighbor_dim,
                        num_layers=1,
                        dropout=hps.encoder_drop,
                        cell_type="lstm"),
                    inputs=neighbor_input_emb,
                    dtype=tf.float32,
                    sequence_length=neighbor_len)
            # [batch_size, neighbor_len, 2*encoder_dim]
            neighbor_encoder_outputs = tf.concat(neighbor_encoder_outputs, 2)
            if hps.att_neighbor:
                with tf.variable_scope("neighbor_att_context"):
                    neighbor_att_context = _build_context(
                        hps=hps, encoder_outputs=neighbor_encoder_outputs)
                if hps.use_copy:
                    with tf.variable_scope("neighbor_copy_context"):
                        neighbor_copy_context = _build_context(
                            hps=hps, encoder_outputs=neighbor_encoder_outputs)
    att_context, copy_context = None, None
    if hps.att_neighbor:
        att_context = tf.concat([src_att_context, neighbor_att_context], 1)
        if hps.use_copy:
            copy_context = tf.concat([src_copy_context, neighbor_copy_context],
                                     1)
    else:
        att_context = src_att_context
        if hps.use_copy:
            copy_context = src_copy_context
    if hps.encode_neighbor:
        neighbor_fw_states, neighbor_bw_states = neighbor_encoder_states
        neighbor_h = tf.concat(
            [neighbor_fw_states[-1].h, neighbor_bw_states[-1].h], axis=1)
        if mode == tf.estimator.ModeKeys.TRAIN and hps.drop > 0.:
            neighbor_h = tf.nn.dropout(neighbor_h, keep_prob=1.0 - hps.drop)
        mem_input = tf.layers.dense(
            neighbor_h,
            units=hps.decoder_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="mem_input")
        if mode == tf.estimator.ModeKeys.TRAIN and hps.drop > 0.:
            mem_input = tf.nn.dropout(mem_input, keep_prob=1.0 - hps.drop)
    elif hps.sum_neighbor:

        src_fw_states, src_bw_states = src_encoder_states
        src_h = tf.concat([src_fw_states[-1].h, src_bw_states[-1].h], axis=1)

        if mode == tf.estimator.ModeKeys.TRAIN and hps.drop > 0.:
            src_h = tf.nn.dropout(src_h, keep_prob=1.0 - hps.drop)
        src_h = tf.layers.dense(
            src_h,
            units=hps.decoder_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="proj_src_h")
        neighbor_encoder_outputs = tf.layers.dense(
            neighbor_encoder_outputs,
            units=hps.decoder_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="proj_neighbor_out")

        alpha = tf.tensordot(neighbor_encoder_outputs, src_h, axes=1)
        alpha = tf.einsum("bij,bj->bi", neighbor_encoder_outputs, src_h)
        alpha = tf.nn.softmax(alpha)
        mem_input = tf.reduce_sum(
            neighbor_encoder_outputs * tf.expand_dims(alpha, -1), 1)
        # mem_input = tf.reduce_mean(mem_input, axis=1)
        if mode == tf.estimator.ModeKeys.TRAIN and hps.drop > 0.:
            mem_input = tf.nn.dropout(mem_input, keep_prob=1.0 - hps.drop)
    else:
        assert hps.rnn_cell != "hyper_lstm"
        assert hps.att_type != "hyper"
        mem_input = None

    if hps.use_bridge:
        with tf.variable_scope("bridge"):
            out_dim = hps.num_decoder_layers * hps.decoder_dim
            fw_states, bw_states = src_encoder_states
            c_states, h_states = [], []
            for (fw, bw) in zip(fw_states, bw_states):
                c_states.append(tf.concat((fw.c, bw.c), 1))
                h_states.append(tf.concat((fw.h, bw.h), 1))
            cs, hs = c_states[-1], h_states[-1]

            if mode == tf.estimator.ModeKeys.TRAIN and hps.drop > 0.:
                hs = tf.nn.dropout(hs, keep_prob=1.0 - hps.drop)
                cs = tf.nn.dropout(cs, keep_prob=1.0 - hps.drop)

            h_state = tf.layers.dense(
                hs,
                units=out_dim,
                activation=tf.nn.tanh,
                use_bias=True,
                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                name="h_layer")
            c_state = tf.layers.dense(
                cs,
                units=out_dim,
                activation=tf.nn.tanh,
                use_bias=True,
                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                name="c_layer")
    else:
        h_state, c_state = None, None
    # print (att_context)
    # dsadsa
    return EncoderOutputs(embeddings=embeddings,
                          mem_input=mem_input,
                          att_context=att_context,
                          copy_context=copy_context,
                          states=(h_state, c_state))
Ejemplo n.º 2
0
def basic_decoder(features, mode, vocab, encoder_outputs, hps):
    """Decoder.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    encoder_outputs: output tensors from the encoder
    hps: Hyperparams.

  Returns:
    Decoder outputs
  """

    embeddings = encoder_outputs.embeddings
    mem_input = encoder_outputs.mem_input
    batch_size = tensor_utils.shape(mem_input, 0)

    src_len, src_inputs = features["src_len"], features["src_inputs"]
    src_mask = tf.sequence_mask(src_len, tf.shape(src_inputs)[1])

    if hps.att_neighbor:
        neighbor_len, neighbor_inputs \
            = features["neighbor_len"], features["neighbor_inputs"]
        neighbor_mask = tf.sequence_mask(neighbor_len,
                                         tf.shape(neighbor_inputs)[1])
        inputs = tf.concat([src_inputs, neighbor_inputs], 1)
        lens = src_len + neighbor_len
        mask = tf.concat([src_mask, neighbor_mask], axis=1)
    else:
        inputs = features["src_inputs"]
        lens = features["src_len"]
        mask = src_mask

    sparse_inputs = None
    float_mask = tf.cast(mask, dtype=tf.float32)

    if hps.use_copy:
        sparse_inputs = sparse_map(tf.expand_dims(inputs, axis=2),
                                   tf.expand_dims(float_mask, axis=2),
                                   vocab.size())
    # [batch_size, dec_len]
    decoder_inputs = features["decoder_inputs"]

    # [batch_size, dec_len, emb_dim]
    decoder_input_emb = tf.nn.embedding_lookup(embeddings, decoder_inputs)
    if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
        decoder_input_emb = tf.nn.dropout(decoder_input_emb,
                                          keep_prob=1.0 - hps.emb_drop)

    def _decode(cell, helper):
        """Decode function.

    Args:
      cell: rnn cell
      helper: a helper instance from tf.contrib.seq2seq.

    Returns:
      decoded outputs and lengths.
    """
        with tf.variable_scope("decoder"):
            initial_state = cell.zero_state(batch_size, tf.float32)
            if hps.use_bridge:
                h_state, c_state = encoder_outputs.states
                initial_cell_state = tf.contrib.rnn.LSTMStateTuple(
                    h_state, c_state)
                initial_state = initial_state.clone(
                    cell_state=(initial_cell_state, ))

            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=cell, helper=helper, initial_state=initial_state)
        with tf.variable_scope("dynamic_decode", reuse=tf.AUTO_REUSE):
            decoder_outputs, _, decoder_len = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, maximum_iterations=hps.max_dec_steps)
        return decoder_outputs, decoder_len

    att_context = encoder_outputs.att_context
    with tf.variable_scope("attention"):
        if hps.att_type == "luong":
            attention = tf.contrib.seq2seq.LuongAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "bahdanau":
            attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "hyper":
            attention = HyperAttention(num_units=hps.decoder_dim,
                                       mem_input=mem_input,
                                       hps=hps,
                                       memory=att_context,
                                       use_beam=False,
                                       memory_sequence_length=lens)
        elif hps.att_type == "my":
            attention = MyAttention(num_units=hps.decoder_dim,
                                    memory=att_context,
                                    memory_sequence_length=lens,
                                    mask=mask)

    with tf.variable_scope("rnn_decoder"):
        decoder_cell = get_rnn_cell(mode=mode,
                                    hps=hps,
                                    input_dim=hps.decoder_dim + hps.emb_dim,
                                    num_units=hps.decoder_dim,
                                    num_layers=hps.num_decoder_layers,
                                    dropout=hps.decoder_drop,
                                    mem_input=mem_input,
                                    use_beam=False,
                                    cell_type=hps.rnn_cell)
    with tf.variable_scope("attention_wrapper"):
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            decoder_cell,
            attention,
            attention_layer_size=hps.decoder_dim,
            alignment_history=hps.use_copy)

    with tf.variable_scope("output_projection"):
        weights = tf.transpose(embeddings) if hps.tie_embedding else None
        hidden_dim = hps.emb_dim if hps.tie_embedding else hps.decoder_dim
        decoder_cell = OutputWrapper(
            decoder_cell,
            num_layers=hps.num_mlp_layers,
            hidden_dim=hidden_dim,
            output_size=hps.output_size,
            #  output_size=vocab.size(),
            weights=weights,
            dropout=hps.out_drop,
            use_copy=hps.use_copy,
            encoder_emb=encoder_outputs.copy_context,
            sparse_inputs=sparse_inputs,
            mask=float_mask,
            mode=mode,
            hps=hps)

    if mode == tf.estimator.ModeKeys.TRAIN:
        if hps.sampling_probability > 0.:
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs=decoder_input_emb,
                sequence_length=features["decoder_len"],
                embedding=embeddings,
                sampling_probability=hps.sampling_probability)
        else:
            helper = tf.contrib.seq2seq.TrainingHelper(decoder_input_emb,
                                                       features["decoder_len"])
        decoder_outputs, _ = _decode(decoder_cell, helper=helper)
        return DecoderOutputs(decoder_outputs=decoder_outputs,
                              decoder_len=features["decoder_len"]), None

    # used to compute loss
    teacher_helper = tf.contrib.seq2seq.TrainingHelper(decoder_input_emb,
                                                       features["decoder_len"])
    teacher_decoder_outputs, _ = _decode(decoder_cell, helper=teacher_helper)

    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding=embeddings,
        start_tokens=tf.fill([batch_size], vocab.word2id(data.START_DECODING)),
        end_token=vocab.word2id(data.STOP_DECODING))
    decoder_outputs, decoder_len = _decode(decoder_cell, helper=helper)
    return DecoderOutputs(
        decoder_outputs=decoder_outputs,
        decoder_len=decoder_len), \
           DecoderOutputs(
               decoder_outputs=teacher_decoder_outputs,
               decoder_len=features["decoder_len"]
           )
Ejemplo n.º 3
0
def beam_decoder(features, mode, vocab, encoder_outputs, hps):
    """Beam search decoder.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    encoder_outputs: output tensors from the encoder
    hps: Hyperparams.

  Returns:
    Decoder outputs
  """
    assert mode is not tf.estimator.ModeKeys.TRAIN, "Not using beam in training."
    embeddings = encoder_outputs.embeddings
    mem_input = encoder_outputs.mem_input
    batch_size = tensor_utils.shape(features["src_len"], 0)

    src_len, src_inputs = features["src_len"], features["src_inputs"]
    src_mask = tf.sequence_mask(src_len, tf.shape(src_inputs)[1])

    if hps.att_neighbor:
        neighbor_len, neighbor_inputs \
            = features["neighbor_len"], features["neighbor_inputs"]
        neighbor_mask = tf.sequence_mask(neighbor_len,
                                         tf.shape(neighbor_inputs)[1])
        inputs = tf.concat([src_inputs, neighbor_inputs], 1)
        # lens = src_len + neighbor_len
        mask = tf.concat([src_mask, neighbor_mask], axis=1)
        lens = tf.shape(mask)[1] * tf.ones([batch_size], tf.int32)
    else:
        inputs = features["src_inputs"]
        lens = features["src_len"]
        mask = src_mask

    sparse_inputs = None
    float_mask = tf.cast(mask, dtype=tf.float32)
    if hps.use_copy:
        sparse_inputs = sparse_map(tf.expand_dims(inputs, axis=2),
                                   tf.expand_dims(float_mask, axis=2),
                                   vocab.size())
        sparse_inputs = sparse_tile_batch(sparse_inputs,
                                          multiplier=hps.beam_width)

    tiled_mask = tf.contrib.seq2seq.tile_batch(mask, multiplier=hps.beam_width)
    inputs = tf.contrib.seq2seq.tile_batch(inputs, multiplier=hps.beam_width)
    lens = tf.contrib.seq2seq.tile_batch(lens, multiplier=hps.beam_width)

    def _beam_decode(cell):
        """Beam decode."""
        with tf.variable_scope("beam_decoder"):

            initial_state = cell.zero_state(batch_size=batch_size *
                                            hps.beam_width,
                                            dtype=tf.float32)
            if hps.use_bridge:
                h_state, c_state = encoder_outputs.states
                h_state = tf.contrib.seq2seq.tile_batch(
                    h_state, multiplier=hps.beam_width)
                c_state = tf.contrib.seq2seq.tile_batch(
                    c_state, multiplier=hps.beam_width)
                initial_cell_state = tf.contrib.rnn.LSTMStateTuple(
                    h_state, c_state)
                initial_state = initial_state.clone(
                    cell_state=(initial_cell_state, ))

            decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell=cell,
                embedding=embeddings,
                start_tokens=tf.fill([batch_size],
                                     vocab.word2id(data.START_DECODING)),
                end_token=vocab.word2id(data.STOP_DECODING),
                initial_state=initial_state,
                beam_width=hps.beam_width,
                length_penalty_weight=hps.length_norm,
                coverage_penalty_weight=hps.cp)
        with tf.variable_scope("dynamic_decode", reuse=tf.AUTO_REUSE):
            decoder_outputs, _, decoder_len = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, maximum_iterations=hps.max_dec_steps)
        return decoder_outputs, decoder_len

    # [batch_size*beam_width, src_len, encoder_dim]
    att_context = tf.contrib.seq2seq.tile_batch(encoder_outputs.att_context,
                                                multiplier=hps.beam_width)
    copy_context = None
    if hps.use_copy:
        copy_context = tf.contrib.seq2seq.tile_batch(
            encoder_outputs.copy_context, multiplier=hps.beam_width)
    with tf.variable_scope("attention", reuse=tf.AUTO_REUSE):
        if hps.att_type == "luong":
            attention = tf.contrib.seq2seq.LuongAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "bahdanau":
            attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "hyper":
            attention = HyperAttention(num_units=hps.decoder_dim,
                                       mem_input=mem_input,
                                       hps=hps,
                                       memory=att_context,
                                       use_beam=True,
                                       memory_sequence_length=lens)
        elif hps.att_type == "my":
            attention = MyAttention(num_units=hps.decoder_dim,
                                    memory=att_context,
                                    memory_sequence_length=lens,
                                    mask=tiled_mask)

    with tf.variable_scope("rnn_decoder", reuse=tf.AUTO_REUSE):
        decoder_cell = get_rnn_cell(mode=mode,
                                    hps=hps,
                                    input_dim=hps.decoder_dim + hps.emb_dim,
                                    num_units=hps.decoder_dim,
                                    num_layers=hps.num_decoder_layers,
                                    dropout=hps.decoder_drop,
                                    mem_input=mem_input,
                                    use_beam=True,
                                    cell_type=hps.rnn_cell)
    with tf.variable_scope("attention_wrapper", reuse=tf.AUTO_REUSE):
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            decoder_cell,
            attention,
            attention_layer_size=hps.decoder_dim,
            alignment_history=hps.use_copy)

    with tf.variable_scope("output_projection", reuse=tf.AUTO_REUSE):
        weights = tf.transpose(embeddings) if hps.tie_embedding else None
        hidden_dim = hps.emb_dim if hps.tie_embedding else hps.decoder_dim
        decoder_cell = OutputWrapper(
            decoder_cell,
            num_layers=hps.num_mlp_layers,
            hidden_dim=hidden_dim,
            output_size=vocab.size() if hps.tie_embedding else hps.output_size,
            weights=weights,
            dropout=hps.out_drop,
            use_copy=hps.use_copy,
            encoder_emb=copy_context,
            sparse_inputs=sparse_inputs,
            mask=tf.cast(tiled_mask, dtype=tf.float32),
            hps=hps,
            mode=mode,
            reuse=tf.AUTO_REUSE)

    decoder_outputs, decoder_len = _beam_decode(decoder_cell)

    return DecoderOutputs(decoder_outputs=decoder_outputs,
                          decoder_len=decoder_len)
Ejemplo n.º 4
0
def encoder(features, mode, vocab, hps):
  """Model function.

  Atttention seq2seq model, augmented with an encoder
  over the targets of the nearest neighbors.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    hps: Hyperparams.

  Returns:
    Encoder outputs.
  """

  # [batch_size, src_len]
  src_inputs = features["src_inputs"]
  src_len = features["src_len"]

  with tf.variable_scope("embeddings"):
    embeddings = tf.get_variable(
        "embeddings",
        [vocab.size(), hps.emb_dim],
        dtype=tf.float32,
        initializer=tf.uniform_unit_scaling_initializer())

  # [batch_size, src_len, emb_dim]
  src_encoder_input_emb = tf.nn.embedding_lookup(embeddings, src_inputs)
  if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
    src_encoder_input_emb = tf.nn.dropout(
        src_encoder_input_emb, keep_prob=1.0-hps.emb_drop)

  src_att_context, neighbor_att_context = None, None
  src_copy_context, neighbor_copy_context = None, None
  with tf.variable_scope("src_encoder"):

    # 2 * [batch_size, src_len, encoder_dim]
    src_encoder_outputs, src_encoder_states = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=get_rnn_cell(
            mode=mode, hps=hps,
            input_dim=hps.emb_dim,
            num_units=hps.encoder_dim,
            num_layers=hps.num_encoder_layers,
            dropout=hps.encoder_drop,
            cell_type=hps.rnn_cell),
        cell_bw=get_rnn_cell(
            mode=mode, hps=hps,
            input_dim=hps.emb_dim,
            num_units=hps.encoder_dim,
            num_layers=hps.num_encoder_layers,
            dropout=hps.encoder_drop,
            cell_type=hps.rnn_cell),
        inputs=src_encoder_input_emb,
        dtype=tf.float32,
        sequence_length=src_len)

    # [batch_size, src_len, 2*encoder_dim]
    src_encoder_outputs = tf.concat(src_encoder_outputs, 2)
    with tf.variable_scope("src_att_context"):
      src_att_context = _build_context(
          hps=hps,
          encoder_outputs=src_encoder_outputs)
    if hps.use_copy:
      with tf.variable_scope("src_copy_context"):
        src_copy_context = _build_context(
            hps=hps,
            encoder_outputs=src_encoder_outputs)

  if hps.model == "nn2seq":

    # [batch_size, neighbor_len]
    neighbor_inputs = features["neighbor_inputs"]
    neighbor_len = features["neighbor_len"]

    # [batch_size, neighbor_len, emb_dim]
    neighbor_input_emb = tf.nn.embedding_lookup(
        embeddings, neighbor_inputs)
    if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
      neighbor_input_emb = tf.nn.dropout(
          neighbor_input_emb, keep_prob=1.0-hps.emb_drop)
    with tf.variable_scope("neighbor_encoder"):
      # 2 * [batch_size, neighbor_len, encoder_dim]
      neighbor_encoder_outputs, _ = \
          tf.nn.bidirectional_dynamic_rnn(
              cell_fw=get_rnn_cell(
                  mode=mode, hps=hps,
                  input_dim=hps.emb_dim,
                  num_units=hps.encoder_dim,
                  num_layers=1,
                  dropout=hps.encoder_drop,
                  cell_type=hps.rnn_cell),
              cell_bw=get_rnn_cell(
                  mode=mode, hps=hps,
                  input_dim=hps.emb_dim,
                  num_units=hps.encoder_dim,
                  num_layers=1,
                  dropout=hps.encoder_drop,
                  cell_type=hps.rnn_cell),
              inputs=neighbor_input_emb,
              dtype=tf.float32,
              sequence_length=neighbor_len)

      neighbor_encoder_outputs = tf.concat(neighbor_encoder_outputs, 2)
      with tf.variable_scope("neighbor_att_context"):
        neighbor_att_context = _build_context(
            hps=hps,
            encoder_outputs=neighbor_encoder_outputs)
      if hps.use_copy:
        with tf.variable_scope("neighbor_copy_context"):
          neighbor_copy_context = _build_context(
              hps=hps,
              encoder_outputs=neighbor_encoder_outputs)
  att_context, copy_context = None, None
  if hps.model == "nn2seq":
    att_context = tf.concat([src_att_context, neighbor_att_context], 1)
    if hps.use_copy:
      copy_context = tf.concat(
          [src_copy_context, neighbor_copy_context], 1)
  elif hps.model == "seq2seq":
    att_context = src_att_context
    if hps.use_copy:
      copy_context = src_copy_context
  else:
    assert False, "baseline `model` should be [`nn2seq`, `seq2seq`]."
  if hps.use_bridge:
    with tf.variable_scope("bridge"):
      out_dim = hps.num_decoder_layers * hps.decoder_dim
      if hps.rnn_cell == "lstm":
        fw_states, bw_states = src_encoder_states
        hs = tf.concat([fw_states[-1].h, bw_states[-1].h], axis=1)
        cs = tf.concat([fw_states[-1].c, bw_states[-1].c], axis=1)

        h_state = tf.layers.dense(
            hs, units=out_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="h_layer")
        c_state = tf.layers.dense(
            cs, units=out_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="c_layer")
      elif hps.rnn_cell == "gru":
        fw_states, bw_states = src_encoder_states
        hs = tf.concat([fw_states[-1], bw_states[-1]], axis=1)
        h_state = tf.layers.dense(
            hs, units=out_dim,
            activation=tf.nn.tanh,
            use_bias=True,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            name="h_layer")
        c_state = None
  else:
    h_state, c_state = None, None

  return EncoderOutputs(
      embeddings=embeddings,
      att_context=att_context,
      copy_context=copy_context,
      states=(h_state, c_state)
  )