示例#1
0
def add_extra_summary(config,
                      vocab_i2s,
                      decoder_output,
                      base_word,
                      output_words,
                      src_words,
                      tgt_words,
                      inserted_words,
                      deleted_words,
                      oov,
                      vocab_size,
                      collections=None):
    pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s, vocab_size,
                                     oov)
    pred_len = decoder.seq_length(decoder_output)

    ops = {}

    if config.get('logger.enable_trace', False):
        trace_summary = add_extra_summary_trace(pred_tokens, pred_len,
                                                base_word, output_words,
                                                src_words, tgt_words,
                                                inserted_words, deleted_words,
                                                vocab_i2s, collections)
        ops[ES_TRACE] = trace_summary

    if config.get('logger.enable_bleu', True):
        avg_bleu = add_extra_summary_avg_bleu(pred_tokens, pred_len,
                                              output_words, vocab_i2s,
                                              collections)
        ops[ES_BLEU] = avg_bleu

    return ops
示例#2
0
def model_fn(features, mode, config, embedding_matrix, vocab_tables):
    if mode == tf.estimator.ModeKeys.PREDICT:
        base_words, extended_base_words, \
        _, _, \
        src_words, tgt_words, \
        inserted_words, deleted_words, \
        oov = features
        output_words = extended_output_words = tgt_words
    else:
        base_words, extended_base_words, \
        output_words, extended_output_words, \
        src_words, tgt_words, \
        inserted_words, deleted_words, \
        oov = features

    is_training = mode == tf.estimator.ModeKeys.TRAIN
    tf.add_to_collection('is_training', is_training)

    if mode != tf.estimator.ModeKeys.TRAIN:
        config.put('editor.enable_dropout', False)
        config.put('editor.dropout_keep', 1.0)

    vocab_i2s = vocab_tables[vocab.INT_TO_STR]
    vocab.init_embeddings(embedding_matrix)

    vocab_size = len(vocab_tables[vocab.RAW_WORD2ID])

    train_decoder_output, infer_decoder_output, \
    gold_dec_out, gold_dec_out_len = editor.editor_train(
        base_words, extended_base_words, output_words, extended_output_words,
        src_words, tgt_words, inserted_words, deleted_words, oov,
        vocab_size,
        config.editor.hidden_dim, config.editor.agenda_dim, config.editor.edit_dim,
        config.editor.edit_enc.micro_ev_dim, config.editor.edit_enc.num_heads,
        config.editor.encoder_layers, config.editor.decoder_layers, config.editor.attention_dim,
        config.editor.beam_width,
        config.editor.edit_enc.transformer,
        config.editor.edit_enc.wa_hidden_dim, config.editor.edit_enc.wa_hidden_layer,
        config.editor.edit_enc.meve_hidden_dim, config.editor.edit_enc.meve_hidden_layer,
        config.editor.max_sent_length, config.editor.dropout_keep, config.editor.lamb_reg,
        config.editor.norm_eps, config.editor.norm_max, config.editor.kill_edit,
        config.editor.draw_edit, config.editor.use_swap_memory,
        config.get('editor.use_beam_decoder', False), config.get('editor.enable_dropout', False),
        config.get('editor.no_insert_delete_attn', False), config.get('editor.enable_vae', True)
    )

    loss = optimizer.loss(train_decoder_output, gold_dec_out, gold_dec_out_len)
    train_op, gradients_norm = optimizer.train(loss, config.optim.learning_rate, config.optim.max_norm_observe_steps)

    tf.logging.info("Trainable variable")
    for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        tf.logging.info(str(i))

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.summary.scalar('grad_norm', gradients_norm)
        ops = add_extra_summary(config, vocab_i2s, train_decoder_output,
                                base_words, output_words,
                                src_words, tgt_words,
                                inserted_words, deleted_words, oov,
                                vocab_size, collections=['extra'])

        hooks = [
            get_train_extra_summary_writer(config),
            get_extra_summary_logger(ops, config),
        ]

        if config.get('logger.enable_profiler', False):
            hooks.append(get_profiler_hook(config))

        return tf.estimator.EstimatorSpec(
            mode,
            train_op=train_op,
            loss=loss,
            training_hooks=hooks
        )

    elif mode == tf.estimator.ModeKeys.EVAL:
        ops = add_extra_summary(config, vocab_i2s, train_decoder_output,
                                base_words, output_words,
                                src_words, tgt_words,
                                inserted_words, deleted_words, oov,
                                vocab_size, collections=['extra'])

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            evaluation_hooks=[get_extra_summary_logger(ops, config)],
            eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])}
        )

    elif mode == tf.estimator.ModeKeys.PREDICT:
        lengths = decoder.seq_length(infer_decoder_output)
        tokens = decoder.str_tokens(infer_decoder_output, vocab_i2s, vocab_size, oov)
        # attns_weight = tf.get_collection('attns_weight')

        preds = {
            'str_tokens': tokens,
            'sample_id': decoder.sample_id(infer_decoder_output),
            'lengths': lengths,
            'joined': metrics.join_tokens(tokens, lengths),
        }

        tmee_attentions = tf.get_collection('TransformerMicroEditExtractor_Attentions')
        if len(tmee_attentions) > 0:
            preds.update({
                'tmee_attentions_st_enc_self': tmee_attentions[0][0],
                'tmee_attentions_st_dec_self': tmee_attentions[0][1],
                'tmee_attentions_st_dec_enc': tmee_attentions[0][2],
                'tmee_attentions_ts_enc_self': tmee_attentions[1][0],
                'tmee_attentions_ts_dec_self': tmee_attentions[1][1],
                'tmee_attentions_ts_dec_enc': tmee_attentions[1][2],
            })

        add_decoder_attention(config, preds)

        return tf.estimator.EstimatorSpec(
            mode,
            predictions=preds
        )
示例#3
0
def decoder_outputs_to_pq(decoder_output,
                          base_sent_embed,
                          src_words,
                          tgt_words,
                          src_len,
                          tgt_len,
                          temperature_starter,
                          decay_rate,
                          decay_steps,
                          agenda_dim,
                          enc_hidden_dim,
                          enc_num_layers,
                          dec_hidden_dim,
                          dec_num_layers,
                          swap_memory,
                          use_dropout=False,
                          dropout_keep=1.0):
    with tf.variable_scope(OPS_NAME):
        # [batch]
        output_length = decoder.seq_length(decoder_output)

        # [batch x max_len x word_dim]
        output_embed = prepare_output_embed(decoder_output,
                                            temperature_starter, decay_rate,
                                            decay_steps)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, dec_output_embedding = encoder.bidirectional_encoder(
            output_embed,
            output_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory,
            name='dec_output_encoder')

        agenda = tf.concat([dec_output_embedding, base_sent_embed], axis=1)
        agenda = tf.layers.dense(agenda,
                                 agenda_dim,
                                 activation=None,
                                 use_bias=False,
                                 name='agenda')  # [batch x agenda_dim]

        # append START, STOP token to create decoder input and output
        out = prepare_decoder_input_output(src_words, src_len, None)
        p_dec_inp, p_dec_inp_len, p_dec_out, p_dec_out_len = out

        out = prepare_decoder_input_output(tgt_words, tgt_len, None)
        q_dec_inp, q_dec_inp_len, q_dec_out, q_dec_out_len = out

        # decode agenda twice
        p_dec, q_dec = decode_agenda(agenda,
                                     p_dec_inp,
                                     q_dec_inp,
                                     p_dec_inp_len,
                                     q_dec_inp_len,
                                     dec_hidden_dim,
                                     dec_num_layers,
                                     swap_memory,
                                     use_dropout=use_dropout,
                                     dropout_keep=dropout_keep)

        # calculate the loss
        with tf.name_scope('losses'):
            p_loss = decoding_loss(p_dec, p_dec_out, p_dec_out_len)
            q_loss = decoding_loss(q_dec, q_dec_out, q_dec_out_len)

            loss = p_loss + q_loss

    tf.summary.scalar('reconstruction_loss', loss, ['extra'])

    return loss
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter,
                                   decay_rate, decay_steps, edit_dim,
                                   enc_hidden_dim, enc_num_layers,
                                   dense_layers, swap_memory):
    with tf.variable_scope(OPS_NAME):
        # [VOCAB x word_dim]
        embeddings = vocab.get_embeddings()

        # Extend embedding matrix to support oov tokens
        unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
        unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
        unk_embeddings = tf.tile(unk_embed, [50, 1])

        # [VOCAB+50 x word_dim]
        embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

        global_step = tf.train.get_global_step()
        temperature = tf.train.exponential_decay(temperature_starter,
                                                 global_step,
                                                 decay_steps,
                                                 decay_rate,
                                                 name='temperature')
        tf.summary.scalar('temper', temperature, ['extra'])

        # [batch x max_len x VOCAB+50], softmax probabilities
        outputs = decoder.rnn_output(decoder_output)

        # substitute values less than 0 for numerical stability
        outputs = tf.where(tf.less_equal(outputs, 0),
                           tf.ones_like(outputs) * 1e-10, outputs)

        # convert softmax probabilities to one_hot vectors
        dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

        # [batch x max_len x VOCAB+50], one_hot
        outputs_one_hot = dist.sample()

        # [batch x max_len x word_dim], one_hot^T * embedding_matrix
        outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                                  embeddings_extended)

        # [batch]
        outputs_length = decoder.seq_length(decoder_output)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, sentence_embedding = encoder.source_sent_encoder(
            outputs_embed,
            outputs_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory)

        h = sentence_embedding
        for l in dense_layers:
            h = tf.layers.dense(h,
                                l,
                                activation='relu',
                                name='hidden_%s' % (l))

        # [batch x edit_dim]
        edit_vector = tf.layers.dense(h,
                                      edit_dim,
                                      activation=None,
                                      name='linear')

    return edit_vector