def train_fever(model, index, config, args, best_score, optimizer, scheduler): model.train() dataset = FEVERDataset(config["system"]['train_data'], config["model"], True, args.tokenizer) device = args.device train_sampler = RandomSampler(dataset) dataloader = DataLoader(dataset=dataset, sampler=train_sampler, batch_size=config['training']['train_batch_size'], collate_fn=batcher_fever(device), num_workers=0) print_loss = 0 criterion = CrossEntropyLoss() bce_loss_logits = nn.BCEWithLogitsLoss() for step, batch in enumerate(tqdm(dataloader)): logits_score, logits_pred = model.network(batch, device) if args.fp16: node_loss = bce_loss_logits(logits_score, batch[1].half()) else: node_loss = bce_loss_logits(logits_score, batch[1]) logits_score = F.softmax(logits_score) logits_pred = F.softmax(logits_pred, dim=1) final_score = torch.mm(logits_score.unsqueeze(0), logits_pred) pred_loss = criterion(final_score, batch[2]) loss = pred_loss + node_loss if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps print_loss += loss.data.cpu().numpy() if args.fp16: optimizer.backward(loss) else: loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() optimizer.zero_grad() if (step + 1) % args.checkpoint == 0: logging.info("********* loss ************{}".format(print_loss)) print_loss = 0 model.eval() eval_file = config['system']['validation_data'] auc, _ = evaluation_fever(model, eval_file, config, args) if auc > best_score: best_score = auc model.save( os.path.join( base_dir, config['name'], "saved_models/model_finetuned_epoch_{}.pt".format(0))) model.train() return best_score
### Model Evaluation if args.test: model.load(os.path.join(base_dir, args.in_model)) model.eval() if os.path.exists('atts.txt'): os.remove('atts.txt') for eval_file in [ config["system"]['train_data'], config["system"]['validation_data'], config["system"]['test_data'] ]: if config['task'] == 'hotpotqa': final_pred = evaluation_hotpot(model, eval_file, config, args) json.dump(final_pred, open("out_dev.json", "w")) elif config['task'] == 'fever': auc, pred_dict, f1_mac = evaluation_fever( model, eval_file, config, args) with open( 'outputs/preds_' + args.in_model.split('.')[0] + '_' + eval_file.split('/')[-1], 'w+') as f: f.write(json.dumps(pred_dict)) ### Model Training else: if args.in_model: model.load(os.path.join(base_dir, args.in_model)) # final layers fine-tuning """for param in model.network.parameters(): param.requires_grad = False