Ejemplo n.º 1
0
def get_predictions(example, result, n_best_size, max_answer_length):
    """
    This function has been mostly copied from run_squad.py.
    Unfortunate, but I needed to return local variables from that function.
    :param all_examples:
    :param all_features:
    :param all_results:
    :param n_best_size:
    :param max_answer_length:
    :param null_score_diff_threshold:
    :return:
    """

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score

    # start_indexes = run_squad._get_best_indexes(result.start_logits, n_best_size)
    # end_indexes = run_squad._get_best_indexes(result.end_logits, n_best_size)
    start_indexes, end_indexes = get_nbest_bounds_from_membership(result.membership, n_best_size)
    for start_index in start_indexes:
        for end_index in end_indexes:
            # We could hypothetically create invalid predictions, e.g., predict
            # that the start of the span is in the question. We throw out all
            # invalid predictions.
            if start_index >= len(example.tokens):
                continue
            if end_index >= len(example.tokens):
                continue
            if start_index not in example.token_to_orig_map:
                continue
            if end_index not in example.token_to_orig_map:
                continue
            if not example.token_is_max_context.get(start_index, False):
                continue
            if end_index < start_index:
                continue
            length = end_index - start_index + 1
            if length > max_answer_length:
                continue
            prelim_predictions.append(
                _PrelimPrediction(
                    # feature_index=feature_index,
                    feature_index=0,
                    start_index=start_index,
                    end_index=end_index,
                    start_logit=1,
                    end_logit=1))
                    # start_logit=result.start_logits[start_index],
                    # end_logit=result.end_logits[end_index]))

    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_logit", "end_logit", "start_offset", "end_offset"])

    seen_predictions = set()
    nbest = []
    for pred in prelim_predictions:
        if len(nbest) >= n_best_size:
          break
        # feature = features[pred.feature_index]
        if pred.start_index > 0:  # this is a non-null prediction
            start_offset = example.token_to_orig_map[pred.start_index]
            end_offset = example.token_to_orig_map[pred.end_index] + 1
            try:
                final_text = ' '.join(example.doc_tokens[start_offset:end_offset])
            except AttributeError:
                final_text = ''
            if (start_offset, end_offset) in seen_predictions:
                continue
            seen_predictions.add((start_offset, end_offset))
        else:
            final_text = ""
            start_offset = 0
            end_offset = 0
            seen_predictions.add((0, 0))

        nbest.append(
            _NbestPrediction(
                text=final_text,
                start_logit=pred.start_logit,
                end_logit=pred.end_logit,
                start_offset=start_offset,
                end_offset=end_offset))

    # if we didn't include the empty option in the n-best, include it
    if (0, 0) not in seen_predictions:
        nbest.append(
            _NbestPrediction(
                text="", start_logit=null_start_logit,
                end_logit=null_end_logit,
                start_offset=0,
                end_offset=0))
    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
        nbest.append(
            _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0, start_offset=0, end_offset=0))

    assert len(nbest) >= 1

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
        total_scores.append(entry.start_logit + entry.end_logit)
        if not best_non_null_entry:
            if entry.text:
                best_non_null_entry = entry

    probs = run_squad._compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
        output = collections.OrderedDict()
        output["text"] = entry.text
        output["start_offset"] = entry.start_offset
        output["end_offset"] = entry.end_offset
        output["probability"] = probs[i]
        output["start_logit"] = entry.start_logit
        output["end_logit"] = entry.end_logit
        nbest_json.append(output)

    assert len(nbest_json) >= 1

    return nbest_json
    def write_predictions(self, all_examples, all_features, all_results,
                          n_best_size, max_answer_length, do_lower_case,
                          output_prediction_file, output_nbest_file,
                          output_null_log_odds_file):
        """Write final predictions to the json file and log-odds of null if needed."""
        tf.logging.info("Writing predictions to: %s" %
                        (output_prediction_file))
        tf.logging.info("Writing nbest to: %s" % (output_nbest_file))

        example_index_to_features = collections.defaultdict(list)
        for feature in all_features:
            example_index_to_features[feature.example_index].append(feature)

        unique_id_to_result = {}
        for result in all_results:
            unique_id_to_result[result.unique_id] = result

        _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "PrelimPrediction", [
                "feature_index", "start_index", "end_index", "start_logit",
                "end_logit"
            ])

        all_predictions = collections.OrderedDict()
        all_nbest_json = collections.OrderedDict()
        scores_diff_json = collections.OrderedDict()

        for (example_index, example) in enumerate(all_examples):
            features = example_index_to_features[example_index]

            prelim_predictions = []
            # keep track of the minimum score of null start+end of position 0
            score_null = 1000000  # large and positive
            min_null_feature_index = 0  # the paragraph slice with min mull score
            null_start_logit = 0  # the start logit at the slice with min null score
            null_end_logit = 0  # the end logit at the slice with min null score
            for (feature_index, feature) in enumerate(features):
                result = unique_id_to_result[feature.unique_id]
                start_indexes = run_squad._get_best_indexes(
                    result.start_logits, n_best_size)
                end_indexes = run_squad._get_best_indexes(
                    result.end_logits, n_best_size)
                # if we could have irrelevant answers, get the min score of irrelevant
                if FLAGS.version_2_with_negative:
                    feature_null_score = result.start_logits[
                        0] + result.end_logits[0]
                    if feature_null_score < score_null:
                        score_null = feature_null_score
                        min_null_feature_index = feature_index
                        null_start_logit = result.start_logits[0]
                        null_end_logit = result.end_logits[0]
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        # We could hypothetically create invalid predictions, e.g., predict
                        # that the start of the span is in the question. We throw out all
                        # invalid predictions.
                        if start_index >= len(feature.tokens):
                            continue
                        if end_index >= len(feature.tokens):
                            continue
                        if start_index not in feature.token_to_orig_map:
                            continue
                        if end_index not in feature.token_to_orig_map:
                            continue
                        if not feature.token_is_max_context.get(
                                start_index, False):
                            continue
                        if end_index < start_index:
                            continue
                        length = end_index - start_index + 1
                        if length > max_answer_length:
                            continue
                        prelim_predictions.append(
                            _PrelimPrediction(
                                feature_index=feature_index,
                                start_index=start_index,
                                end_index=end_index,
                                start_logit=result.start_logits[start_index],
                                end_logit=result.end_logits[end_index]))

            if FLAGS.version_2_with_negative:
                prelim_predictions.append(
                    _PrelimPrediction(feature_index=min_null_feature_index,
                                      start_index=0,
                                      end_index=0,
                                      start_logit=null_start_logit,
                                      end_logit=null_end_logit))
            prelim_predictions = sorted(prelim_predictions,
                                        key=lambda x:
                                        (x.start_logit + x.end_logit),
                                        reverse=True)

            _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
                "NbestPrediction", ["text", "start_logit", "end_logit"])

            seen_predictions = {}
            nbest = []
            for pred in prelim_predictions:
                if len(nbest) >= n_best_size:
                    break
                feature = features[pred.feature_index]
                if pred.start_index > 0:  # this is a non-null prediction
                    tok_tokens = feature.tokens[pred.start_index:(
                        pred.end_index + 1)]
                    orig_doc_start = feature.token_to_orig_map[
                        pred.start_index]
                    orig_doc_end = feature.token_to_orig_map[pred.end_index]
                    orig_tokens = example.doc_tokens[orig_doc_start:(
                        orig_doc_end + 1)]
                    tok_text = " ".join(tok_tokens)

                    # De-tokenize WordPieces that have been split off.
                    tok_text = tok_text.replace(" ##", "")
                    tok_text = tok_text.replace("##", "")

                    # Clean whitespace
                    tok_text = tok_text.strip()
                    tok_text = " ".join(tok_text.split())
                    orig_text = " ".join(orig_tokens)

                    final_text = run_squad.get_final_text(
                        tok_text, orig_text, do_lower_case)
                    if final_text in seen_predictions:
                        continue

                    seen_predictions[final_text] = True
                else:
                    final_text = ""
                    seen_predictions[final_text] = True

                nbest.append(
                    _NbestPrediction(text=final_text,
                                     start_logit=pred.start_logit,
                                     end_logit=pred.end_logit))

            # if we didn't inlude the empty option in the n-best, inlcude it
            if FLAGS.version_2_with_negative:
                if "" not in seen_predictions:
                    nbest.append(
                        _NbestPrediction(text="",
                                         start_logit=null_start_logit,
                                         end_logit=null_end_logit))
            # In very rare edge cases we could have no valid predictions. So we
            # just create a nonce prediction in this case to avoid failure.
            if not nbest:
                nbest.append(
                    _NbestPrediction(text="empty",
                                     start_logit=0.0,
                                     end_logit=0.0))

            assert len(nbest) >= 1

            total_scores = []
            best_non_null_entry = None
            for entry in nbest:
                total_scores.append(entry.start_logit + entry.end_logit)
                if not best_non_null_entry:
                    if entry.text:
                        best_non_null_entry = entry

            probs = run_squad._compute_softmax(total_scores)

            nbest_json = []
            for (i, entry) in enumerate(nbest):
                output = collections.OrderedDict()
                output["text"] = entry.text
                output["probability"] = probs[i]
                output["start_logit"] = entry.start_logit
                output["end_logit"] = entry.end_logit
                nbest_json.append(output)

            assert len(nbest_json) >= 1

            if not FLAGS.version_2_with_negative:
                all_predictions[example.qas_id] = nbest_json[0]["text"]
            else:
                # predict "" iff the null score - the score of best non-null > threshold
                score_diff = score_null - best_non_null_entry.start_logit - (
                    best_non_null_entry.end_logit)
                scores_diff_json[example.qas_id] = score_diff
                if score_diff > FLAGS.null_score_diff_threshold:
                    all_predictions[example.qas_id] = ""
                else:
                    all_predictions[example.qas_id] = best_non_null_entry.text

            all_nbest_json[example.qas_id] = nbest_json

        if FLAGS.version_2_with_negative:
            return all_predictions, all_nbest_json, scores_diff_json
        return all_predictions, all_nbest_json
def get_answer(doc_tokens, tokens_for_postprocessing, start_logits, end_logits,
               args):

    result = RawResult(start_logits=start_logits, end_logits=end_logits)

    predictions = []
    Prediction = collections.namedtuple('Prediction',
                                        ['text', 'start_logit', 'end_logit'])

    if args.version_2_with_negative:
        null_val = (float("inf"), 0, 0)

    start_indices = _get_best_indices(result.start_logits, args.n_best_size)
    end_indices = _get_best_indices(result.end_logits, args.n_best_size)
    prelim_predictions = get_valid_prelim_predictions(
        start_indices, end_indices, tokens_for_postprocessing, result, args)
    prelim_predictions = sorted(prelim_predictions,
                                key=lambda x: (x.start_logit + x.end_logit),
                                reverse=True)
    if args.version_2_with_negative:
        score = result.start_logits[0] + result.end_logits[0]
        if score < null_val[0]:
            null_val = (score, result.start_logits[0], result.end_logits[0])

    doc_tokens_obj = {
        'doc_tokens': doc_tokens,
    }
    doc_tokens_obj = SimpleNamespace(**doc_tokens_obj)

    curr_predictions = []
    seen_predictions = []
    for pred in prelim_predictions:
        if len(curr_predictions) == args.n_best_size:
            break
        if pred.end_index > 0:  # this is a non-null prediction
            final_text = get_answer_text(doc_tokens_obj,
                                         tokens_for_postprocessing, pred, args)
            if final_text in seen_predictions:
                continue
        else:
            final_text = ""

        seen_predictions.append(final_text)
        curr_predictions.append(
            Prediction(final_text, pred.start_logit, pred.end_logit))
    predictions += curr_predictions

    # add empty prediction
    if args.version_2_with_negative:
        predictions.append(Prediction('', null_val[1], null_val[2]))

    nbest_answers = []
    answer = None
    nbest = sorted(predictions,
                   key=lambda x: (x.start_logit + x.end_logit),
                   reverse=True)[:args.n_best_size]

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
        total_scores.append(entry.start_logit + entry.end_logit)
        if not best_non_null_entry and entry.text:
            best_non_null_entry = entry
    probs = _compute_softmax(total_scores)
    for (i, entry) in enumerate(nbest):
        output = collections.OrderedDict()
        output["text"] = entry.text
        output["probability"] = probs[i]
        output["start_logit"] = entry.start_logit
        output["end_logit"] = entry.end_logit
        nbest_answers.append(output)
    if args.version_2_with_negative:
        score_diff = null_val[
            0] - best_non_null_entry.start_logit - best_non_null_entry.end_logit
        if score_diff > args.null_score_diff_threshold:
            answer = ""
        else:
            answer = best_non_null_entry.text
    else:
        answer = nbest_answers[0]['text']

    return answer, nbest_answers
Ejemplo n.º 4
0
    def predict(self, all_examples, all_features, all_results, n_best_size,
                max_answer_length, do_lower_case):
        example_index_to_features = collections.defaultdict(list)
        for feature in all_features:
            example_index_to_features[feature.example_index].append(feature)

        unique_id_to_result = {}
        for result in all_results:
            unique_id_to_result[result.unique_id] = result

        _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "PrelimPrediction", [
                "feature_index", "start_index", "end_index", "start_logit",
                "end_logit"
            ])

        all_predictions = collections.OrderedDict()
        all_nbest_json = collections.OrderedDict()

        for (example_index, example) in enumerate(all_examples):
            features = example_index_to_features[example_index]

            prelim_predictions = []

            for (feature_index, feature) in enumerate(features):
                result = unique_id_to_result[feature.unique_id]
                start_indexes = _get_best_indexes(result.start_logits,
                                                  n_best_size)
                end_indexes = _get_best_indexes(result.end_logits, n_best_size)
                # if we could have irrelevant answers, get the min score of irrelevant

                for start_index in start_indexes:
                    for end_index in end_indexes:
                        # We could hypothetically create invalid predictions, e.g., predict
                        # that the start of the span is in the question. We throw out all
                        # invalid predictions.
                        if (start_index >= len(feature.tokens)
                                or end_index >= len(feature.tokens)
                                or start_index not in feature.token_to_orig_map
                                or end_index not in feature.token_to_orig_map
                                or not feature.token_is_max_context.get(
                                    start_index, False)
                                or end_index < start_index):
                            continue

                        length = end_index - start_index + 1
                        if length > max_answer_length:
                            continue

                        prelim_predictions.append(
                            _PrelimPrediction(
                                feature_index=feature_index,
                                start_index=start_index,
                                end_index=end_index,
                                start_logit=result.start_logits[start_index],
                                end_logit=result.end_logits[end_index]))

            prelim_predictions = sorted(prelim_predictions,
                                        key=lambda x:
                                        (x.start_logit + x.end_logit),
                                        reverse=True)

            _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
                "NbestPrediction", ["text", "start_logit", "end_logit"])

            seen_predictions = {}
            nbest = []
            for pred in prelim_predictions:
                if len(nbest) >= n_best_size:
                    break
                feature = features[pred.feature_index]
                if pred.start_index > 0:  # this is a non-null prediction
                    tok_tokens = feature.tokens[pred.start_index:(
                        pred.end_index + 1)]
                    orig_doc_start = feature.token_to_orig_map[
                        pred.start_index]
                    orig_doc_end = feature.token_to_orig_map[pred.end_index]
                    orig_tokens = example.doc_tokens[orig_doc_start:(
                        orig_doc_end + 1)]
                    tok_text = " ".join(tok_tokens)

                    # De-tokenize WordPieces that have been split off.
                    tok_text = tok_text.replace(" ##", "")
                    tok_text = tok_text.replace("##", "")

                    # Clean whitespace
                    tok_text = tok_text.strip()
                    tok_text = " ".join(tok_text.split())
                    orig_text = " ".join(orig_tokens)

                    final_text = get_final_text(tok_text, orig_text,
                                                do_lower_case)
                    if final_text in seen_predictions:
                        continue

                    seen_predictions[final_text] = True
                else:
                    final_text = ""
                    seen_predictions[final_text] = True

                nbest.append(
                    _NbestPrediction(text=final_text,
                                     start_logit=pred.start_logit,
                                     end_logit=pred.end_logit))

            # In very rare edge cases we could have no valid predictions. So we
            # just create a nonce prediction in this case to avoid failure.
            if not nbest:
                nbest.append(
                    _NbestPrediction(text="empty",
                                     start_logit=0.0,
                                     end_logit=0.0))

            assert len(nbest) >= 1

            total_scores = []
            best_non_null_entry = None
            for entry in nbest:
                total_scores.append(entry.start_logit + entry.end_logit)
                if not best_non_null_entry:
                    if entry.text:
                        best_non_null_entry = entry

            probs = _compute_softmax(total_scores)

            nbest_json = []
            for (i, entry) in enumerate(nbest):
                output = collections.OrderedDict()
                output["text"] = entry.text
                output["probability"] = probs[i]
                output["start_logit"] = entry.start_logit
                output["end_logit"] = entry.end_logit
                nbest_json.append(output)

            assert len(nbest_json) >= 1

            all_predictions[example.qas_id] = nbest_json[0]["text"]
            all_nbest_json[example.qas_id] = nbest_json

        return all_nbest_json
Ejemplo n.º 5
0
                         start_logit=pred.start_logit,
                         end_logit=pred.end_logit))

# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
    nbest.append(_NbestPrediction(text="empty", start_logit=0.0,
                                  end_logit=0.0))

assert len(nbest) >= 1

total_scores = []
for entry in nbest:
    total_scores.append(entry.start_logit + entry.end_logit)

probs = run_squad._compute_softmax(total_scores)

print(nbest[0].text)

#END SECTION TRYING A NEW APPROACH
#END SECTION TRYING A NEW APPROACH
#END SECTION TRYING A NEW APPROACH
#END SECTION TRYING A NEW APPROACH

i = 0
example_index_routes = example_indices_routes[0]
#start_logits and end_logits are both lists of len 384
start_logits_routes = batch_start_logits_routes[i].detach().cpu().tolist()
end_logits_routes = batch_end_logits_routes[i].detach().cpu().tolist()
#remember from above, eval_features is list of objects, each w/ fields for tokens, input_mask, input_ids, segment_ids, etc.
eval_feature_routes = eval_features_routes[