def test(cfg): logging.info('Model is loading...') with open(cfg['dev_eval_file'], "r") as fh: dev_eval_file = json.load(fh) dev_dataset = SQuADDataset(cfg['dev_record_file'], -1, cfg['batch_size'], cfg['word2ind_file']) model_args = pickle.load(open(cfg['args_filename'], 'rb')) model = QANet(**model_args) model.load_state_dict(torch.load(cfg['dump_filename'])) model.to(device) metrics, answer_dict = evaluation(model, dev_dataset, dev_eval_file, len(dev_dataset)) with open('logs/answers.json', 'w') as f: json.dump(answer_dict, f) logging.info("TEST loss %f F1 %f EM %f\n", metrics["loss"], metrics["f1"], metrics["exact_match"])
def train(model_params, launch_params): with open(launch_params['word_emb_file'], "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(launch_params['char_emb_file'], "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(launch_params['train_eval_file'], "r") as fh: train_eval_file = json.load(fh) with open(launch_params['dev_eval_file'], "r") as fh: dev_eval_file = json.load(fh) writer = SummaryWriter(os.path.join(launch_params['log'], launch_params['prefix'])) lr = launch_params['learning_rate'] base_lr = 1.0 warm_up = launch_params['lr_warm_up_num'] model_params['word_mat'] = word_mat model_params['char_mat'] = char_mat logging.info('Load dataset and create model.') dev_dataset = SQuADDataset(launch_params['dev_record_file'], launch_params['test_num_batches'], launch_params['batch_size'], launch_params['word2ind_file']) if launch_params['fine_tuning']: train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['fine_tuning_steps'], launch_params['batch_size'], launch_params['word2ind_file']) model_args = pickle.load(open(launch_params['args_filename'], 'rb')) model = QANet(**model_args) model.load_state_dict(torch.load(launch_params['dump_filename'])) model.to(device) else: train_dataset = SQuADDataset(launch_params['train_record_file'], launch_params['num_steps'], launch_params['batch_size'], launch_params['word2ind_file']) model = QANet(**model_params).to(device) launch_params['fine_tuning_steps'] = 0 params = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(params, lr=base_lr, betas=(launch_params['beta1'], launch_params['beta2']), eps=1e-7, weight_decay=3e-7) cr = lr / log2(warm_up) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr) qt = False logging.info('Start training.') for iter in range(launch_params['num_steps']): try: passage_w, passage_c, question_w, question_c, y1, y2, ids = train_dataset[iter] passage_w, passage_c = passage_w.to(device), passage_c.to(device) question_w, question_c = question_w.to(device), question_c.to(device) y1, y2 = y1.to(device), y2.to(device) loss, p1, p2 = model.train_step([passage_w, passage_c, question_w, question_c], y1, y2, optimizer, scheduler) if iter % launch_params['train_interval'] == 0: logging.info('Iteration %d; Loss: %f', iter+launch_params['fine_tuning_steps'], loss) writer.add_scalar('Loss', loss, iter+launch_params['fine_tuning_steps']) if iter % launch_params['train_sample_interval'] == 0: start = torch.argmax(p1[0, :]).item() end = torch.argmax(p2[0, start:]).item()+start passage = train_dataset.decode(passage_w) question = train_dataset.decode(question_w) generated_answer = train_dataset.decode(passage_w[:, start:end+1]) real_answer = train_dataset.decode(passage_w[:, y1[0]:y2[0]+1]) logging.info('Train Sample:\n Passage: %s\nQuestion: %s\nOriginal answer: %s\nGenerated answer: %s', passage, question, real_answer, generated_answer) if iter % launch_params['test_interval'] == 0: metrics, _ = evaluation(model, train_dataset, train_eval_file, launch_params['val_num_batches']) logging.info("VALID loss %f F1 %f EM %f", metrics['loss'], metrics['f1'], metrics['exact_match']) writer.add_scalar('Valid_loss', metrics['loss'], iter) writer.add_scalar('Valid_f1', metrics['f1'], iter) writer.add_scalar('Valid_em', metrics['exact_match'], iter) if iter % launch_params['test_interval'] == 0: metrics, _ = evaluation(model, dev_dataset, dev_eval_file, launch_params['test_num_batches']) logging.info("TEST loss %f F1 %f EM %f", metrics['loss'], metrics['f1'], metrics['exact_match']) writer.add_scalar('Test_loss', metrics['loss'], iter) writer.add_scalar('Test_f1', metrics['f1'], iter) writer.add_scalar('Test_em', metrics['exact_match'], iter) except RuntimeError as e: logging.error(str(e)) except KeyboardInterrupt: break torch.save(model.cpu().state_dict(), launch_params['dump_filename']) pickle.dump(model_params, open(launch_params['args_filename'], 'wb')) logging.info('Model has been saved.')