Ejemplo n.º 1
0
def beam_decoder(features, mode, vocab, encoder_outputs, hps):
    """Beam search decoder.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    encoder_outputs: output tensors from the encoder
    hps: Hyperparams.

  Returns:
    Decoder outputs
  """
    assert mode is not tf.estimator.ModeKeys.TRAIN, "Not using beam in training."
    embeddings = encoder_outputs.embeddings
    mem_input = encoder_outputs.mem_input
    batch_size = tensor_utils.shape(features["src_len"], 0)

    src_len, src_inputs = features["src_len"], features["src_inputs"]
    src_mask = tf.sequence_mask(src_len, tf.shape(src_inputs)[1])

    if hps.att_neighbor:
        neighbor_len, neighbor_inputs \
            = features["neighbor_len"], features["neighbor_inputs"]
        neighbor_mask = tf.sequence_mask(neighbor_len,
                                         tf.shape(neighbor_inputs)[1])
        inputs = tf.concat([src_inputs, neighbor_inputs], 1)
        # lens = src_len + neighbor_len
        mask = tf.concat([src_mask, neighbor_mask], axis=1)
        lens = tf.shape(mask)[1] * tf.ones([batch_size], tf.int32)
    else:
        inputs = features["src_inputs"]
        lens = features["src_len"]
        mask = src_mask

    sparse_inputs = None
    float_mask = tf.cast(mask, dtype=tf.float32)
    if hps.use_copy:
        sparse_inputs = sparse_map(tf.expand_dims(inputs, axis=2),
                                   tf.expand_dims(float_mask, axis=2),
                                   vocab.size())
        sparse_inputs = sparse_tile_batch(sparse_inputs,
                                          multiplier=hps.beam_width)

    tiled_mask = tf.contrib.seq2seq.tile_batch(mask, multiplier=hps.beam_width)
    inputs = tf.contrib.seq2seq.tile_batch(inputs, multiplier=hps.beam_width)
    lens = tf.contrib.seq2seq.tile_batch(lens, multiplier=hps.beam_width)

    def _beam_decode(cell):
        """Beam decode."""
        with tf.variable_scope("beam_decoder"):

            initial_state = cell.zero_state(batch_size=batch_size *
                                            hps.beam_width,
                                            dtype=tf.float32)
            if hps.use_bridge:
                h_state, c_state = encoder_outputs.states
                h_state = tf.contrib.seq2seq.tile_batch(
                    h_state, multiplier=hps.beam_width)
                c_state = tf.contrib.seq2seq.tile_batch(
                    c_state, multiplier=hps.beam_width)
                initial_cell_state = tf.contrib.rnn.LSTMStateTuple(
                    h_state, c_state)
                initial_state = initial_state.clone(
                    cell_state=(initial_cell_state, ))

            decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell=cell,
                embedding=embeddings,
                start_tokens=tf.fill([batch_size],
                                     vocab.word2id(data.START_DECODING)),
                end_token=vocab.word2id(data.STOP_DECODING),
                initial_state=initial_state,
                beam_width=hps.beam_width,
                length_penalty_weight=hps.length_norm,
                coverage_penalty_weight=hps.cp)
        with tf.variable_scope("dynamic_decode", reuse=tf.AUTO_REUSE):
            decoder_outputs, _, decoder_len = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, maximum_iterations=hps.max_dec_steps)
        return decoder_outputs, decoder_len

    # [batch_size*beam_width, src_len, encoder_dim]
    att_context = tf.contrib.seq2seq.tile_batch(encoder_outputs.att_context,
                                                multiplier=hps.beam_width)
    copy_context = None
    if hps.use_copy:
        copy_context = tf.contrib.seq2seq.tile_batch(
            encoder_outputs.copy_context, multiplier=hps.beam_width)
    with tf.variable_scope("attention", reuse=tf.AUTO_REUSE):
        if hps.att_type == "luong":
            attention = tf.contrib.seq2seq.LuongAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "bahdanau":
            attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "hyper":
            attention = HyperAttention(num_units=hps.decoder_dim,
                                       mem_input=mem_input,
                                       hps=hps,
                                       memory=att_context,
                                       use_beam=True,
                                       memory_sequence_length=lens)
        elif hps.att_type == "my":
            attention = MyAttention(num_units=hps.decoder_dim,
                                    memory=att_context,
                                    memory_sequence_length=lens,
                                    mask=tiled_mask)

    with tf.variable_scope("rnn_decoder", reuse=tf.AUTO_REUSE):
        decoder_cell = get_rnn_cell(mode=mode,
                                    hps=hps,
                                    input_dim=hps.decoder_dim + hps.emb_dim,
                                    num_units=hps.decoder_dim,
                                    num_layers=hps.num_decoder_layers,
                                    dropout=hps.decoder_drop,
                                    mem_input=mem_input,
                                    use_beam=True,
                                    cell_type=hps.rnn_cell)
    with tf.variable_scope("attention_wrapper", reuse=tf.AUTO_REUSE):
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            decoder_cell,
            attention,
            attention_layer_size=hps.decoder_dim,
            alignment_history=hps.use_copy)

    with tf.variable_scope("output_projection", reuse=tf.AUTO_REUSE):
        weights = tf.transpose(embeddings) if hps.tie_embedding else None
        hidden_dim = hps.emb_dim if hps.tie_embedding else hps.decoder_dim
        decoder_cell = OutputWrapper(
            decoder_cell,
            num_layers=hps.num_mlp_layers,
            hidden_dim=hidden_dim,
            output_size=vocab.size() if hps.tie_embedding else hps.output_size,
            weights=weights,
            dropout=hps.out_drop,
            use_copy=hps.use_copy,
            encoder_emb=copy_context,
            sparse_inputs=sparse_inputs,
            mask=tf.cast(tiled_mask, dtype=tf.float32),
            hps=hps,
            mode=mode,
            reuse=tf.AUTO_REUSE)

    decoder_outputs, decoder_len = _beam_decode(decoder_cell)

    return DecoderOutputs(decoder_outputs=decoder_outputs,
                          decoder_len=decoder_len)
Ejemplo n.º 2
0
def basic_decoder(features, mode, vocab, encoder_outputs, hps):
    """Decoder.

  Args:
    features: Dictionary of input Tensors.
    mode: train or eval. Keys from tf.estimator.ModeKeys.
    vocab: A list of strings of words in the vocabulary.
    encoder_outputs: output tensors from the encoder
    hps: Hyperparams.

  Returns:
    Decoder outputs
  """

    embeddings = encoder_outputs.embeddings
    mem_input = encoder_outputs.mem_input
    batch_size = tensor_utils.shape(mem_input, 0)

    src_len, src_inputs = features["src_len"], features["src_inputs"]
    src_mask = tf.sequence_mask(src_len, tf.shape(src_inputs)[1])

    if hps.att_neighbor:
        neighbor_len, neighbor_inputs \
            = features["neighbor_len"], features["neighbor_inputs"]
        neighbor_mask = tf.sequence_mask(neighbor_len,
                                         tf.shape(neighbor_inputs)[1])
        inputs = tf.concat([src_inputs, neighbor_inputs], 1)
        lens = src_len + neighbor_len
        mask = tf.concat([src_mask, neighbor_mask], axis=1)
    else:
        inputs = features["src_inputs"]
        lens = features["src_len"]
        mask = src_mask

    sparse_inputs = None
    float_mask = tf.cast(mask, dtype=tf.float32)

    if hps.use_copy:
        sparse_inputs = sparse_map(tf.expand_dims(inputs, axis=2),
                                   tf.expand_dims(float_mask, axis=2),
                                   vocab.size())
    # [batch_size, dec_len]
    decoder_inputs = features["decoder_inputs"]

    # [batch_size, dec_len, emb_dim]
    decoder_input_emb = tf.nn.embedding_lookup(embeddings, decoder_inputs)
    if mode == tf.estimator.ModeKeys.TRAIN and hps.emb_drop > 0.:
        decoder_input_emb = tf.nn.dropout(decoder_input_emb,
                                          keep_prob=1.0 - hps.emb_drop)

    def _decode(cell, helper):
        """Decode function.

    Args:
      cell: rnn cell
      helper: a helper instance from tf.contrib.seq2seq.

    Returns:
      decoded outputs and lengths.
    """
        with tf.variable_scope("decoder"):
            initial_state = cell.zero_state(batch_size, tf.float32)
            if hps.use_bridge:
                h_state, c_state = encoder_outputs.states
                initial_cell_state = tf.contrib.rnn.LSTMStateTuple(
                    h_state, c_state)
                initial_state = initial_state.clone(
                    cell_state=(initial_cell_state, ))

            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=cell, helper=helper, initial_state=initial_state)
        with tf.variable_scope("dynamic_decode", reuse=tf.AUTO_REUSE):
            decoder_outputs, _, decoder_len = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder, maximum_iterations=hps.max_dec_steps)
        return decoder_outputs, decoder_len

    att_context = encoder_outputs.att_context
    with tf.variable_scope("attention"):
        if hps.att_type == "luong":
            attention = tf.contrib.seq2seq.LuongAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "bahdanau":
            attention = tf.contrib.seq2seq.BahdanauAttention(
                num_units=hps.decoder_dim,
                memory=att_context,
                memory_sequence_length=lens)
        elif hps.att_type == "hyper":
            attention = HyperAttention(num_units=hps.decoder_dim,
                                       mem_input=mem_input,
                                       hps=hps,
                                       memory=att_context,
                                       use_beam=False,
                                       memory_sequence_length=lens)
        elif hps.att_type == "my":
            attention = MyAttention(num_units=hps.decoder_dim,
                                    memory=att_context,
                                    memory_sequence_length=lens,
                                    mask=mask)

    with tf.variable_scope("rnn_decoder"):
        decoder_cell = get_rnn_cell(mode=mode,
                                    hps=hps,
                                    input_dim=hps.decoder_dim + hps.emb_dim,
                                    num_units=hps.decoder_dim,
                                    num_layers=hps.num_decoder_layers,
                                    dropout=hps.decoder_drop,
                                    mem_input=mem_input,
                                    use_beam=False,
                                    cell_type=hps.rnn_cell)
    with tf.variable_scope("attention_wrapper"):
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
            decoder_cell,
            attention,
            attention_layer_size=hps.decoder_dim,
            alignment_history=hps.use_copy)

    with tf.variable_scope("output_projection"):
        weights = tf.transpose(embeddings) if hps.tie_embedding else None
        hidden_dim = hps.emb_dim if hps.tie_embedding else hps.decoder_dim
        decoder_cell = OutputWrapper(
            decoder_cell,
            num_layers=hps.num_mlp_layers,
            hidden_dim=hidden_dim,
            output_size=hps.output_size,
            #  output_size=vocab.size(),
            weights=weights,
            dropout=hps.out_drop,
            use_copy=hps.use_copy,
            encoder_emb=encoder_outputs.copy_context,
            sparse_inputs=sparse_inputs,
            mask=float_mask,
            mode=mode,
            hps=hps)

    if mode == tf.estimator.ModeKeys.TRAIN:
        if hps.sampling_probability > 0.:
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs=decoder_input_emb,
                sequence_length=features["decoder_len"],
                embedding=embeddings,
                sampling_probability=hps.sampling_probability)
        else:
            helper = tf.contrib.seq2seq.TrainingHelper(decoder_input_emb,
                                                       features["decoder_len"])
        decoder_outputs, _ = _decode(decoder_cell, helper=helper)
        return DecoderOutputs(decoder_outputs=decoder_outputs,
                              decoder_len=features["decoder_len"]), None

    # used to compute loss
    teacher_helper = tf.contrib.seq2seq.TrainingHelper(decoder_input_emb,
                                                       features["decoder_len"])
    teacher_decoder_outputs, _ = _decode(decoder_cell, helper=teacher_helper)

    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
        embedding=embeddings,
        start_tokens=tf.fill([batch_size], vocab.word2id(data.START_DECODING)),
        end_token=vocab.word2id(data.STOP_DECODING))
    decoder_outputs, decoder_len = _decode(decoder_cell, helper=helper)
    return DecoderOutputs(
        decoder_outputs=decoder_outputs,
        decoder_len=decoder_len), \
           DecoderOutputs(
               decoder_outputs=teacher_decoder_outputs,
               decoder_len=features["decoder_len"]
           )