示例#1
0
def test(cfg):
    logging.info('Model is loading...')
    with open(cfg['dev_eval_file'], "r") as fh:
        dev_eval_file = json.load(fh)
    dev_dataset = SQuADDataset(cfg['dev_record_file'], -1, cfg['batch_size'], cfg['word2ind_file'])
    model_args = pickle.load(open(cfg['args_filename'], 'rb'))
    model = QANet(**model_args)

    model.load_state_dict(torch.load(cfg['dump_filename']))
    model.to(device)
    
    metrics, answer_dict = evaluation(model, dev_dataset, dev_eval_file, len(dev_dataset))
    with open('logs/answers.json', 'w') as f:
        json.dump(answer_dict, f)
    logging.info("TEST loss %f F1 %f EM %f\n", metrics["loss"], metrics["f1"], metrics["exact_match"])
示例#2
0
def train(model_params, launch_params):
    with open(launch_params['word_emb_file'], "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(launch_params['char_emb_file'], "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(launch_params['train_eval_file'], "r") as fh:
        train_eval_file = json.load(fh)
    with open(launch_params['dev_eval_file'], "r") as fh:
        dev_eval_file = json.load(fh)

    writer = SummaryWriter(os.path.join(launch_params['log'], launch_params['prefix']))
    
    lr = launch_params['learning_rate']
    base_lr = 1.0
    warm_up = launch_params['lr_warm_up_num']
    model_params['word_mat'] = word_mat
    model_params['char_mat'] = char_mat
    
    logging.info('Load dataset and create model.')
    dev_dataset = SQuADDataset(launch_params['dev_record_file'], launch_params['test_num_batches'], 
                               launch_params['batch_size'], launch_params['word2ind_file'])
    if launch_params['fine_tuning']:
        train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['fine_tuning_steps'], 
                                    launch_params['batch_size'], launch_params['word2ind_file'])
        model_args = pickle.load(open(launch_params['args_filename'], 'rb'))
        model = QANet(**model_args)
        model.load_state_dict(torch.load(launch_params['dump_filename']))
        model.to(device)
    else:
        train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['num_steps'], 
                                    launch_params['batch_size'], launch_params['word2ind_file'])
        model = QANet(**model_params).to(device)
        launch_params['fine_tuning_steps'] = 0
    
    params = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(params, lr=base_lr, betas=(launch_params['beta1'], launch_params['beta2']), eps=1e-7, weight_decay=3e-7)
    cr = lr / log2(warm_up)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr)
    qt = False
    logging.info('Start training.')
    for iter in range(launch_params['num_steps']):
        try:
            passage_w, passage_c, question_w, question_c, y1, y2, ids = train_dataset[iter]
            passage_w, passage_c = passage_w.to(device), passage_c.to(device)
            question_w, question_c = question_w.to(device), question_c.to(device)
            y1, y2 = y1.to(device), y2.to(device)
            loss, p1, p2 = model.train_step([passage_w, passage_c, question_w, question_c], y1, y2, optimizer, scheduler)
            if iter % launch_params['train_interval'] == 0:
                logging.info('Iteration %d; Loss: %f', iter+launch_params['fine_tuning_steps'], loss)
                writer.add_scalar('Loss', loss, iter+launch_params['fine_tuning_steps'])
            if iter % launch_params['train_sample_interval'] == 0:
                start = torch.argmax(p1[0, :]).item()
                end = torch.argmax(p2[0, start:]).item()+start
                passage = train_dataset.decode(passage_w)
                question = train_dataset.decode(question_w)
                generated_answer = train_dataset.decode(passage_w[:, start:end+1])
                real_answer = train_dataset.decode(passage_w[:, y1[0]:y2[0]+1])
                logging.info('Train Sample:\n Passage: %s\nQuestion: %s\nOriginal answer: %s\nGenerated answer: %s',
                        passage, question, real_answer, generated_answer)
            if iter % launch_params['test_interval'] == 0:
                metrics, _ = evaluation(model, train_dataset, train_eval_file, launch_params['val_num_batches'])
                logging.info("VALID loss %f F1 %f EM %f", metrics['loss'], metrics['f1'], metrics['exact_match'])
                writer.add_scalar('Valid_loss', metrics['loss'], iter)
                writer.add_scalar('Valid_f1', metrics['f1'], iter)
                writer.add_scalar('Valid_em', metrics['exact_match'], iter)
            if iter % launch_params['test_interval'] == 0:
                metrics, _ = evaluation(model, dev_dataset, dev_eval_file, launch_params['test_num_batches'])
                logging.info("TEST loss %f F1 %f EM %f", metrics['loss'], metrics['f1'], metrics['exact_match'])
                writer.add_scalar('Test_loss', metrics['loss'], iter)
                writer.add_scalar('Test_f1', metrics['f1'], iter)
                writer.add_scalar('Test_em', metrics['exact_match'], iter)
        except RuntimeError as e:
            logging.error(str(e))
        except KeyboardInterrupt:
            break
    torch.save(model.cpu().state_dict(), launch_params['dump_filename'])
    pickle.dump(model_params, open(launch_params['args_filename'], 'wb'))
    logging.info('Model has been saved.')