Exemplo n.º 1
0
        'input_ids', 'segment_ids', 'input_mask', 'start_position',
        'end_position', 'question_mask', 'rationale_mask', 'yes_mask',
        'extractive_mask', 'no_mask', 'unk_mask', 'qid'
    ])
eval_batch_generator = BatchGenerator(
    vocab,
    eval_data,
    training=False,
    batch_size=12,
    additional_fields=[
        'input_ids', 'segment_ids', 'input_mask', 'start_position',
        'end_position', 'question_mask', 'rationale_mask', 'yes_mask',
        'extractive_mask', 'no_mask', 'unk_mask', 'qid'
    ])

model = BertCoQA(bert_dir=bert_dir, answer_verificatioin=True)
warmup_proportion = 0.1
num_train_steps = int(len(train_data) / 12 * 2)
num_warmup_steps = int(num_train_steps * warmup_proportion)

# original paper adamax optimizer
model.compile(3e-5,
              num_train_steps=num_train_steps,
              num_warmup_steps=num_warmup_steps)
model.train_and_evaluate(train_batch_generator,
                         eval_batch_generator,
                         evaluator,
                         epochs=2,
                         eposides=1)
#model.evaluate(eval_batch_generator,evaluator)
                        "squad_seq2seq_dev_moses_tokenized")
coqa_format_test_save_file = os.path.join(
    DATA_DIR, "squad_seq2seq_dev_moses_test_coqa_format.json")
src_squad_seq2seq_predicted_responses_file = os.path.join(
    DATA_DIR, "src_squad_seq2seq_dev_moses_test.txt")
predictions_save_file = "coqa_predictions_on_squad_seq2seq_dev_moses_test.txt"

test_data = coqa_reader.read(coqa_format_test_save_file, 'test')
evaluator = CoQAEvaluator(coqa_format_test_save_file)

best_model_path = os.path.join('models', 'best_weights')
bert_dir = 'uncased_L-12_H-768_A-12'
bert_data_helper = BertDataHelper(bert_dir)
test_data = bert_data_helper.convert(test_data, data='coqa')

model = BertCoQA(bert_dir=bert_dir, answer_verification=True)
print("loading model")
model.load(best_model_path)
print("model loaded")

my_batch_size = 6
test_batch_generator = BatchGenerator(
    vocab,
    test_data,
    training=False,
    batch_size=my_batch_size,
    additional_fields=[
        'input_ids', 'segment_ids', 'input_mask', 'start_position',
        'end_position', 'question_mask', 'rationale_mask', 'yes_mask',
        'extractive_mask', 'no_mask', 'unk_mask', 'qid'
    ])