Esempio n. 1
0
def filter_user_questions(input_to_filter, faq_contents):
    '''
    This function takes the data from the input_to_filter json file, and only
    returns the examples that align with the faq_contents.

    input_to_filer: str, filename of the json to filter
    faq_contents: set, all faq questions in a set.
    '''
    with open(input_to_filter, 'r', encoding='utf-8') as in_stream:
        input_data = json.load(in_stream)

    _, pid2passage, _ = get_passages_by_source(input_data, keep_ood=True)

    filtered_example = []
    examples = get_examples(input_data, keep_ood=True)
    for example in examples:
        related_pid = get_passage_id(example)
        related_passage = pid2passage[related_pid]
        if get_passage_last_header(related_passage) in faq_contents:
            filtered_example.append(example)

    logger.info(
        'file {}: passage size {} / pre-filtering example size {} / post filtering examples size'
        ' {}'.format(input_to_filter, len(input_data['passages']),
                     len(examples), len(filtered_example)))

    return {'examples': filtered_example, 'passages': input_data['passages']}
Esempio n. 2
0
def generate_embeddings(ret_trainee, input_file, out_file):
    with open(input_file, "r", encoding="utf-8") as f:
        json_data = json.load(f)

    source2passages, pid2passage, _ = get_passages_by_source(json_data)

    question_embs = []
    labels = []
    for example in tqdm(json_data["examples"]):
        pid = get_passage_id(example)
        passage = pid2passage[pid]
        labels.append('id' if is_in_distribution(passage) else 'ood')
        emb = ret_trainee.retriever.embed_question(get_question(example))
        question_embs.append(emb)

    passage_header_embs = []
    ood = 0
    for source, passages in source2passages.items():
        logger.info('embedding passages for source {}'.format(source))
        for passage in tqdm(passages):
            if is_in_distribution(passage):
                emb = ret_trainee.retriever.embed_paragraph(
                    get_passage_last_header(passage, return_error_for_ood=True))
                passage_header_embs.append(emb)
            else:
                ood += 1

    to_serialize = {"question_embs": question_embs, "passage_header_embs": passage_header_embs,
                    "question_labels": labels}
    with open(out_file, "wb") as out_stream:
        pickle.dump(to_serialize, out_stream)
    logger.info(
        'saved {} question embeddings and {} passage header embeddings ({} skipped because '
        'out-of-distribution)'.format(
            len(question_embs), len(passage_header_embs), ood))
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("--input", help="join files to filter", required=True)
    parser.add_argument("--output", help="output file", required=True)
    parser.add_argument("--keep-id", help="will keep id", action="store_true")
    parser.add_argument("--keep-ood",
                        help="will keep ood",
                        action="store_true")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    with open(args.input, 'r', encoding='utf-8') as in_stream:
        input_data = json.load(in_stream)

    filtered_examples = []
    filtered_passage_ids = set()
    _, pid2passages, _ = get_passages_by_source(input_data)

    id_kept = 0
    ood_kept = 0
    total = 0

    for example in get_examples(input_data, True):
        example_pid = get_passage_id(example)
        related_passage = pid2passages[example_pid]
        is_id = is_in_distribution(related_passage)
        if is_id and args.keep_id:
            filtered_examples.append(example)
            filtered_passage_ids.add(example_pid)
            id_kept += 1
        elif not is_id and args.keep_ood:
            filtered_examples.append(example)
            filtered_passage_ids.add(example_pid)
            ood_kept += 1
        total += 1

    filtered_passages = [pid2passages[pid] for pid in filtered_passage_ids]

    logger.info('kept {} ID and {} OOD (from a total of {} examples)'.format(
        id_kept, ood_kept, total))
    logger.info('kept {} passages (from a total of {} passages)'.format(
        len(filtered_passages), len(pid2passages)))

    with open(args.output, "w", encoding="utf-8") as ostream:
        json.dump(
            {
                'examples': filtered_examples,
                'passages': filtered_passages
            },
            ostream,
            indent=4,
            ensure_ascii=False)
Esempio n. 4
0
def generate_embeddings(ret_trainee, input_file=None, out_file=None, json_data=None,
                        embed_passages=True):
    if input_file:
        with open(input_file, "r", encoding="utf-8") as f:
            json_data = json.load(f)
    elif json_data:
        pass
    else:
        raise ValueError("You should specify either the input file or the json_data")

    source2passages, pid2passage, _ = get_passages_by_source(json_data)

    question_embs = []
    question_texts = []
    labels = []
    if json_data.get("examples"):
        for example in tqdm(json_data["examples"]):
            pid = get_passage_id(example)
            passage = pid2passage[pid]
            labels.append('id' if is_in_distribution(passage) else 'ood')
            question = get_question(example)
            emb = ret_trainee.retriever.embed_question(question)
            question_embs.append(emb)
            question_texts.append(question)

    passage_header_embs = []
    ood = 0
    passage_texts = []
    if embed_passages:
        for source, passages in source2passages.items():
            logger.info('embedding passages for source {}'.format(source))
            for passage in tqdm(passages):
                if is_in_distribution(passage):
                    passage_text = get_passage_last_header(passage, return_error_for_ood=True)
                    emb = ret_trainee.retriever.embed_paragraph(
                        passage_text)
                    passage_header_embs.append(emb)
                    passage_texts.append(passage_text)
                else:
                    ood += 1

    to_serialize = {"question_embs": question_embs, "passage_header_embs": passage_header_embs,
                    "question_labels": labels, "passage_texts": passage_texts,
                    "question_texts": question_texts}
    if out_file:
        with open(out_file, "wb") as out_stream:
            pickle.dump(to_serialize, out_stream)
    logger.info(
        'generated {} question embeddings and {} passage header embeddings ({} skipped because '
        'out-of-distribution)'.format(
            len(question_embs), len(passage_header_embs), ood))

    return to_serialize
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("--input", help="main json file", required=True)
    parser.add_argument("--input-ood", help="json file where to extract the ood", required=True)
    parser.add_argument("--output", help="output file", required=True)
    parser.add_argument("--max-ood", help="max amount of ood to use. -1 means all of them",
                        type=int, default=-1)
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    # parse input data - check there is no OOD
    with open(args.input, 'r', encoding='utf-8') as in_stream:
        input_data = json.load(in_stream)

    _, pid2passages, _ = get_passages_by_source(input_data)
    for example in input_data['examples']:
        related_passage_pid = get_passage_id(example)
        related_passage = pid2passages[related_passage_pid]
        if not is_in_distribution(related_passage):
            raise ValueError(
                '--input file must not have any OOD -- see example {} / passage {}'.format(
                    example['id'], related_passage_pid))

    # new pid number (for OOD) is the highest one plus 1 (so there is no collision)
    new_pid_for_ood = max(pid2passages.keys()) + 1
    # new id number for example is the highest one plus 1 (so there is no collision)
    highest_id_for_example = max([ex['id'] for ex in input_data['examples']]) + 1

    result = input_data

    with open(args.input_ood, 'r', encoding='utf-8') as in_stream:
        ood_data = json.load(in_stream)

    new_ood_passage = {
            "passage_id": new_pid_for_ood,
            "source": "ood_source",
            "uri": None,
            "reference_type": "ood",
            "reference": {
                "page_title": None,
                "section_headers": [],
                "section_content": None,
                "selected_span": None
            }
        }
    result['passages'].append(new_ood_passage)

    _, pid2passages, _ = get_passages_by_source(ood_data)
    id_count = 0
    added_ood = 0
    for example in ood_data['examples']:
        related_passage = get_passage_id(example)
        related_passage = pid2passages[related_passage]
        if is_in_distribution(related_passage):
            id_count += 1
        else:
            example['passage_id'] = new_pid_for_ood
            example['id'] = highest_id_for_example
            highest_id_for_example += 1
            result['examples'].append(example)
            added_ood += 1
        if args.max_ood > -1 and added_ood >= args.max_ood:
            break

    logger.info('kept {} ood from {} (and skipped {} id from the same file)'.format(
        added_ood, args.input_ood, id_count
    ))

    with open(args.output, "w", encoding="utf-8") as ostream:
        json.dump(result, ostream, indent=4, ensure_ascii=False)