Ejemplo n.º 1
0
 def get_training_data(self,
                       train_features,
                       example_ids,
                       out,
                       token_types=None):
     output = mx.nd.split(out, axis=2, num_outputs=2)
     example_ids = example_ids.asnumpy().tolist()
     pred_start = output[0].reshape((0, -3)).asnumpy()
     pred_end = output[1].reshape((0, -3)).asnumpy()
     raw_data = []
     for example_id, start, end in zip(example_ids, pred_start, pred_end):
         results = [PredResult(start=start, end=end)]
         features = train_features[example_id]
         label = 0 if features[0].is_impossible else 1
         prediction, score_diff, top_predict = predict(
             features=features,
             results=results,
             tokenizer=self.tokenizer,
             max_answer_length=self.max_answer_length,
             n_best_size=self.n_best_size,
             version_2=self.version_2)
         non_empty_top = 1. if top_predict else 0.
         # print(prediction, "," , top_predict, ",", features[0].orig_answer_text)
         raw_data.append([score_diff, non_empty_top, label])
     return raw_data
Ejemplo n.º 2
0
 def parse_sentences(self,
                     train_features,
                     example_ids,
                     out,
                     token_types=None):
     output = mx.nd.split(out, axis=2, num_outputs=2)
     example_ids = example_ids.asnumpy().tolist()
     pred_start = output[0].reshape((0, -3)).asnumpy()
     pred_end = output[1].reshape((0, -3)).asnumpy()
     raw_data = []
     for example_id, start, end in zip(example_ids, pred_start, pred_end):
         results = [PredResult(start=start, end=end)]
         features = train_features[example_id]
         label = 0 if features[0].is_impossible else 1
         context_text = ' '.join(features[0].doc_tokens)
         question_text = features[0].question_text
         answer_text = features[0].orig_answer_text
         prediction, _, _ = predict(  # TODO: use this more wisely, for example, GAN
             features=features,
             results=results,
             tokenizer=self.tokenizer,
             max_answer_length=self.max_answer_length,
             n_best_size=self.n_best_size,
             version_2=self.version_2,
             offsets=self.offsets)
         # if len(prediction) == 0:
         #     continue # not validating for n/a output
         if self.extract_sentence:
             sentences = list(
                 filter(lambda x: len(x.strip()) > 0,
                        re.split(pattern, context_text)))
             if label == 1:
                 answer_sentence = self.find_sentence(
                     sentences, answer_text)
                 raw_data.append([
                     answer_sentence + '. ' + question_text, answer_text,
                     label
                 ])
             elif len(prediction) > 0:
                 sentence_text = self.find_sentence(sentences, prediction)
                 raw_data.append([
                     sentence_text + '. ' + question_text, prediction, label
                 ])
         else:
             first_part = context_text + '. ' + question_text
             if label == 1:
                 raw_data.append([first_part, answer_text, label])
             elif len(prediction) > 0:
                 raw_data.append([first_part, prediction, label])
     # dataset = VerifierDataset(raw_data)
     # return dataset
     return raw_data
Ejemplo n.º 3
0
 def parse_sentences(self, train_features, example_ids, out):
     output = mx.nd.split(out, axis=2, num_outputs=2)
     example_ids = example_ids.asnumpy().tolist()
     pred_start = output[0].reshape((0, -3)).asnumpy()
     pred_end = output[1].reshape((0, -3)).asnumpy()
     raw_data = []
     for example_id, start, end in zip(example_ids, pred_start, pred_end):
         results = [PredResult(start=start, end=end)]
         features = train_features[example_id]
         label = 0 if features[0].is_impossible else 1
         # if features[0].is_impossible:
         #     prediction = ""
         '''
         prediction, _ = predict( # TODO: use this more wisely, for example, GAN
             features=features,
             results=results,
             tokenizer=self.tokenizer,
             max_answer_length=self.max_answer_length,
             null_score_diff_threshold=self.null_score_diff_threshold,
             n_best_size=self.n_best_size,
             version_2=self.version_2)
         '''
         context_text = ' '.join(features[0].doc_tokens)
         sentences = context_text.strip
         question_text = features[0].question_text
         answer_text = features[0].orig_answer_text
         sentences =  list(filter(lambda x: len(x.strip())>0, re.split(pattern, context_text) ))
         if label == 1:
             sentence_text = ''
             for s in sentences:
                 if s.find(answer_text) != -1:
                     sentence_text = s
                     break
         else:
             sentence_text = random.choice(sentences)
             answer_text = random.choice(sentence_text.split())
         # raw_data.append([question_text, prediction, label]) # TODO: might should use whole context if answer not available
         # raw_data.append([question_text, answer_text, label])
         first_part = sentence_text + ' ' + question_text
         second_part = answer_text
         raw_data.append([first_part, second_part, label])
     dataset = VerifierDataset(raw_data)
     return dataset
Ejemplo n.º 4
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    dev_dataset = dev_data.transform(SQuADTransform(
        copy.copy(tokenizer),
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_pad=False,
        is_training=False)._transform,
                                     lazy=False)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(copy.copy(tokenizer),
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=False,
                       is_training=False))
    log.info('The number of examples after preprocessing:{}'.format(
        len(dev_data_transform)))

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              num_workers=4,
                                              batch_size=test_batch_size,
                                              shuffle=False,
                                              last_batch='keep')

    log.info('start prediction')

    all_results = collections.defaultdict(list)

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)
        out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx))

        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(PredResult(start=start, end=end))

    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()

    for features in dev_dataset:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id

        prediction, _ = predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
            max_answer_length=max_answer_length,
            null_score_diff_threshold=null_score_diff_threshold,
            n_best_size=n_best_size,
            version_2=version_2)

        all_predictions[example_qas_id] = prediction

    with io.open(os.path.join(output_dir, 'predictions.json'),
                 'w',
                 encoding='utf-8') as fout:
        data = json.dumps(all_predictions, ensure_ascii=False)
        fout.write(data)

    if version_2:
        log.info(
            'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0'
        )
    else:
        F1_EM = get_F1_EM(dev_data, all_predictions)
        log.info(F1_EM)
Ejemplo n.º 5
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Start Evaluation')

    all_results = collections.defaultdict(list)

    if VERIFIER_ID == 2:
        all_pre_na_prob = collections.defaultdict(list)

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)

        cls_mask = mx.nd.zeros(token_types.shape)
        sep_mask_1 = mx.nd.zeros(token_types.shape)
        sep_mask_2 = mx.nd.zeros(token_types.shape)
        cls_mask[:, 0] = 1.
        range_row_index = mx.nd.array(np.arange(len(example_ids)))
        valid_query_length = (1 - token_types).sum(axis=1)
        sep_mask_1[range_row_index, valid_query_length - 1] = 1.
        sep_mask_2[range_row_index, valid_length - 1] = 1.
        additional_masks = (cls_mask.astype('float32').as_in_context(ctx),
                            sep_mask_1.astype('float32').as_in_context(ctx),
                            sep_mask_2.astype('float32').as_in_context(ctx))

        out, bert_out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx),
            additional_masks)

        if VERIFIER_ID == 2:
            has_answer_tmp = verifier.evaluate(dev_features, example_ids, out,
                                               token_types,
                                               bert_out).asnumpy().tolist()

        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(PredResult(start=start, end=end))
        if VERIFIER_ID == 2:
            for example_id, has_ans_prob in zip(example_ids, has_answer_tmp):
                all_pre_na_prob[example_id].append(has_ans_prob)

    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()

    for features in dev_dataset:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id
        # prediction2 is likely to be empty when in version_2
        prediction, score_diff, best_pred = predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
            max_answer_length=max_answer_length,
            n_best_size=n_best_size,
            version_2=version_2,
            offsets=offsets)
        # print(score_diff, null_score_diff_threshold, features[0].is_impossible) # debug
        # verifier
        if version_2 and prediction != '':
            # threshold serves as the basic verifier

            if score_diff > null_score_diff_threshold:
                answerable = 0.
            else:
                answerable = 1.

            if VERIFIER_ID == 0:
                best_pred_score = 1. if best_pred else 0.
                has_ans_prob = verifier.evaluate(score_diff, best_pred_score)
                # print(features[0].is_impossible)
            elif VERIFIER_ID == 1:
                has_ans_prob = verifier.evaluate(features, prediction)
            elif VERIFIER_ID == 2:
                has_ans_prob_list = all_pre_na_prob[features[0].example_id]
                has_ans_prob = sum(has_ans_prob_list) / max(
                    len(has_ans_prob_list), 1)
            else:
                has_ans_prob = 1.

            if args.verifier_mode == "takeover":
                answerable = has_ans_prob
            elif args.verifier_mode == "joint":
                answerable = answerable * has_ans_prob
            elif args.verifier_mode == "all":
                answerable = (answerable + has_ans_prob) * 0.5

            if answerable < answerable_threshold:
                prediction = ""

        all_predictions[example_qas_id] = prediction
        # the form of hashkey - answer string

    with io.open(os.path.join(output_dir, 'predictions.json'),
                 'w',
                 encoding='utf-8') as fout:
        data = json.dumps(all_predictions, ensure_ascii=False)
        fout.write(data)

    if version_2:
        log.info(
            'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0'
        )
    else:
        F1_EM = get_F1_EM(dev_data, all_predictions)
        log.info(F1_EM)
Ejemplo n.º 6
0
    def parse_sentences(self, all_features, example_ids, out, token_types,
                        bert_out):
        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()
        verifier_input_shape = (bert_out.shape[0],
                                bert_out.shape[1] + self.max_answer_length,
                                bert_out.shape[2])
        verifier_input = mx.nd.zeros(verifier_input_shape, ctx=self.ctx)
        labels = mx.nd.array([[0 if all_features[eid][0].is_impossible else 1] \
                                        for eid in example_ids]).as_in_context(self.ctx)
        labels_pred = mx.nd.zeros(labels.shape, ctx=self.ctx)
        for idx, data in enumerate(
                zip(example_ids, pred_start, pred_end, token_types)):
            example_id, start, end, token = data
            results = [PredResult(start=start, end=end)]
            features = all_features[example_id]
            prediction = predict_span(features=features,
                                      results=results,
                                      max_answer_length=self.max_answer_length,
                                      n_best_size=self.n_best_size,
                                      offsets=self.offsets,
                                      version_2=self.version_2)
            num_total_tokens = len(features[0].tokens)
            num_query_tokens = int((1 - token).sum().max().asscalar()) - 2
            num_contx_tokens = num_total_tokens - num_query_tokens - 3
            num_answr_tokens = prediction[1] - prediction[0] + 1

            if self.extract_sentence:
                # the sentence
                if num_answr_tokens == 0:
                    sentence_idx = (num_query_tokens + 2,
                                    num_contx_tokens + num_query_tokens + 2)
                    num_sentc_tokens = num_contx_tokens
                else:
                    sentence_begin = num_query_tokens + 2
                    sentence_end = num_contx_tokens + num_query_tokens + 2
                    sequence_tokens = features[0].tokens
                    sentence_ends_included = { i \
                                                for i in range(len(sequence_tokens)) \
                                                if sequence_tokens[i].find('.') != -1 or sequence_tokens[i].find('?') != -1 or sequence_tokens[i].find('!') != -1}
                    sentence_ends_included.add(num_total_tokens -
                                               2)  # the ending
                    sentence_begins_included = {
                        i + 1
                        for i in sentence_ends_included
                    }
                    if num_total_tokens - 1 in sentence_begins_included:
                        sentence_begins_included.remove(num_total_tokens - 1)
                    if num_query_tokens + 1 in sentence_begins_included:
                        sentence_begins_included.remove(num_query_tokens + 1)
                    sentence_begins_included.add(1)
                    sentence_begins_included.add(num_query_tokens + 2)
                    begin_idxs = sorted(list(sentence_begins_included))
                    end_idxs = sorted(list(sentence_ends_included))
                    for i in range(len(begin_idxs) - 1):
                        if begin_idxs[i] <= prediction[0] and begin_idxs[
                                i + 1] > prediction[0]:
                            sentence_begin = begin_idxs[i]
                            break
                    for i in range(len(end_idxs) - 1):
                        if end_idxs[i] < prediction[1] and end_idxs[
                                i + 1] >= prediction[1]:
                            sentence_end = end_idxs[i + 1]
                            break
                    sentence_idx = (sentence_begin, sentence_end)
                    num_sentc_tokens = sentence_end - sentence_begin + 1
                # the beginning
                verifier_input[idx, 0, :] = bert_out[idx, 0, :]
                # the sentence embedding
                verifier_input[idx, 1:num_sentc_tokens + 1, :] = bert_out[
                    idx, sentence_idx[0]:sentence_idx[1] + 1, :]
                # the query embedding
                verifier_input[idx, num_sentc_tokens+1: num_query_tokens+num_sentc_tokens+1, :] \
                                    = bert_out[idx, 1:num_query_tokens+1, :]
                # the separater
                verifier_input[idx, num_query_tokens + num_sentc_tokens +
                               1, :] = bert_out[idx, num_query_tokens + 1, :]
                # the answer
                if num_answr_tokens > 0:
                    verifier_input[idx, num_query_tokens+num_sentc_tokens+2:num_answr_tokens+num_query_tokens+num_sentc_tokens+2, :] \
                                    = bert_out[idx, prediction[0]:prediction[1]+1,:]
                # the ending
                verifier_input[idx, num_answr_tokens+num_query_tokens+num_sentc_tokens+2, :] \
                                    = bert_out[idx, num_query_tokens + num_contx_tokens+2, :]
            else:
                # the beginning
                verifier_input[idx, 0, :] = bert_out[idx, 0, :]
                # the context embedding
                verifier_input[idx, 1:num_contx_tokens +
                               1, :] = bert_out[idx, num_query_tokens +
                                                2:num_contx_tokens +
                                                num_query_tokens + 2, :]
                # the query embedding
                verifier_input[idx, num_contx_tokens+1: num_query_tokens+num_contx_tokens+1, :] \
                                    = bert_out[idx, 1:num_query_tokens+1, :]
                # the separater
                verifier_input[idx, num_query_tokens + num_contx_tokens +
                               1, :] = bert_out[idx, num_query_tokens + 1, :]
                # the answer
                if num_answr_tokens > 0:
                    verifier_input[idx, num_query_tokens+num_contx_tokens+2:num_answr_tokens+num_query_tokens+num_contx_tokens+2, :] \
                                    = bert_out[idx, prediction[0]:prediction[1]+1,:]
                # the ending
                verifier_input[idx, num_answr_tokens+num_query_tokens+num_contx_tokens+2, :] \
                                    = bert_out[idx, num_query_tokens + num_contx_tokens+2, :]
                # the predicted answerability
        return verifier_input, labels
Ejemplo n.º 7
0
def evaluate():
    """Evaluate the model on validation dataset.
    """
    log.info('Loading dev data...')
    if version_2:
        dev_data = SQuAD('dev', version='2.0')
    else:
        dev_data = SQuAD('dev', version='1.1')
    if args.debug:
        sampled_data = dev_data[:10]  # [dev_data[0], dev_data[1], dev_data[2]]
        dev_data = mx.gluon.data.SimpleDataset(sampled_data)
    log.info('Number of records in dev data:{}'.format(len(dev_data)))

    dev_dataset = dev_data.transform(SQuADTransform(
        copy.copy(tokenizer),
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_pad=True,
        is_training=True)._transform,
                                     lazy=False)

    dev_data_transform, _ = preprocess_dataset(
        dev_data,
        SQuADTransform(copy.copy(tokenizer),
                       max_seq_length=max_seq_length,
                       doc_stride=doc_stride,
                       max_query_length=max_query_length,
                       is_pad=True,
                       is_training=True))

    # refer to evaluation process
    # for feat in train_dataset:
    #     print(feat[0].example_id)
    #     print(feat[0].tokens)
    #     print(feat[0].token_to_orig_map)
    #     input()
    # exit(0)

    dev_features = {
        features[0].example_id: features
        for features in dev_dataset
    }

    #for line in train_data_transform:
    #    print(line)
    #    input()

    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform,
                                              batchify_fn=batchify_fn,
                                              batch_size=test_batch_size,
                                              num_workers=4,
                                              shuffle=True)
    '''

    dev_dataset = dev_data.transform(
        SQuADTransform(
            copy.copy(tokenizer),
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_pad=False,
            is_training=False)._transform, lazy=False)

    # for feat in dev_dataset:
    #     print(feat[0].example_id)
    #     print(feat[0].tokens)
    #     print(feat[0].token_to_orig_map)
    #     input()
    # exit(0)

    dev_features = {features[0].example_id: features for features in dev_dataset}

    dev_data_transform, _ = preprocess_dataset(
        dev_data, SQuADTransform(
            copy.copy(tokenizer),
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_pad=False,
            is_training=False))
    log.info('The number of examples after preprocessing:{}'.format(
        len(dev_data_transform)))

    dev_dataloader = mx.gluon.data.DataLoader(
        dev_data_transform,
        batchify_fn=batchify_fn,
        num_workers=4, batch_size=test_batch_size,
        shuffle=False, last_batch='keep')
    '''
    log.info('start prediction')

    all_results = collections.defaultdict(list)

    if args.verify and VERIFIER_ID in [2, 3]:
        all_pre_na_prob = collections.defaultdict(list)
    else:
        all_pre_na_prob = None

    epoch_tic = time.time()
    total_num = 0
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        total_num += len(inputs)
        out = net(
            inputs.astype('float32').as_in_context(ctx),
            token_types.astype('float32').as_in_context(ctx),
            valid_length.astype('float32').as_in_context(ctx))

        if all_pre_na_prob is not None:
            has_answer_tmp = verifier.evaluate(dev_features, example_ids,
                                               out).asnumpy().tolist()

        output = mx.nd.split(out, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = output[0].reshape((0, -3)).asnumpy()
        pred_end = output[1].reshape((0, -3)).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(PredResult(start=start, end=end))
        if all_pre_na_prob is not None:
            for example_id, has_ans_prob in zip(example_ids, has_answer_tmp):
                all_pre_na_prob[example_id].append(has_ans_prob)

    epoch_toc = time.time()
    log.info('Time cost={:.2f} s, Thoughput={:.2f} samples/s'.format(
        epoch_toc - epoch_tic, total_num / (epoch_toc - epoch_tic)))

    log.info('Get prediction results...')

    all_predictions = collections.OrderedDict()

    for features in dev_dataset:
        results = all_results[features[0].example_id]
        example_qas_id = features[0].qas_id

        if all_pre_na_prob is not None:
            has_ans_prob_list = all_pre_na_prob[features[0].example_id]
            has_ans_prob = sum(has_ans_prob_list) / max(
                len(has_ans_prob_list), 1)
            if has_ans_prob < 0.5:
                prediction = ""
                all_predictions[example_qas_id] = prediction
                continue

        prediction, _ = predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=lower),
            max_answer_length=max_answer_length,
            null_score_diff_threshold=null_score_diff_threshold,
            n_best_size=n_best_size,
            version_2=version_2)

        if args.verify and VERIFIER_ID == 1:
            if len(prediction) > 0:
                has_answer = verifier.evaluate(features, prediction)
                if not has_answer:
                    prediction = ""

        all_predictions[example_qas_id] = prediction
        # the form of hashkey - answer string

    with io.open(os.path.join(output_dir, 'predictions.json'),
                 'w',
                 encoding='utf-8') as fout:
        data = json.dumps(all_predictions, ensure_ascii=False)
        fout.write(data)

    if version_2:
        log.info(
            'Please run evaluate-v2.0.py to get evaluation results for SQuAD 2.0'
        )
    else:
        F1_EM = get_F1_EM(dev_data, all_predictions)
        log.info(F1_EM)