Пример #1
0
def model_fn(features, mode, config, embedding_matrix, vocab_tables):
    if mode == tf.estimator.ModeKeys.PREDICT:
        base_words, _, src_words, tgt_words, inserted_words, commong_words = features
        output_words = tgt_words
    else:
        base_words, output_words, src_words, tgt_words, inserted_words, commong_words = features

    is_training = mode == tf.estimator.ModeKeys.TRAIN
    tf.add_to_collection('is_training', is_training)

    if mode != tf.estimator.ModeKeys.TRAIN:
        config.put('editor.enable_dropout', False)
        config.put('editor.dropout_keep', 1.0)
        config.put('editor.dropout', 0.0)

        config.put('editor.transformer.enable_dropout', False)
        config.put('editor.transformer.layer_postprocess_dropout', 0.0)
        config.put('editor.transformer.attention_dropout', 0.0)
        config.put('editor.transformer.relu_dropout', 0.0)

    vocab.init_embeddings(embedding_matrix)
    EmbeddingSharedWeights.init_from_embedding_matrix()

    editor_model = Editor(config)
    logits, beam_prediction = editor_model(base_words, src_words, tgt_words,
                                           inserted_words, commong_words,
                                           output_words)

    targets = decoder.prepare_decoder_output(
        output_words, sequence.length_pre_embedding(output_words))
    target_lengths = sequence.length_pre_embedding(targets)

    vocab_size = embedding_matrix.shape[0]
    loss, weights = optimizer.padded_cross_entropy_loss(
        logits, targets, target_lengths, config.optim.label_smoothing,
        vocab_size)

    train_op = optimizer.get_train_op(loss, config)

    tf.logging.info("Trainable variable")
    for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        tf.logging.info(str(i))

    tf.logging.info("Num of Trainable parameters")
    tf.logging.info(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))

    if mode == tf.estimator.ModeKeys.TRAIN:
        decoded_ids = decoder.logits_to_decoded_ids(logits)
        ops = add_extra_summary(config,
                                decoded_ids,
                                target_lengths,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                commong_words,
                                collections=['extra'])

        hooks = [
            get_train_extra_summary_writer(config),
            get_extra_summary_logger(ops, config),
        ]

        if config.get('logger.enable_profiler', False):
            hooks.append(get_profiler_hook(config))

        return tf.estimator.EstimatorSpec(mode,
                                          train_op=train_op,
                                          loss=loss,
                                          training_hooks=hooks)

    elif mode == tf.estimator.ModeKeys.EVAL:
        decoded_ids = decoder.logits_to_decoded_ids(logits)
        ops = add_extra_summary(config,
                                decoded_ids,
                                target_lengths,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                commong_words,
                                collections=['extra'])

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            evaluation_hooks=[get_extra_summary_logger(ops, config)],
            eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])})

    elif mode == tf.estimator.ModeKeys.PREDICT:
        decoded_ids, decoded_lengths, scores = beam_prediction
        tokens = decoder.str_tokens(decoded_ids)

        preds = {
            'str_tokens': tf.transpose(tokens, [0, 2, 1]),
            'decoded_ids': tf.transpose(decoded_ids, [0, 2, 1]),
            'lengths': decoded_lengths,
            'joined': metrics.join_tokens(tokens, decoded_lengths)
        }

        tmee_attentions = tf.get_collection(
            'TransformerMicroEditExtractor_Attentions')
        if len(tmee_attentions) > 0:
            preds.update({
                'tmee_attentions_st_enc_self': tmee_attentions[0][0],
                'tmee_attentions_st_dec_self': tmee_attentions[0][1],
                'tmee_attentions_st_dec_enc': tmee_attentions[0][2],
                'tmee_attentions_ts_enc_self': tmee_attentions[1][0],
                'tmee_attentions_ts_dec_self': tmee_attentions[1][1],
                'tmee_attentions_ts_dec_enc': tmee_attentions[1][2],
                'src_words': src_words,
                'tgt_words': tgt_words,
                'base_words': base_words,
                'output_words': output_words
            })

        return tf.estimator.EstimatorSpec(mode, predictions=preds)
Пример #2
0
def model_fn(features, mode, config, embedding_matrix, vocab_tables):
    if mode == tf.estimator.ModeKeys.PREDICT:
        base_words, src_words, tgt_words, inserted_words, deleted_words = features
    else:
        src_words, tgt_words, inserted_words, deleted_words = features
        base_words = src_words

    if mode != tf.estimator.ModeKeys.TRAIN:
        config.put('editor.enable_dropout', False)
        config.put('editor.dropout_keep', 1.0)

    vocab_s2i = vocab_tables[vocab.STR_TO_INT]
    vocab_i2s = vocab_tables[vocab.INT_TO_STR]

    vocab.init_embeddings(embedding_matrix)

    train_decoder_output, infer_decoder_output, \
    gold_dec_out, gold_dec_out_len = editor.editor_train(
        base_words, src_words, tgt_words, inserted_words, deleted_words, embedding_matrix, vocab_s2i,
        config.editor.hidden_dim, config.editor.agenda_dim, config.editor.edit_dim,
        config.editor.encoder_layers, config.editor.decoder_layers, config.editor.attention_dim,
        config.editor.beam_width,
        config.editor.max_sent_length, config.editor.dropout_keep, config.editor.lamb_reg,
        config.editor.norm_eps, config.editor.norm_max, config.editor.kill_edit,
        config.editor.draw_edit, config.editor.use_swap_memory,
        config.get('editor.use_beam_decoder', False), config.get('editor.enable_dropout', False),
        config.get('editor.no_insert_delete_attn', False), config.get('editor.enable_vae', True)
    )

    loss = optimizer.loss(train_decoder_output, gold_dec_out, gold_dec_out_len)
    train_op, gradients_norm = optimizer.train(
        loss, config.optim.learning_rate, config.optim.max_norm_observe_steps)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.summary.scalar('grad_norm', gradients_norm)
        ops = add_extra_summary(config, vocab_i2s, train_decoder_output,
                                src_words, tgt_words, inserted_words,
                                deleted_words, ['extra'])

        hooks = [
            get_train_extra_summary_writer(config),
            get_extra_summary_logger(ops, config),
        ]

        if config.get('logger.enable_profiler', False):
            hooks.append(get_profiler_hook(config))

        return tf.estimator.EstimatorSpec(mode,
                                          train_op=train_op,
                                          loss=loss,
                                          training_hooks=hooks)

    elif mode == tf.estimator.ModeKeys.EVAL:
        ops = add_extra_summary(config, vocab_i2s, train_decoder_output,
                                src_words, tgt_words, inserted_words,
                                deleted_words, ['extra'])

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            evaluation_hooks=[get_extra_summary_logger(ops, config)],
            eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])})

    elif mode == tf.estimator.ModeKeys.PREDICT:
        lengths = decoder.seq_length(infer_decoder_output)
        tokens = decoder.str_tokens(infer_decoder_output, vocab_i2s)
        preds = {
            'str_tokens': tokens,
            'sample_id': decoder.sample_id(infer_decoder_output),
            'lengths': lengths,
            'joined': metrics.join_tokens(tokens, lengths),
        }

        return tf.estimator.EstimatorSpec(mode, predictions=preds)
Пример #3
0
def model_fn(features, mode, config, embedding_matrix, vocab_tables):
    if mode == tf.estimator.ModeKeys.PREDICT:
        base_words, extended_base_words, \
        _, _, \
        src_words, tgt_words, \
        inserted_words, deleted_words, \
        oov = features
        output_words = extended_output_words = tgt_words
    else:
        base_words, extended_base_words, \
        output_words, extended_output_words, \
        src_words, tgt_words, \
        inserted_words, deleted_words, \
        oov = features

    if mode != tf.estimator.ModeKeys.TRAIN:
        config.put('editor.enable_dropout', False)
        config.put('editor.dropout_keep', 1.0)

    vocab_i2s = vocab_tables[vocab.INT_TO_STR]
    vocab.init_embeddings(embedding_matrix)

    vocab_size = len(vocab_tables[vocab.RAW_WORD2ID])

    train_decoder_output, infer_decoder_output, \
    gold_dec_out, gold_dec_out_len = editor.editor_train(
        base_words, extended_base_words, output_words, extended_output_words,
        src_words, tgt_words, inserted_words, deleted_words, oov,
        vocab_size,
        config.editor.hidden_dim, config.editor.agenda_dim, config.editor.edit_dim,
        config.editor.edit_enc.micro_ev_dim, config.editor.edit_enc.num_heads,
        config.editor.encoder_layers, config.editor.decoder_layers, config.editor.attention_dim,
        config.editor.beam_width,
        config.editor.edit_enc.ctx_hidden_dim, config.editor.edit_enc.ctx_hidden_layer,
        config.editor.edit_enc.wa_hidden_dim, config.editor.edit_enc.wa_hidden_layer,
        config.editor.edit_enc.meve_hidden_dim, config.editor.edit_enc.meve_hidden_layer,
        config.editor.max_sent_length, config.editor.dropout_keep, config.editor.lamb_reg,
        config.editor.norm_eps, config.editor.norm_max, config.editor.kill_edit,
        config.editor.draw_edit, config.editor.use_swap_memory,
        config.get('editor.use_beam_decoder', False), config.get('editor.enable_dropout', False),
        config.get('editor.no_insert_delete_attn', False), config.get('editor.enable_vae', True)
    )

    loss = optimizer.loss(train_decoder_output, gold_dec_out, gold_dec_out_len)
    train_op, gradients_norm = optimizer.train(
        loss, config.optim.learning_rate, config.optim.max_norm_observe_steps)

    tf.logging.info("Trainable variable")
    for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        tf.logging.info(str(i))

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.summary.scalar('grad_norm', gradients_norm)
        ops = add_extra_summary(config,
                                vocab_i2s,
                                train_decoder_output,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                deleted_words,
                                oov,
                                vocab_size,
                                collections=['extra'])

        hooks = [
            get_train_extra_summary_writer(config),
            get_extra_summary_logger(ops, config),
        ]

        if config.get('logger.enable_profiler', False):
            hooks.append(get_profiler_hook(config))

        return tf.estimator.EstimatorSpec(mode,
                                          train_op=train_op,
                                          loss=loss,
                                          training_hooks=hooks)

    elif mode == tf.estimator.ModeKeys.EVAL:
        ops = add_extra_summary(config,
                                vocab_i2s,
                                train_decoder_output,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                deleted_words,
                                oov,
                                vocab_size,
                                collections=['extra'])

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            evaluation_hooks=[get_extra_summary_logger(ops, config)],
            eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])})

    elif mode == tf.estimator.ModeKeys.PREDICT:
        lengths = decoder.seq_length(infer_decoder_output)
        tokens = decoder.str_tokens(infer_decoder_output, vocab_i2s,
                                    vocab_size, oov)

        preds = {
            'str_tokens': tokens,
            'sample_id': decoder.sample_id(infer_decoder_output),
            'lengths': lengths,
            'joined': metrics.join_tokens(tokens, lengths),
        }

        return tf.estimator.EstimatorSpec(mode, predictions=preds)