Exemplo n.º 1
0
def predict(logger,
            args,
            model,
            eval_dataloader,
            eval_examples,
            eval_features,
            device,
            write_prediction=True,
            varying_n_paragraphs=False):
    all_results = []

    if args.verbose:
        eval_dataloader = tqdm(eval_dataloader)

    for batch in eval_dataloader:
        example_indices = batch[-1]
        batch_to_feed = [t.to(device) for t in batch[:-1]]
        with torch.no_grad():
            batch_start_logits, batch_end_logits, batch_switch = model(
                batch_to_feed)
            assert len(batch_start_logits) == len(batch_end_logits) == len(
                batch_switch)
        for i, example_index in enumerate(example_indices):
            start_logits = batch_start_logits[i].detach().cpu().tolist()
            end_logits = batch_end_logits[i].detach().cpu().tolist()
            switch = batch_switch[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index.item()]
            unique_id = int(eval_feature.unique_id)
            all_results.append(
                RawResult(unique_id=unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits,
                          switch=switch))

    output_prediction_file = os.path.join(args.output_dir,
                                          args.prefix + "predictions.json")
    output_nbest_file = os.path.join(args.output_dir,
                                     args.prefix + "nbest_predictions.json")
    f1 = write_predictions(
        logger,
        eval_examples,
        eval_features,
        all_results,
        args.n_best_size if write_prediction else 1,
        args.do_lower_case,
        output_prediction_file if write_prediction else None,
        output_nbest_file if write_prediction else None,
        args.verbose,
        write_prediction=write_prediction,
        n_paragraphs=None if not varying_n_paragraphs else
        [int(n) for n in args.n_paragraphs.split(',')])
    return f1
Exemplo n.º 2
0
def predict(args, model, eval_dataloader, eval_examples, eval_features, device, \
            write_prediction=True):
    all_results = []

    assert args.model == "qa" or type(model) != list

    if args.model == 'qa':

        RawResult = collections.namedtuple(
            "RawResult", ["unique_id", "start_logits", "end_logits", "switch"])

        def _get_raw_results(model1):
            raw_results = []
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                example_indices = batch[-1]
                batch_to_feed = [t.to(device) for t in batch[:-1]]
                with torch.no_grad():
                    batch_start_logits, batch_end_logits, batch_switch = model1(
                        batch_to_feed)

                for i, example_index in enumerate(example_indices):
                    start_logits = batch_start_logits[i].detach().cpu().tolist(
                    )
                    end_logits = batch_end_logits[i].detach().cpu().tolist()
                    switch = batch_switch[i].detach().cpu().tolist()
                    eval_feature = eval_features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    raw_results.append(
                        RawResult(unique_id=unique_id,
                                  start_logits=start_logits,
                                  end_logits=end_logits,
                                  switch=switch))
            return raw_results

        if type(model) == list:
            all_raw_results = [_get_raw_results(m) for m in model]
            for i in range(len(all_raw_results[0])):
                result = [
                    all_raw_result[i] for all_raw_result in all_raw_results
                ]
                assert all(
                    [r.unique_id == result[0].unique_id for r in result])
                start_logits = sum([np.array(r.start_logits)
                                    for r in result]).tolist()
                end_logits = sum([np.array(r.end_logits)
                                  for r in result]).tolist()
                switch = sum([np.array(r.switch) for r in result]).tolist()
                all_results.append(
                    RawResult(unique_id=result[0].unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              switch=switch))
        else:
            all_results = _get_raw_results(model)

        output_prediction_file = os.path.join(args.output_dir,
                                              args.prefix + "predictions.json")
        output_nbest_file = os.path.join(
            args.output_dir, args.prefix + "nbest_predictions.json")

        f1 = write_predictions(
            logger,
            eval_examples,
            eval_features,
            all_results,
            args.n_best_size if write_prediction else 1,
            args.max_answer_length,
            args.do_lower_case,
            output_prediction_file if write_prediction else None,
            output_nbest_file if write_prediction else None,
            args.verbose_logging,
            write_prediction=write_prediction)
        return f1

    elif args.model == 'classifier':

        all_results = collections.defaultdict(list)
        all_results_per_key = collections.defaultdict(list)
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            example_indices = batch[-1]
            batch_to_feed = tuple(t.to(device) for t in batch[:-1])
            with torch.no_grad():
                batch_predicted_label = model(batch_to_feed)
            for i, example_index in enumerate(example_indices):
                logit = batch_predicted_label[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                eval_example = eval_examples[eval_feature.example_index]
                all_results[eval_example.qas_id].append( \
                                (logit, eval_feature.switch, eval_example.all_answers,
                                 eval_example.question_text))
        for example_index, results in all_results.items():
            example_index_, sent_index = example_index[:-2], example_index[-1]
            logit = [0, 0]
            switch = results[0][1]
            f1 = results[0][2]
            for (logit_, switch_, f1_, _) in results:
                logit[0] += logit_[0]
                logit[1] += logit_[1]
                assert switch == switch_ and f1 == f1_
            logit_indicator = (np.exp(logit) / sum(np.exp(logit))).tolist()[1]
            #logit_indicator = logit[1]/np.linalg.norm(logit) #/len(results)
            assert len(switch) == 1  #and switch[0] == int(f1>0.6)
            all_results_per_key[example_index_].append(( \
                            logit_indicator, f1, int(sent_index), results[0][3]))
        accs = {}
        sents_scores = {}
        for key, results in all_results_per_key.items():
            ranked_labels = sorted(results, key=lambda x: (-x[0], x[2]))
            acc = ranked_labels[0][1]
            accs[key] = acc
            sents_scores[key] = [
                r for r in sorted(results, key=lambda x: x[2])
            ]
        if write_prediction:
            output_prediction_file = os.path.join(
                args.output_dir, args.prefix + "class_scores.json")
            logger.info("Save score file into: " + output_prediction_file)
            with open(output_prediction_file, "w") as f:
                json.dump(sents_scores, f)

        return np.mean(list(accs.values()))

    elif args.model == "span-predictor":

        RawResult = collections.namedtuple("RawResult", [
            "unique_id", "start_logits", "end_logits", "keyword_logits",
            "switch"
        ])

        has_keyword = args.with_key
        em_all_results = collections.defaultdict(list)
        accs = []
        for batch in eval_dataloader:
            example_indices = batch[-1]
            batch_to_feed = [t.to(device) for t in batch[:-1]]
            with torch.no_grad():
                if has_keyword:
                    batch_start_logits, batch_end_logits, batch_keyword_logits, batch_switch = model(
                        batch_to_feed)
                else:
                    batch_start_logits, batch_end_logits, batch_switch = model(
                        batch_to_feed)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                switch = batch_switch[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                gold_start_positions = eval_feature.start_position
                gold_end_positions = eval_feature.end_position
                gold_switch = eval_feature.switch
                if has_keyword:
                    keyword_logits = batch_keyword_logits[i].detach().cpu(
                    ).tolist()
                    gold_keyword_positions = eval_feature.keyword_position
                else:
                    keyword_logits = None
                if gold_switch == [1]:
                    acc = np.argmax(switch) == 1
                elif has_keyword:
                    start_logits = start_logits[:len(eval_feature.tokens)]
                    end_logits = end_logits[:len(eval_feature.tokens)]
                    scores = []
                    for (i, s) in enumerate(start_logits):
                        for (j, e) in enumerate(end_logits[i:]):
                            for (k, key) in enumerate(keyword_logits[i:i + j +
                                                                     1]):
                                scores.append(((i, i + j, i + k), s + e + key))
                    scores = sorted(scores, key=lambda x: x[1], reverse=True)
                    acc = scores[0][0] in [(s, e, key) for (s, e, key) in \
                            zip(gold_start_positions, gold_end_positions, gold_keyword_positions)]
                else:
                    start_logits = start_logits[:len(eval_feature.tokens)]
                    end_logits = end_logits[:len(eval_feature.tokens)]
                    scores = []
                    for (i, s) in enumerate(start_logits):
                        for (j, e) in enumerate(end_logits[i:]):
                            scores.append(((i, i + j), s + e))
                    scores = sorted(scores, key=lambda x: x[1], reverse=True)
                    acc = scores[0][0] in zip(gold_start_positions,
                                              gold_end_positions)

                em_all_results[eval_feature.example_index].append(
                    (unique_id, acc))
                all_results.append(
                    RawResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              keyword_logits=keyword_logits,
                              switch=switch))

        output_prediction_file = os.path.join(args.output_dir,
                                              args.prefix + "predictions.json")
        output_nbest_file = os.path.join(
            args.output_dir, args.prefix + "nbest_predictions.json")

        for example_index, results in em_all_results.items():
            acc = sorted(results, key=lambda x: x[0])[0][1]
            accs.append(acc)

        if write_prediction:
            is_bridge = 'bridge' in args.predict_file
            is_intersec = 'intersec' in args.predict_file
            assert (is_bridge and not is_intersec) or (is_intersec
                                                       and not is_bridge)

            print("Accuracy", np.mean(accs))
            f1 = span_write_predictions(
                logger,
                eval_examples,
                eval_features,
                all_results,
                args.n_best_size if write_prediction else 1,
                args.max_answer_length,
                args.do_lower_case,
                output_prediction_file if write_prediction else None,
                output_nbest_file if write_prediction else None,
                args.verbose_logging,
                write_prediction=write_prediction,
                with_key=args.with_key,
                is_bridge=is_bridge)

        return np.mean(accs)

    raise NotImplementedError()
Exemplo n.º 3
0
def predict(args, model, eval_dataloader, eval_examples, eval_features, device, \
            write_prediction=True):
    all_results = []

    def _get_raw_results(model1):
        raw_results = []
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            example_indices = batch[-1]
            batch_to_feed = [t.to(device) for t in batch[:-1]]
            with torch.no_grad():
                batch_start_logits, batch_end_logits, batch_switch = model1(
                    batch_to_feed)

            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                switch = batch_switch[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                raw_results.append(
                    RawResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              switch=switch))
        return raw_results

    if type(model) == list:
        all_raw_results = [_get_raw_results(m) for m in model]
        for i in range(len(all_raw_results[0])):
            result = [all_raw_result[i] for all_raw_result in all_raw_results]
            assert all([r.unique_id == result[0].unique_id for r in result])
            start_logits = sum([np.array(r.start_logits)
                                for r in result]).tolist()
            end_logits = sum([np.array(r.end_logits) for r in result]).tolist()
            switch = sum([np.array(r.switch) for r in result]).tolist()
            all_results.append(
                RawResult(unique_id=result[0].unique_id,
                          start_logits=start_logits,
                          end_logits=end_logits,
                          switch=switch))
    else:
        all_results = _get_raw_results(model)

    output_prediction_file = os.path.join(args.output_dir,
                                          args.prefix + "predictions.json")
    output_nbest_file = os.path.join(args.output_dir,
                                     args.prefix + "nbest_predictions.json")

    f1 = write_predictions(
        logger,
        eval_examples,
        eval_features,
        all_results,
        args.n_best_size if write_prediction else 1,
        args.max_answer_length,
        args.do_lower_case,
        output_prediction_file if write_prediction else None,
        output_nbest_file if write_prediction else None,
        args.verbose_logging,
        write_prediction=write_prediction)
    return f1