def main(): torch.manual_seed(94) batch_size = 4 train_data_list = load_data_list('./train_data_list') train_set = QADataset(train_data_list) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) dev_data_list = load_data_list('./dev_data_list') dev_set = QADataset(dev_data_list) dev_loader = DataLoader(dev_set, batch_size=batch_size, shuffle=True) model = QAModel() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) EPOCHS = 6 loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=5e-5) loss = 0 for epoch in range(EPOCHS): cum_loss = 0 print(f'epoch: {epoch}') ite = 0 model.train() for i, (question_id, context, question, answerable, ans_start, ans_end) in enumerate(tqdm(train_loader)): ite = i + 1 pt = tokenizer(context, question, return_tensors='pt') bs = len(context) mask = torch.zeros(bs, 512).bool() mask[:, 466:] = True if torch.cuda.is_available(): answerable = answerable.cuda() ans_start = ans_start.cuda() ans_end = ans_end.cuda() pt['input_ids'] = pt['input_ids'].cuda() pt['token_type_ids'] = pt['token_type_ids'].cuda() pt['attention_mask'] = pt['attention_mask'].cuda() mask = mask.cuda() target = torch.cat((ans_start.unsqueeze(1), ans_end.unsqueeze(1)), dim=1).to(device) optimizer.zero_grad() output = model(pt) output[:, :, 0].masked_fill_(mask, float('-inf')) output[:, :, 1].masked_fill_(mask, float('-inf')) loss = loss_fn(output, target) cum_loss += float(loss) loss.backward() optimizer.step() print(cum_loss) model.eval() dev_loss = 0 dev_ite = 0 for i, (question_id, context, question, answerable, ans_start, ans_end) in enumerate(tqdm(dev_loader)): bs = len(context) mask = torch.zeros(bs, 512).bool() mask[:, 466:] = 1 with torch.no_grad(): dev_ite = i + 1 pt = tokenizer(context, question, return_tensors='pt') if torch.cuda.is_available(): answerable = answerable.cuda() ans_start = ans_start.cuda() ans_end = ans_end.cuda() pt['input_ids'] = pt['input_ids'].cuda() pt['token_type_ids'] = pt['token_type_ids'].cuda() pt['attention_mask'] = pt['attention_mask'].cuda() mask = mask.cuda() target = torch.cat( (ans_start.unsqueeze(1), ans_end.unsqueeze(1)), dim=1).to(device) output = model(pt) output[:, :, 0].masked_fill_(mask, float('-inf')) output[:, :, 1].masked_fill_(mask, float('-inf')) loss = loss_fn(output, target) dev_loss += float(loss) print('avg_train_loss: {}, avg_dev_loss: {}'.format( cum_loss / ite, dev_loss / dev_ite)) SAVED_MDL_PATH = './model/' + str(epoch + 1) + '.pt' #torch.save(model.state_dict(), SAVED_MDL_PATH) torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss }, SAVED_MDL_PATH) print('model {} saved'.format(SAVED_MDL_PATH))
def predict(MDL_PATH, DATA_PATH): batch_size = 4 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = QAModel() model.to(device) optimizer = optim.SGD(model.parameters(), lr=3e-5) checkpoint = torch.load(MDL_PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() dev_data_list = load_data_list(DATA_PATH) dev_set = QADataset(dev_data_list) dev_loader = DataLoader(dev_set, batch_size=batch_size, shuffle=False) print('run prediction') dic = {} for i, (question_id, context, question, answerable, ans_start, ans_end) in enumerate(tqdm(dev_loader)): bs = len(context) mask = torch.zeros(bs, 512).bool() mask[:, 466:] = 1 with torch.no_grad(): dev_ite = i + 1 pt = tokenizer(context, question, return_tensors='pt') if torch.cuda.is_available(): answerable = answerable.cuda() ans_start = ans_start.cuda() ans_end = ans_end.cuda() pt['input_ids'] = pt['input_ids'].cuda() pt['token_type_ids'] = pt['token_type_ids'].cuda() pt['attention_mask'] = pt['attention_mask'].cuda() mask = mask.cuda() target = torch.cat((ans_start.unsqueeze(1), ans_end.unsqueeze(1)), dim=1).to(device) output = model(pt) # shape (batch_size, 512, 2) output[:, :, 0].masked_fill_(mask, float('-inf')) output[:, :, 1].masked_fill_(mask, float('-inf')) for batch_idx, sample in enumerate( output): # sample: shape (512, 2) start = sample[:, 0] # start: shape (512) end = sample[:, 1] start_candidates = torch.topk(start, k=30) end_candidates = torch.topk(end, k=30) ans_candidates = [] scores = [] for i, s in enumerate(start_candidates[1]): for j, e in enumerate(end_candidates[1]): if e == s and e == 0: ans_candidates.append((s, e)) scores.append(start_candidates[0][i] + end_candidates[0][j]) if s < e and e - s <= 30: ans_candidates.append((s, e)) scores.append(start_candidates[0][i] + end_candidates[0][j]) results = list(zip(scores, ans_candidates)) results.sort() results.reverse() if results[0][1][0] == 0: dic[question_id[batch_idx]] = "" else: s, e = results[0][1][0], results[0][1][1] ids = pt['input_ids'][batch_idx][s:e] dic[question_id[batch_idx]] = tokenizer.decode( ids).replace(" ", "") with open('prediction.json', 'w') as fp: json.dump(dic, fp)