Example #1
0
def build_attention_model(params,
                          src_vocab,
                          trg_vocab,
                          source_placeholders,
                          target_placeholders,
                          beam_size=1,
                          mode=MODE.TRAIN,
                          burn_in_step=100000,
                          increment_step=10000,
                          teacher_rate=1.0,
                          max_step=100):
    """
    Build a model.

    :param params: dict.
     {encoder: {rnn_cell: {},
                ...},
      decoder: {rnn_cell: {},
                ...}}
      for example:
        {'encoder': {'rnn_cell': {'state_size': 512,
                                   'cell_name': 'BasicLSTMCell',
                                   'num_layers': 2,
                                   'input_keep_prob': 1.0,
                                   'output_keep_prob': 1.0},
                      'attention_key_size': attention_size},
        'decoder':  {'rnn_cell': {'cell_name': 'BasicLSTMCell',
                                   'state_size': 512,
                                   'num_layers': 1,
                                   'input_keep_prob': 1.0,
                                   'output_keep_prob': 1.0},
                      'trg_vocab_size': trg_vocab_size}}
    :param src_vocab: Vocab of source symbols.
    :param trg_vocab: Vocab of target symbols.
    :param source_ids: placeholder
    :param source_seq_length: placeholder
    :param target_ids: placeholder
    :param target_seq_length: placeholder
    :param beam_size: used in beam inference
    :param mode:
    :return:
    """
    if mode != MODE.TRAIN:
        params = sq.disable_dropout(params)

    tf.logging.info(json.dumps(params, indent=4))

    decoder_params = params['decoder']
    # parameters
    source_ids = source_placeholders['src']
    source_seq_length = source_placeholders['src_len']
    source_sample_matrix = source_placeholders['src_sample_matrix']
    source_word_seq_length = source_placeholders['src_word_len']

    target_ids = target_placeholders['trg']
    target_seq_length = target_placeholders['trg_len']

    # Because source encoder is different to the target feedback,
    # we construct source_embedding_table manually
    source_char_embedding_table = sq.LookUpOp(src_vocab.vocab_size,
                                              src_vocab.embedding_dim,
                                              name='source')
    source_char_embedded = source_char_embedding_table(source_ids)

    # encode char to word
    char_encoder = sq.StackRNNEncoder(params['char_encoder'],
                                      params['attention_key_size']['char'],
                                      name='char_rnn',
                                      mode=mode)

    # char_encoder_outputs: T_c B F
    char_encoded_representation = char_encoder.encode(source_char_embedded,
                                                      source_seq_length)
    char_encoder_outputs = char_encoded_representation.outputs
    #dynamical_batch_size = tf.shape(char_encoder_outputs)[1]
    #space_indices = tf.where(tf.equal(tf.transpose(source_ids), src_vocab.space_id))
    ##space_indices = tf.transpose(tf.gather_nd(tf.transpose(space_indices), [[1], [0]]))
    #space_indices = tf.concat(tf.split(space_indices, 2, axis=1)[::-1], axis=1)
    #space_indices = tf.transpose(tf.reshape(space_indices, [dynamical_batch_size, -1, 2]),
    #                             [1, 0, 2])
    ## T_w * B * F
    #source_embedded = tf.gather_nd(char_encoder_outputs, space_indices)

    # must be time major
    char_encoder_outputs = tf.transpose(char_encoder_outputs, perm=(1, 0, 2))
    sampled_word_embedded = tf.matmul(source_sample_matrix,
                                      char_encoder_outputs)
    source_embedded = tf.transpose(sampled_word_embedded, perm=(1, 0, 2))

    char_attention_keys = char_encoded_representation.attention_keys
    char_attention_values = char_encoded_representation.attention_values
    char_attention_length = char_encoded_representation.attention_length

    encoder = sq.StackBidirectionalRNNEncoder(
        params['encoder'],
        params['attention_key_size']['word'],
        name='stack_rnn',
        mode=mode)
    encoded_representation = encoder.encode(source_embedded,
                                            source_word_seq_length)
    attention_keys = encoded_representation.attention_keys
    attention_values = encoded_representation.attention_values
    attention_length = encoded_representation.attention_length
    encoder_final_states_bw = encoded_representation.final_state[-1][-1].h

    # feedback
    if mode == MODE.RL:
        tf.logging.info('BUILDING RL TRAIN FEEDBACK......')
        dynamical_batch_size = tf.shape(attention_keys)[1]
        feedback = sq.RLTrainingFeedBack(target_ids,
                                         target_seq_length,
                                         trg_vocab,
                                         dynamical_batch_size,
                                         burn_in_step=burn_in_step,
                                         increment_step=increment_step,
                                         max_step=max_step)
    elif mode == MODE.TRAIN:

        tf.logging.info('BUILDING TRAIN FEEDBACK WITH {} TEACHER_RATE'
                        '......'.format(teacher_rate))
        feedback = sq.TrainingFeedBack(target_ids,
                                       target_seq_length,
                                       trg_vocab,
                                       teacher_rate,
                                       max_step=max_step)
    elif mode == MODE.EVAL:
        tf.logging.info('BUILDING EVAL FEEDBACK ......')
        feedback = sq.TrainingFeedBack(target_ids,
                                       target_seq_length,
                                       trg_vocab,
                                       0.,
                                       max_step=max_step)
    else:
        tf.logging.info('BUILDING INFER FEEDBACK WITH BEAM_SIZE {}'
                        '......'.format(beam_size))
        infer_key_size = attention_keys.get_shape().as_list()[-1]
        infer_value_size = attention_values.get_shape().as_list()[-1]
        infer_states_bw_shape = encoder_final_states_bw.get_shape().as_list(
        )[-1]

        infer_char_key_size = char_attention_keys.get_shape().as_list()[-1]
        infer_char_value_size = char_attention_values.get_shape().as_list()[-1]

        encoder_final_states_bw = tf.reshape(
            tf.tile(encoder_final_states_bw, [1, beam_size]),
            [-1, infer_states_bw_shape])

        # expand beam
        if TIME_MAJOR:
            # batch size should be dynamical
            dynamical_batch_size = tf.shape(attention_keys)[1]
            final_key_shape = [
                -1, dynamical_batch_size * beam_size, infer_key_size
            ]
            final_value_shape = [
                -1, dynamical_batch_size * beam_size, infer_value_size
            ]
            attention_keys = tf.reshape(
                (tf.tile(attention_keys, [1, 1, beam_size])), final_key_shape)
            attention_values = tf.reshape(
                (tf.tile(attention_values, [1, 1, beam_size])),
                final_value_shape)

            final_char_key_shape = [
                -1, dynamical_batch_size * beam_size, infer_char_key_size
            ]
            final_char_value_shape = [
                -1, dynamical_batch_size * beam_size, infer_char_value_size
            ]
            char_attention_keys = tf.reshape(
                (tf.tile(char_attention_keys, [1, 1, beam_size])),
                final_char_key_shape)
            char_attention_values = tf.reshape(
                (tf.tile(char_attention_values, [1, 1, beam_size])),
                final_char_value_shape)

        else:
            dynamical_batch_size = tf.shape(attention_keys)[0]
            final_key_shape = [
                dynamical_batch_size * beam_size, -1, infer_key_size
            ]
            final_value_shape = [
                dynamical_batch_size * beam_size, -1, infer_value_size
            ]
            final_char_key_shape = [
                dynamical_batch_size * beam_size, -1, infer_char_key_size
            ]
            final_char_value_shape = [
                dynamical_batch_size * beam_size, -1, infer_char_value_size
            ]

            attention_keys = tf.reshape(
                (tf.tile(attention_keys, [1, beam_size, 1])), final_key_shape)
            attention_values = tf.reshape(
                (tf.tile(attention_values, [1, beam_size, 1])),
                final_value_shape)

            char_attention_keys = tf.reshape(
                (tf.tile(char_attention_keys, [1, beam_size, 1])),
                final_char_key_shape)
            char_attention_values = tf.reshape(
                (tf.tile(char_attention_values, [1, beam_size, 1])),
                final_char_value_shape)

        attention_length = tf.reshape(
            tf.transpose(tf.tile([attention_length], [beam_size, 1])), [-1])
        char_attention_length = tf.reshape(
            tf.transpose(tf.tile([char_attention_length], [beam_size, 1])),
            [-1])

        feedback = sq.BeamFeedBack(trg_vocab,
                                   beam_size,
                                   dynamical_batch_size,
                                   max_step=max_step)

    encoder_decoder_bridge = EncoderDecoderBridge(
        encoder_final_states_bw.get_shape().as_list()[-1],
        decoder_params['rnn_cell'])
    decoder_state_size = decoder_params['rnn_cell']['state_size']
    # attention
    attention = sq.AvAttention(decoder_state_size, attention_keys,
                               attention_values, attention_length,
                               char_attention_keys, char_attention_values,
                               char_attention_length)
    context_size = attention.context_size

    with tf.variable_scope('logits_func'):
        attention_mix = LinearOp(context_size + feedback.embedding_dim +
                                 decoder_state_size,
                                 decoder_state_size,
                                 name='attention_mix')
        attention_mix_middle = LinearOp(decoder_state_size,
                                        decoder_state_size // 2,
                                        name='attention_mix_middle')
        logits_trans = LinearOp(decoder_state_size // 2,
                                feedback.vocab_size,
                                name='logits_trans')
        logits_func = lambda _softmax: logits_trans(
            tf.nn.relu(
                attention_mix_middle(tf.nn.relu(attention_mix(_softmax)))))

    # decoder
    decoder = sq.AttentionRNNDecoder(
        decoder_params,
        attention,
        feedback,
        logits_func=logits_func,
        init_state=encoder_decoder_bridge(encoder_final_states_bw),
        mode=mode)
    decoder_output, decoder_final_state = sq.dynamic_decode(decoder,
                                                            swap_memory=True,
                                                            scope='decoder')

    # not training
    if mode == MODE.EVAL or mode == MODE.INFER:
        return decoder_output, decoder_final_state

    # bos is added in feedback
    # so target_ids is predict_ids
    if not TIME_MAJOR:
        ground_truth_ids = tf.transpose(target_ids, [1, 0])
    else:
        ground_truth_ids = target_ids

    # construct the loss
    if mode == MODE.RL:
        # Creates a variable to hold the global_step.
        global_step_tensor = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope='global_step')[0]
        rl_time_steps = tf.floordiv(
            tf.maximum(global_step_tensor - burn_in_step, 0), increment_step)
        start_rl_step = target_seq_length - rl_time_steps

        baseline_states = tf.stop_gradient(decoder_output.baseline_states)
        predict_ids = tf.stop_gradient(decoder_output.predicted_ids)

        # TODO: bug in tensorflow
        ground_or_predict_ids = tf.cond(tf.greater(rl_time_steps,
                                                   0), lambda: predict_ids,
                                        lambda: ground_truth_ids)

        reward, sequence_length = tf.py_func(
            func=_py_func,
            inp=[ground_or_predict_ids, ground_truth_ids, trg_vocab.eos_id],
            Tout=[tf.float32, tf.int32],
            name='reward')
        sequence_length.set_shape((None, ))

        total_loss_avg, entropy_loss_avg, reward_loss_rmse, reward_predicted \
            = rl_sequence_loss(
            logits=decoder_output.logits,
            predict_ids=predict_ids,
            sequence_length=sequence_length,
            baseline_states=baseline_states,
            start_rl_step=start_rl_step,
            reward=reward)
        return decoder_output, total_loss_avg, entropy_loss_avg, \
               reward_loss_rmse, reward_predicted
    else:
        total_loss_avg = cross_entropy_sequence_loss(
            logits=decoder_output.logits,
            targets=ground_truth_ids,
            sequence_length=target_seq_length)
        return decoder_output, total_loss_avg, total_loss_avg, \
               tf.to_float(0.), tf.to_float(0.)
Example #2
0
def build_attention_model(params,
                          src_vocab,
                          trg_vocab,
                          source_ids,
                          source_seq_length,
                          target_ids,
                          target_seq_length,
                          beam_size=1,
                          mode=MODE.TRAIN,
                          burn_in_step=100000,
                          increment_step=10000,
                          teacher_rate=1.0,
                          max_step=100):
    """
    Build a model.

    :param params: dict.
     {encoder: {rnn_cell: {},
                ...},
      decoder: {rnn_cell: {},
                ...}}
      for example:
        {'encoder': {'rnn_cell': {'state_size': 512,
                                   'cell_name': 'BasicLSTMCell',
                                   'num_layers': 2,
                                   'input_keep_prob': 1.0,
                                   'output_keep_prob': 1.0},
                      'attention_key_size': attention_size},
        'decoder':  {'rnn_cell': {'cell_name': 'BasicLSTMCell',
                                   'state_size': 512,
                                   'num_layers': 1,
                                   'input_keep_prob': 1.0,
                                   'output_keep_prob': 1.0},
                      'trg_vocab_size': trg_vocab_size}}
    :param src_vocab: Vocab of source symbols.
    :param trg_vocab: Vocab of target symbols.
    :param source_ids: placeholder
    :param source_seq_length: placeholder
    :param target_ids: placeholder
    :param target_seq_length: placeholder
    :param beam_size: used in beam inference
    :param mode:
    :return:
    """
    if mode != MODE.TRAIN:
        params = sq.disable_dropout(params)

    tf.logging.info(json.dumps(params, indent=4))

    # parameters
    encoder_params = params['encoder']
    decoder_params = params['decoder']

    # Because source encoder is different to the target feedback,
    # we construct source_embedding_table manually
    source_embedding_table = sq.LookUpOp(src_vocab.vocab_size,
                                         src_vocab.embedding_dim,
                                         name='source')
    source_embedded = source_embedding_table(source_ids)

    encoder = sq.StackBidirectionalRNNEncoder(encoder_params,
                                              name='stack_rnn',
                                              mode=mode)
    encoded_representation = encoder.encode(source_embedded, source_seq_length)
    attention_keys = encoded_representation.attention_keys
    attention_values = encoded_representation.attention_values
    attention_length = encoded_representation.attention_length

    # feedback
    if mode == MODE.RL:
        tf.logging.info('BUILDING RL TRAIN FEEDBACK......')
        dynamical_batch_size = tf.shape(attention_keys)[1]
        feedback = sq.RLTrainingFeedBack(target_ids,
                                         target_seq_length,
                                         trg_vocab,
                                         dynamical_batch_size,
                                         burn_in_step=burn_in_step,
                                         increment_step=increment_step,
                                         max_step=max_step)
    elif mode == MODE.TRAIN:
        tf.logging.info('BUILDING TRAIN FEEDBACK WITH {} TEACHER_RATE'
                        '......'.format(teacher_rate))
        feedback = sq.TrainingFeedBack(target_ids,
                                       target_seq_length,
                                       trg_vocab,
                                       teacher_rate,
                                       max_step=max_step)
    elif mode == MODE.EVAL:
        tf.logging.info('BUILDING EVAL FEEDBACK ......')
        feedback = sq.TrainingFeedBack(target_ids,
                                       target_seq_length,
                                       trg_vocab,
                                       0.,
                                       max_step=max_step)
    else:
        tf.logging.info('BUILDING INFER FEEDBACK WITH BEAM_SIZE {}'
                        '......'.format(beam_size))
        infer_key_size = attention_keys.get_shape().as_list()[-1]
        infer_value_size = attention_values.get_shape().as_list()[-1]

        # expand beam
        if TIME_MAJOR:
            # batch size should be dynamical
            dynamical_batch_size = tf.shape(attention_keys)[1]
            final_key_shape = [
                -1, dynamical_batch_size * beam_size, infer_key_size
            ]
            final_value_shape = [
                -1, dynamical_batch_size * beam_size, infer_value_size
            ]
            attention_keys = tf.reshape(
                (tf.tile(attention_keys, [1, 1, beam_size])), final_key_shape)
            attention_values = tf.reshape(
                (tf.tile(attention_values, [1, 1, beam_size])),
                final_value_shape)
        else:
            dynamical_batch_size = tf.shape(attention_keys)[0]
            final_key_shape = [
                dynamical_batch_size * beam_size, -1, infer_key_size
            ]
            final_value_shape = [
                dynamical_batch_size * beam_size, -1, infer_value_size
            ]
            attention_keys = tf.reshape(
                (tf.tile(attention_keys, [1, beam_size, 1])), final_key_shape)
            attention_values = tf.reshape(
                (tf.tile(attention_values, [1, beam_size, 1])),
                final_value_shape)

        attention_length = tf.reshape(
            tf.transpose(tf.tile([attention_length], [beam_size, 1])), [-1])

        feedback = sq.BeamFeedBack(trg_vocab,
                                   beam_size,
                                   dynamical_batch_size,
                                   max_step=max_step)

    # attention
    attention = sq.Attention(decoder_params['rnn_cell']['state_size'],
                             attention_keys, attention_values,
                             attention_length)

    # decoder
    decoder = sq.AttentionRNNDecoder(decoder_params,
                                     attention,
                                     feedback,
                                     mode=mode)
    decoder_output, decoder_final_state = sq.dynamic_decode(decoder,
                                                            swap_memory=True,
                                                            scope='decoder')

    # not training
    if mode == MODE.EVAL or mode == MODE.INFER:
        return decoder_output, decoder_final_state

    # bos is added in feedback
    # so target_ids is predict_ids
    if not TIME_MAJOR:
        ground_truth_ids = tf.transpose(target_ids, [1, 0])
    else:
        ground_truth_ids = target_ids

    # construct the loss
    if mode == MODE.RL:
        # Creates a variable to hold the global_step.
        global_step_tensor = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope='global_step')[0]
        rl_time_steps = tf.floordiv(
            tf.maximum(global_step_tensor - burn_in_step, 0), increment_step)
        start_rl_step = target_seq_length - rl_time_steps

        baseline_states = tf.stop_gradient(decoder_output.baseline_states)
        predict_ids = tf.stop_gradient(decoder_output.predicted_ids)

        # TODO: bug in tensorflow
        ground_or_predict_ids = tf.cond(tf.greater(rl_time_steps,
                                                   0), lambda: predict_ids,
                                        lambda: ground_truth_ids)

        reward, sequence_length = tf.py_func(
            func=_py_func,
            inp=[ground_or_predict_ids, ground_truth_ids, trg_vocab.eos_id],
            Tout=[tf.float32, tf.int32],
            name='reward')
        sequence_length.set_shape((None, ))

        total_loss_avg, entropy_loss_avg, reward_loss_rmse, reward_predicted \
            = rl_sequence_loss(
            logits=decoder_output.logits,
            predict_ids=predict_ids,
            sequence_length=sequence_length,
            baseline_states=baseline_states,
            start_rl_step=start_rl_step,
            reward=reward)
        return decoder_output, total_loss_avg, entropy_loss_avg, \
               reward_loss_rmse, reward_predicted
    else:
        total_loss_avg = cross_entropy_sequence_loss(
            logits=decoder_output.logits,
            targets=ground_truth_ids,
            sequence_length=target_seq_length)
        return decoder_output, total_loss_avg, total_loss_avg, \
               tf.to_float(0.), tf.to_float(0.)