Ejemplo n.º 1
0
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    batch_size = args.batch_size
    dataloader = DataLoader(data_source,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=pad_squad_data)
    ans_pred_tokens_samples = []
    vocab = data_source.vocab

    with torch.no_grad():
        for idx, (seq_input, ans_pos_list, tok_type) in enumerate(dataloader):
            start_pos, end_pos = model(seq_input, token_type_input=tok_type)
            target_start_pos, target_end_pos = [], []
            for item in ans_pos_list:
                _target_start_pos, _target_end_pos = item.to(device).split(
                    1, dim=-1)
                target_start_pos.append(_target_start_pos.squeeze(-1))
                target_end_pos.append(_target_end_pos.squeeze(-1))

            # in dev, pos come with three set. Use the first one to calculate loss here
            loss = (criterion(start_pos, target_start_pos[0]) +
                    criterion(end_pos, target_end_pos[0])) / 2
            total_loss += loss.item()

            start_pos = nn.functional.softmax(start_pos, dim=1).argmax(1)
            end_pos = nn.functional.softmax(end_pos, dim=1).argmax(1)

            # [TODO] remove '<unk>', '<cls>', '<pad>', '<MASK>' from ans_tokens and pred_tokens
            # Go through batch and convert ids to tokens list
            seq_input = seq_input.transpose(0,
                                            1)  # convert from (S, N) to (N, S)
            for num in range(0, seq_input.size(0)):
                if int(start_pos[num]) > int(end_pos[num]):
                    continue  # start pos is in front of end pos
                ans_tokens = []
                for _idx in range(len(target_end_pos)):
                    ans_tokens.append([
                        vocab.itos[int(seq_input[num][i])]
                        for i in range(target_start_pos[_idx][num],
                                       target_end_pos[_idx][num] + 1)
                    ])
                pred_tokens = [
                    vocab.itos[int(seq_input[num][i])]
                    for i in range(start_pos[num], end_pos[num] + 1)
                ]
                ans_pred_tokens_samples.append((ans_tokens, pred_tokens))

    return total_loss / (len(data_source) // batch_size), \
        compute_qa_exact(ans_pred_tokens_samples), \
        compute_qa_f1(ans_pred_tokens_samples)
Ejemplo n.º 2
0
def evaluate(data_source, vocab):
    model.eval()
    total_loss = 0.
    batch_size = args.batch_size
    dataloader = DataLoader(data_source,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=collate_batch)
    ans_pred_tokens_samples = []
    with torch.no_grad():
        for idx, (seq_input, ans_pos_list, tok_type) in enumerate(dataloader):
            start_pos, end_pos = model(seq_input, token_type_input=tok_type)
            target_start_pos, target_end_pos = [], []
            for item in ans_pos_list:
                _target_start_pos, _target_end_pos = item.to(device).split(
                    1, dim=-1)
                target_start_pos.append(_target_start_pos.squeeze(-1))
                target_end_pos.append(_target_end_pos.squeeze(-1))
            loss = (criterion(start_pos, target_start_pos[0]) +
                    criterion(end_pos, target_end_pos[0])) / 2
            total_loss += loss.item()
            start_pos = nn.functional.softmax(start_pos, dim=1).argmax(1)
            end_pos = nn.functional.softmax(end_pos, dim=1).argmax(1)
            seq_input = seq_input.transpose(0,
                                            1)  # convert from (S, N) to (N, S)
            for num in range(0, seq_input.size(0)):
                if int(start_pos[num]) > int(end_pos[num]):
                    continue  # start pos is in front of end pos
                ans_tokens = []
                for _idx in range(len(target_end_pos)):
                    ans_tokens.append([
                        vocab.itos[int(seq_input[num][i])]
                        for i in range(target_start_pos[_idx][num],
                                       target_end_pos[_idx][num] + 1)
                    ])
                pred_tokens = [
                    vocab.itos[int(seq_input[num][i])]
                    for i in range(start_pos[num], end_pos[num] + 1)
                ]
                ans_pred_tokens_samples.append((ans_tokens, pred_tokens))
    return total_loss / (len(data_source) // batch_size), \
        compute_qa_exact(ans_pred_tokens_samples), \
        compute_qa_f1(ans_pred_tokens_samples)