def evaluate(data_source): # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0. batch_size = args.batch_size dataloader = DataLoader(data_source, batch_size=batch_size, shuffle=True, collate_fn=pad_squad_data) ans_pred_tokens_samples = [] vocab = data_source.vocab with torch.no_grad(): for idx, (seq_input, ans_pos_list, tok_type) in enumerate(dataloader): start_pos, end_pos = model(seq_input, token_type_input=tok_type) target_start_pos, target_end_pos = [], [] for item in ans_pos_list: _target_start_pos, _target_end_pos = item.to(device).split( 1, dim=-1) target_start_pos.append(_target_start_pos.squeeze(-1)) target_end_pos.append(_target_end_pos.squeeze(-1)) # in dev, pos come with three set. Use the first one to calculate loss here loss = (criterion(start_pos, target_start_pos[0]) + criterion(end_pos, target_end_pos[0])) / 2 total_loss += loss.item() start_pos = nn.functional.softmax(start_pos, dim=1).argmax(1) end_pos = nn.functional.softmax(end_pos, dim=1).argmax(1) # [TODO] remove '<unk>', '<cls>', '<pad>', '<MASK>' from ans_tokens and pred_tokens # Go through batch and convert ids to tokens list seq_input = seq_input.transpose(0, 1) # convert from (S, N) to (N, S) for num in range(0, seq_input.size(0)): if int(start_pos[num]) > int(end_pos[num]): continue # start pos is in front of end pos ans_tokens = [] for _idx in range(len(target_end_pos)): ans_tokens.append([ vocab.itos[int(seq_input[num][i])] for i in range(target_start_pos[_idx][num], target_end_pos[_idx][num] + 1) ]) pred_tokens = [ vocab.itos[int(seq_input[num][i])] for i in range(start_pos[num], end_pos[num] + 1) ] ans_pred_tokens_samples.append((ans_tokens, pred_tokens)) return total_loss / (len(data_source) // batch_size), \ compute_qa_exact(ans_pred_tokens_samples), \ compute_qa_f1(ans_pred_tokens_samples)
def evaluate(data_source, vocab): model.eval() total_loss = 0. batch_size = args.batch_size dataloader = DataLoader(data_source, batch_size=batch_size, shuffle=True, collate_fn=collate_batch) ans_pred_tokens_samples = [] with torch.no_grad(): for idx, (seq_input, ans_pos_list, tok_type) in enumerate(dataloader): start_pos, end_pos = model(seq_input, token_type_input=tok_type) target_start_pos, target_end_pos = [], [] for item in ans_pos_list: _target_start_pos, _target_end_pos = item.to(device).split( 1, dim=-1) target_start_pos.append(_target_start_pos.squeeze(-1)) target_end_pos.append(_target_end_pos.squeeze(-1)) loss = (criterion(start_pos, target_start_pos[0]) + criterion(end_pos, target_end_pos[0])) / 2 total_loss += loss.item() start_pos = nn.functional.softmax(start_pos, dim=1).argmax(1) end_pos = nn.functional.softmax(end_pos, dim=1).argmax(1) seq_input = seq_input.transpose(0, 1) # convert from (S, N) to (N, S) for num in range(0, seq_input.size(0)): if int(start_pos[num]) > int(end_pos[num]): continue # start pos is in front of end pos ans_tokens = [] for _idx in range(len(target_end_pos)): ans_tokens.append([ vocab.itos[int(seq_input[num][i])] for i in range(target_start_pos[_idx][num], target_end_pos[_idx][num] + 1) ]) pred_tokens = [ vocab.itos[int(seq_input[num][i])] for i in range(start_pos[num], end_pos[num] + 1) ] ans_pred_tokens_samples.append((ans_tokens, pred_tokens)) return total_loss / (len(data_source) // batch_size), \ compute_qa_exact(ans_pred_tokens_samples), \ compute_qa_f1(ans_pred_tokens_samples)