예제 #1
0
def show_predictions(db_filename, predictions):
    """
    display claim and predicted sentences which doesn't include at least one evidence set
    :param db_filename:
    :param predictions:
    :return:
    """

    db = FeverDocDB(db_filename)

    for line in predictions:

        if line['label'].upper() != "NOT ENOUGH INFO":
            macro_rec = evidence_macro_recall(line)
            if macro_rec[0] == 1.0:
                continue
            pages = set([page for page, _ in line['predicted_evidence']])
            evidence_set = set([(page, line_num) for page, line_num in line['predicted_evidence']])
            p_lines = []
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    # print(page)
                    continue
                doc_lines = [doc_line.split("\t")[1] if len(doc_line.split("\t")[1]) > 1 else "" for doc_line in
                             doc_lines.split("\n")]
                p_lines.extend(zip(doc_lines, [page] * len(doc_lines), range(len(doc_lines))))

            print("claim: {}".format(line['claim']))
            print(evidence_set)
            count = 0
            for doc_line in p_lines:
                if (doc_line[1], doc_line[2]) in evidence_set:
                    print("the {}st evidence: {}".format(count, doc_line[0]))
                    count += 1
def in_doc_sampling(db_filename, datapath, num_sample=1):

    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    X = []
    count = 0
    with open(datapath, "r") as f:
        lines = jlr.process(f)

        for line in tqdm(lines):
            count += 1
            pos_pairs = []
            # count1 += 1
            if line['label'].upper() == "NOT ENOUGH INFO":
                continue
            neg_sents = []
            claim = line['claim']

            pos_set = set()
            for evidence_set in line['evidence']:
                pos_sent = get_whole_evidence(evidence_set, db)
                if pos_sent in pos_set:
                    continue
                pos_set.add(pos_sent)

            p_lines = []
            evidence_set = set([(evidence[2], evidence[3])
                                for evidences in line['evidence']
                                for evidence in evidences])
            page_set = set([evidence[0] for evidence in evidence_set])
            for page in page_set:
                doc_lines = db.get_doc_lines(page)
                p_lines.extend(get_valid_texts(doc_lines, page))
            for doc_line in p_lines:
                if (doc_line[1], doc_line[2]) not in evidence_set:
                    neg_sents.append(doc_line[0])

            num_sampling = num_sample
            if len(neg_sents) < num_sampling:
                num_sampling = len(neg_sents)
                # print(neg_sents)
            if num_sampling == 0:
                continue
            else:
                for pos_sent in pos_set:
                    samples = random.sample(neg_sents, num_sampling)
                    for sample in samples:
                        if not sample:
                            continue
                        X.append((claim, pos_sent, sample))
                        if count % 1000 == 0:
                            print("claim:{} ,evidence :{} sample:{}".format(
                                claim, pos_sent, sample))
    return X
예제 #3
0
def load_words(embedding_file, train_datapath, test_path, db_filename,
               num_sample, sampled_path):

    words = set()

    def _insert(iterable):
        for w in iterable:
            w = Dictionary.normalize(w)
            if valid_words and w not in valid_words:
                continue
            words.add(w)

    valid_words = index_embedding_words(embedding_file)

    X_claim, X_sents, y = load_generate_samples(db_filename, train_datapath,
                                                num_sample, sampled_path)
    X_claim = set(X_claim)
    for claim in X_claim:
        words = nltk.word_tokenize(claim)
        _insert(words)

    for sent in X_sents:
        words = simple_tokenizer(sent)
        _insert(words)

    with open(test_path, "r") as f:
        jlr = JSONLineReader()
        db = FeverDocDB(db_filename)

        lines = jlr.process(f)
        for line in lines:
            claim = line['claim']
            words = nltk.word_tokenize(claim)
            _insert(words)
            evidence_set = set([(evidence[2], evidence[3])
                                for evidences in line['evidence']
                                for evidence in evidences])
            pages = set()
            pages.update(evidence[0] for evidence in line['predicted_pages'])
            pages.update(evidence[0] for evidence in evidence_set)
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    continue
                doc_lines = [
                    doc_line.split("\t")[1]
                    if len(doc_line.split("\t")[1]) > 1 else ""
                    for doc_line in doc_lines.split("\n")
                ]
                doc_lines = [doc_line for doc_line in doc_lines if doc_line]
                for doc_line in doc_lines:
                    words = simple_tokenizer(doc_line)
                    _insert(words)
    return words
예제 #4
0
def test_data_4_siamese(db_filename, dataset_path):
    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    X_claims = []
    X_sents = []
    all_sents_id = []

    with open(dataset_path, "r") as f:
        lines = jlr.process(f)
        # lines = lines[:1000]

        for line in tqdm(lines):
            claims = []
            sents = []
            sents_indexes = []
            p_lines = []
            claim = line['claim']
            pages = set()
            pages.update(evidence[0] for evidence in line['predicted_pages'])
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    continue
                doc_lines = [
                    doc_line.split("\t")[1]
                    if len(doc_line.split("\t")[1]) > 1 else ""
                    for doc_line in doc_lines.split("\n")
                ]
                p_lines.extend(
                    zip(doc_lines, [page] * len(doc_lines),
                        range(len(doc_lines))))
            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                else:
                    claims.append(claim)
                    sents.append(doc_line[0])
                    sents_indexes.append((doc_line[1], doc_line[2]))
            X_claims.append(claims)
            X_sents.append(sents)
            all_sents_id.append(sents_indexes)
    # print(len(X_claims))
    # print(len(X_sents))
    # print(len(all_sents_id))
    # X_claims_indexes, X_sents_indexes = [], []
    # for idx, claims in enumerate(X_claims):
    #     claims_index, sents_index = data_transformer(claims, X_sents[idx], word_dict)
    #     X_claims_indexes.append(claims_index)
    #     X_sents_indexes.append(sents_index)

    return X_claims, X_sents, all_sents_id
def _create_token_set_of_db(db):
    logger.debug("start creating token set for DB...")
    if type(db) == str:
        db = FeverDocDB(db)
    _token_set = set()
    for doc_id in tqdm(db.get_non_empty_doc_ids()):
        doc_lines = db.get_doc_lines(doc_id)
        for line in doc_lines:
            tokens = tokenize(clean_text(line))
            for token in tokens:
                if token.lower() in _token_set:
                    continue
                _token_set.add(token.lower())
    return _token_set
def _create_db_vocab_idx(db, _global_dict):
    # logger = LogHelper.get_logger("_create_db_vocab_idx")
    logger.debug("start creating vocab indices for DB...")
    if type(db) == str:
        db = FeverDocDB(db)
    _vocab_idx = {}
    for doc_id in tqdm(db.get_non_empty_doc_ids()):
        doc_lines = db.get_doc_lines(doc_id)
        for line in doc_lines:
            tokens = tokenize(clean_text(line))
            for token in tokens:
                if token.lower() in _vocab_idx:
                    continue
                if token.lower() in _global_dict:
                    _vocab_idx[token.lower()] = _global_dict[token.lower()]
    _vocab_idx = sorted(list(_vocab_idx.values()))
    return _vocab_idx
def dev_processing(db_filename, lines):

    db = FeverDocDB(db_filename)
    claims = []
    list_sents = []
    labels = []

    for line in tqdm(lines):
        if line['label'].upper() == "NOT ENOUGH INFO":
            continue

        claims.append(line['claim'])
        sents = []
        label = []

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in line['evidence']
                            for evidence in evidences])
        pages = [
            page[0] for page in line['predicted_pages'] if page[0] is not None
        ]
        for page, num in evidence_set:
            pages.append(page)
        pages = set(pages)

        p_lines = []
        for page in pages:
            doc_lines = db.get_doc_lines(page)
            p_lines.extend(get_valid_texts(doc_lines, page))
        for doc_line in p_lines:
            if not doc_line[0]:
                continue
            if (doc_line[1], doc_line[2]) in evidence_set:
                sents.append(doc_line[0])
                label.append(1)
            else:
                sents.append(doc_line[0])
                label.append(0)
        if len(claims) == 0 or len(list_sents) == 0 or len(labels) == 0:
            continue
        list_sents.append(sents)
        labels.append(label)
    return claims, list_sents, labels
def test_processing(db_filename, lines):
    db = FeverDocDB(db_filename)
    claims = []
    list_sents = []
    sents_indexes = []

    for line in tqdm(lines):
        # if line['label'].upper() == "NOT ENOUGH INFO":
        #     continue

        claims.append(line['claim'])
        sents = []
        sents_index = []

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in line['evidence']
                            for evidence in evidences])
        pages = set([
            page[0] for page in line['predicted_pages'] if page[0] is not None
        ])
        if len(pages) == 0:
            pages.add("Michael_Hutchence")

        p_lines = []
        for page in pages:
            doc_lines = db.get_doc_lines(page)
            p_lines.extend(get_valid_texts(doc_lines, page))
        for doc_line in p_lines:
            if not doc_line[0]:
                continue
            if (doc_line[1], doc_line[2]) in evidence_set:
                sents.append(doc_line[0])
            else:
                sents.append(doc_line[0])
            sents_index.append((doc_line[1], doc_line[2]))
        list_sents.append(sents)
        sents_indexes.append(sents_index)
    return claims, list_sents, sents_indexes
def dev_processing(db_filename, datapath):
    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    devs = []
    all_indexes = []

    with open(datapath, "rb") as f:
        lines = jlr.process(f)

        for line in tqdm(lines):
            dev = []
            indexes = []
            pages = set()
            pages.update(page[0] for page in line['predicted_pages'])
            if len(pages) == 0:
                pages.add("Michael_Hutchence")
            claim = line['claim']
            p_lines = []
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    continue
                p_lines.extend(get_valid_texts(doc_lines, page))

            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                dev.append((claim, doc_line[0]))
                indexes.append((doc_line[1], doc_line[2]))
            # print(len(dev))
            if len(dev) == 0:
                dev.append((claim, 'no evidence for this claim'))
                indexes.append(('empty', 0))
            devs.append(dev)
            all_indexes.append(indexes)
    return devs, all_indexes
def train_sample(db_filename, lines, list_size=15):

    db = FeverDocDB(db_filename)

    claims = []
    list_sents = []
    labels = []
    count = 0

    for idx, line in tqdm(enumerate(lines)):

        if line['label'].upper() == "NOT ENOUGH INFO":
            continue

        claim = line['claim']
        claims.append(claim)
        sents = []
        label = []

        pos_set = set()
        neg_sents = []
        for evidence_group in line['evidence']:
            pos_sent = get_whole_evidence(evidence_group, db)
            if pos_sent in pos_set:
                continue
            pos_set.add(pos_sent)

        p_lines = []
        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in line['evidence']
                            for evidence in evidences])

        pages = [
            page[0] for page in line['predicted_pages'] if page[0] is not None
        ]
        for page, num in evidence_set:
            pages.append(page)
        pages = set(pages)
        for page in pages:
            doc_lines = db.get_doc_lines(page)
            p_lines.extend(get_valid_texts(doc_lines, page))
        for doc_line in p_lines:
            if not doc_line[0]:
                continue
            if (doc_line[1], doc_line[2]) not in evidence_set:
                neg_sents.append(doc_line[0])

        pos_set = list(pos_set)
        if len(pos_set) > 5:
            pos_set = random.sample(pos_set, 5)
        if len(neg_sents) < (list_size - len(pos_set)):

            count += 1
            continue
        else:
            samples = random.sample(neg_sents, list_size - len(pos_set))
            pos_indexes_sample = random.sample(range(list_size), len(pos_set))
            neg_index = 0
            pos_index = 0
            for i in range(list_size):
                if i in pos_indexes_sample:
                    sents.append(pos_set[pos_index])
                    label.append(1 / len(pos_set))
                    pos_index += 1
                else:
                    sents.append(samples[neg_index])
                    label.append(0.0)
                    neg_index += 1
            if idx % 1000 == 0:
                print(claim)
                print(sents)
                print(label)

        list_sents.append(sents)
        labels.append(label)
    print(count)
    return claims, list_sents, labels
예제 #11
0
class Doc_Retrieval:

    def __init__(self, database_path, add_claim=False, k_wiki_results=None):
        self.db = FeverDocDB(database_path)
        self.add_claim = add_claim
        self.k_wiki_results = k_wiki_results
        self.proter_stemm = nltk.PorterStemmer()
        self.tokenizer = nltk.word_tokenize
        #print("Va a descargar")
        #get_spacy_model('en_core_web_lg',True,False,False)
        self.predictor = Predictor.from_path(
            "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo-constituency-parser-2018.03.14.tar.gz")
        #print("Descargó")

    def get_NP(self, tree, nps):

        if isinstance(tree, dict):
            if "children" not in tree:
                if tree['nodeType'] == "NP":
                    # print(tree['word'])
                    # print(tree)
                    nps.append(tree['word'])
            elif "children" in tree:
                if tree['nodeType'] == "NP":
                    # print(tree['word'])
                    nps.append(tree['word'])
                    self.get_NP(tree['children'], nps)
                else:
                    self.get_NP(tree['children'], nps)
        elif isinstance(tree, list):
            for sub_tree in tree:
                self.get_NP(sub_tree, nps)

        return nps

    def get_subjects(self, tree):
        subject_words = []
        subjects = []
        for subtree in tree['children']:
            if subtree['nodeType'] == "VP" or subtree['nodeType'] == 'S' or subtree['nodeType'] == 'VBZ':
                subjects.append(' '.join(subject_words))
                subject_words.append(subtree['word'])
            else:
                subject_words.append(subtree['word'])
        return subjects

    def get_noun_phrases(self, line):

        claim = line['claim']
        tokens = self.predictor.predict(claim)
        nps = []
        tree = tokens['hierplane_tree']['root']
        noun_phrases = self.get_NP(tree, nps)
        subjects = self.get_subjects(tree)
        for subject in subjects:
            if len(subject) > 0:
                noun_phrases.append(subject)
        if self.add_claim:
            noun_phrases.append(claim)
        return list(set(noun_phrases))

    def get_doc_for_claim(self, noun_phrases):

        predicted_pages = []
        for np in noun_phrases:
            if len(np) > 300:
                continue
            i = 1
            while i < 12:
                try:
                    docs = wikipedia.search(np)
                    if self.k_wiki_results is not None:
                        predicted_pages.extend(docs[:self.k_wiki_results])
                    else:
                        predicted_pages.extend(docs)
                except (ConnectionResetError, ConnectionError, ConnectionAbortedError, ConnectionRefusedError):
                    print("Connection reset error received! Trial #" + str(i))
                    time.sleep(600 * i)
                    i += 1
                else:
                    break

            # sleep_num = random.uniform(0.1,0.7)
            # time.sleep(sleep_num)
        predicted_pages = set(predicted_pages)
        processed_pages = []
        for page in predicted_pages:
            page = page.replace(" ", "_")
            page = page.replace("(", "-LRB-")
            page = page.replace(")", "-RRB-")
            page = page.replace(":", "-COLON-")
            processed_pages.append(page)

        return processed_pages

    def np_conc(self, noun_phrases):

        noun_phrases = set(noun_phrases)
        predicted_pages = []
        for np in noun_phrases:
            page = np.replace('( ', '-LRB-')
            page = page.replace(' )', '-RRB-')
            page = page.replace(' - ', '-')
            page = page.replace(' :', '-COLON-')
            page = page.replace(' ,', ',')
            page = page.replace(" 's", "'s")
            page = page.replace(' ', '_')

            if len(page) < 1:
                continue
            doc_lines = self.db.get_doc_lines(page)
            if doc_lines is not None:
                predicted_pages.append(page)
        return predicted_pages

    def exact_match(self, line):

        noun_phrases = self.get_noun_phrases(line)
        wiki_results = self.get_doc_for_claim(noun_phrases)
        wiki_results = list(set(wiki_results))

        claim = normalize(line['claim'])
        claim = claim.replace(".", "")
        claim = claim.replace("-", " ")
        words = [self.proter_stemm.stem(word.lower()) for word in self.tokenizer(claim)]
        words = set(words)
        predicted_pages = self.np_conc(noun_phrases)

        for page in wiki_results:
            page = normalize(page)
            processed_page = re.sub("-LRB-.*?-RRB-", "", page)
            processed_page = re.sub("_", " ", processed_page)
            processed_page = re.sub("-COLON-", ":", processed_page)
            processed_page = processed_page.replace("-", " ")
            processed_page = processed_page.replace("–", " ")
            processed_page = processed_page.replace(".", "")
            page_words = [self.proter_stemm.stem(word.lower()) for word in self.tokenizer(processed_page) if
                          len(word) > 0]

            if all([item in words for item in page_words]):
                if ':' in page:
                    page = page.replace(":", "-COLON-")
                predicted_pages.append(page)
        predicted_pages = list(set(predicted_pages))
        # print("claim: ",claim)
        # print("nps: ",noun_phrases)
        # print("wiki_results: ",wiki_results)
        # print("predicted_pages: ",predicted_pages)
        # print("evidence:",line['evidence'])
        return noun_phrases, wiki_results, predicted_pages
예제 #12
0
def eval_model(db: FeverDocDB, args) -> Model:
    archive = load_archive(args.archive_file,
                           cuda_device=args.cuda_device,
                           overrides=args.overrides)

    config = archive.config
    ds_params = config["dataset_reader"]

    model = archive.model
    model.eval()

    reader = FEVERReader(db,
                         sentence_level=ds_params.pop("sentence_level", False),
                         wiki_tokenizer=Tokenizer.from_params(
                             ds_params.pop('wiki_tokenizer', {})),
                         claim_tokenizer=Tokenizer.from_params(
                             ds_params.pop('claim_tokenizer', {})),
                         token_indexers=TokenIndexer.dict_from_params(
                             ds_params.pop('token_indexers', {})))

    while True:

        claim = input("enter claim (or q to quit) >>")
        if claim.lower() == "q":
            break

        ranker = retriever.get_class('tfidf')(tfidf_path=args.model)

        p_lines = []
        pages, _ = ranker.closest_docs(claim, 5)

        for page in pages:
            lines = db.get_doc_lines(page)
            lines = [
                line.split("\t")[1] if len(line.split("\t")[1]) > 1 else ""
                for line in lines.split("\n")
            ]

            p_lines.extend(zip(lines, [page] * len(lines), range(len(lines))))

        scores = tf_idf_sim(claim, [pl[0] for pl in p_lines])
        scores = list(
            zip(scores, [pl[1] for pl in p_lines], [pl[2] for pl in p_lines],
                [pl[0] for pl in p_lines]))
        scores = list(filter(lambda score: len(score[3].strip()), scores))
        sentences_l = list(
            sorted(scores, reverse=True, key=lambda elem: elem[0]))

        sentences = [s[3] for s in sentences_l[:5]]
        evidence = " ".join(sentences)

        print("Best pages: {0}".format(repr(pages)))

        print("Evidence:")
        for idx, sentence in enumerate(sentences_l[:5]):
            print("{0}\t{1}\t\t{2}\t{3}".format(idx + 1, sentence[0],
                                                sentence[1], sentence[3]))

        item = reader.text_to_instance(evidence, claim)

        prediction = model.forward_on_instance(item, args.cuda_device)
        cls = model.vocab._index_to_token["labels"][np.argmax(
            prediction["label_probs"])]
        print("PREDICTED: {0}".format(cls))
        print()
def prepare_ranking(db_filename, datapath, k=10, num_sample=3):
    """

    :param db_filename:
    :param datapath:
    :param k:
    :param num_sample:
    :return:
    """

    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    X = []
    with open(datapath, "r") as f:
        lines = jlr.process(f)

        for line in tqdm(lines):
            if line['label'].upper() == "NOT ENOUGH INFO":
                continue
            p_lines = []
            pos_sents = []
            neg_sents = []
            claim = line['claim']
            evidence_set = set([(evidence[2], evidence[3])
                                for evidences in line['evidence']
                                for evidence in evidences])
            sampled_sents_idx = [(id, number)
                                 for id, number in line['predicted_sentences']]
            sampled_sents_idx = [
                index for index in sampled_sents_idx
                if index not in evidence_set
            ]
            if k:
                sampled_sents_idx = sampled_sents_idx[:k]
            pages = set()
            pages.update(evidence[0] for evidence in line['predicted_pages'])
            pages.update(evidence[0] for evidence in evidence_set)
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    continue
                doc_lines = [
                    doc_line.split("\t")[1]
                    if len(doc_line.split("\t")[1]) > 1 else ""
                    for doc_line in doc_lines.split("\n")
                ]
                p_lines.extend(
                    zip(doc_lines, [page] * len(doc_lines),
                        range(len(doc_lines))))
            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                elif (doc_line[1], doc_line[2]) in sampled_sents_idx:
                    neg_sents.append(doc_line[0])
                elif (doc_line[1], doc_line[2]) in evidence_set:
                    pos_sents.append(doc_line[0])
            # print(line)
            # print(sampled_sents_idx)
            # print(neg_sents)
            if len(sampled_sents_idx) < num_sample:
                continue
            for sent in pos_sents:
                neg_samples = random.sample(neg_sents, num_sample)
                triplet = (claim, sent, neg_samples)
                X.append(triplet)

    return X
예제 #14
0
def sample_ranking_train(db_filename, datapath, k=5, num_sample=2):
    """

    :param db_filename: path stores wiki-pages database
    :param datapath: path stores fever predicted pages train set
    :param k: number of sentences where to select negative examples
    :param num_sample: number of negative examples to sample
    :return: X: claim and sentence pairs y: if the sentence in evidence set
    """

    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    X_claim = []
    X_sents = []
    y = []
    count = 0

    with open(datapath, "r") as f:
        lines = jlr.process(f)
        # lines = lines[:1000]

        for line in tqdm(lines):
            num_sampling = num_sample
            if line['label'].upper() == "NOT ENOUGH INFO":
                continue
            p_lines = []
            neg_sents = []
            claim = line['claim']
            evidence_set = set([(evidence[2], evidence[3])
                                for evidences in line['evidence']
                                for evidence in evidences])
            sampled_sents_idx = [(id, number)
                                 for id, number in line['predicted_sentences']]
            sampled_sents_idx = sampled_sents_idx[0:k + 5]
            sampled_sents_idx = [
                index for index in sampled_sents_idx
                if index not in evidence_set
            ]
            pages = set()
            pages.update(evidence[0] for evidence in line['predicted_pages'])
            pages.update(evidence[0] for evidence in evidence_set)
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    continue
                doc_lines = [
                    doc_line.split("\t")[1]
                    if len(doc_line.split("\t")[1]) > 1 else ""
                    for doc_line in doc_lines.split("\n")
                ]
                p_lines.extend(
                    zip(doc_lines, [page] * len(doc_lines),
                        range(len(doc_lines))))
            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                elif (doc_line[1], doc_line[2]) in sampled_sents_idx:
                    neg_sents.append(doc_line[0])
                elif (doc_line[1], doc_line[2]) in evidence_set:
                    X_claim.append(claim)
                    X_sents.append(doc_line[0])
                    y.append(1)

            if len(sampled_sents_idx) < num_sample:
                count += 1
                num_sampling = len(sampled_sents_idx)

            samples = random.sample(neg_sents, num_sampling)
            for neg_example in samples:
                X_claim.append(claim)
                X_sents.append(neg_example)
                y.append(0)
        print(count)

    return X_claim, X_sents, y
예제 #15
0
class ELMO_Data(object):
    def __init__(self,
                 base_path,
                 train_file,
                 dev_file,
                 test_file,
                 num_negatives,
                 h_max_length,
                 s_max_length,
                 random_seed=100,
                 db_filepath="data/fever/fever.db"):

        self.random_seed = random_seed

        self.base_path = base_path
        self.train_file = train_file
        self.dev_file = dev_file
        self.test_file = test_file
        self.num_negatives = num_negatives
        self.h_max_length = h_max_length
        self.s_max_length = s_max_length
        self.db_filepath = db_filepath
        self.db = FeverDocDB(self.db_filepath)

        self.data_pipeline()

    def data_pipeline(self):

        np.random.seed(self.random_seed)
        random.seed(self.random_seed)

        # create diretory to store sampling data and processed data
        base_dir = os.path.join(self.base_path, "data/train_data")
        store_dir = "data.h{}.s{}.seed{}".format(self.h_max_length,
                                                 self.s_max_length,
                                                 self.random_seed)
        absou_dir = os.path.join(base_dir, store_dir)
        if not os.path.exists(absou_dir):
            os.makedirs(absou_dir)

        train_data_path = os.path.join(absou_dir, "train_sample.p")
        X_train = self.train_data_loader(train_data_path,
                                         self.train_file,
                                         num_samples=self.num_negatives)
        dev_datapath = os.path.join(absou_dir, "dev_data.p")
        devs, self.dev_labels = self.dev_data_loader(dev_datapath,
                                                     self.dev_file)
        test_datapath = os.path.join(absou_dir, "test_data.p")
        tests, self.test_location_indexes = self.predict_data_loader(
            test_datapath, self.test_file)

        self.X_train = self.train_data_tokenizer(X_train)
        self.devs = self.predict_data_tokenizer(devs)
        self.tests = self.predict_data_tokenizer(tests)

        return self

    def get_whole_evidence(self, evidence_set, db):
        pos_sents = []
        for evidence in evidence_set:
            page = evidence[2]
            doc_lines = db.get_doc_lines(page)
            doc_lines = self.get_valid_texts(doc_lines, page)
            for doc_line in doc_lines:
                if doc_line[2] == evidence[3]:
                    pos_sents.append(doc_line[0])
        pos_sent = ' '.join(pos_sents)
        return pos_sent

    def get_valid_texts(self, lines, page):
        if not lines:
            return []
        doc_lines = [
            doc_line.split("\t")[1] if len(doc_line.split("\t")[1]) > 1 else ""
            for doc_line in lines.split("\n")
        ]
        doc_lines = zip(doc_lines, [page] * len(doc_lines),
                        range(len(doc_lines)))
        return doc_lines

    def sampling(self, datapath, num_sample=1):

        jlr = JSONLineReader()

        X = []
        count = 0
        with open(datapath, "r") as f:
            lines = jlr.process(f)

            for line in tqdm(lines):
                count += 1
                pos_pairs = []
                # count1 += 1
                if line['label'].upper() == "NOT ENOUGH INFO":
                    continue
                neg_sents = []
                claim = line['claim']

                pos_set = set()
                for evidence_set in line['evidence']:
                    pos_sent = self.get_whole_evidence(evidence_set, self.db)
                    if pos_sent in pos_set:
                        continue
                    pos_set.add(pos_sent)

                p_lines = []
                evidence_set = set([(evidence[2], evidence[3])
                                    for evidences in line['evidence']
                                    for evidence in evidences])

                pages = [
                    page for page in line['predicted_pages']
                    if page is not None
                ]

                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    p_lines.extend(self.get_valid_texts(doc_lines, page))
                for doc_line in p_lines:
                    if (doc_line[1], doc_line[2]) not in evidence_set:
                        neg_sents.append(doc_line[0])

                num_sampling = num_sample
                if len(neg_sents) < num_sampling:
                    num_sampling = len(neg_sents)
                    # print(neg_sents)
                if num_sampling == 0:
                    continue
                else:
                    for pos_sent in pos_set:
                        samples = random.sample(neg_sents, num_sampling)
                        for sample in samples:
                            if not sample:
                                continue
                            X.append((claim, pos_sent, sample))
                            if count % 1000 == 0:
                                print(
                                    "claim:{} ,evidence :{} sample:{}".format(
                                        claim, pos_sent, sample))
        return X

    def predict_processing(self, datapath):

        jlr = JSONLineReader()

        devs = []
        all_indexes = []

        with open(datapath, "rb") as f:
            lines = jlr.process(f)

            for line in tqdm(lines):
                dev = []
                indexes = []
                pages = set()
                # pages = line['predicted_pages']
                pages.update(page for page in line['predicted_pages'])
                # if len(pages) == 0:
                #     pages.add("Michael_Hutchence")
                claim = line['claim']
                p_lines = []
                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    if not doc_lines:
                        continue
                    p_lines.extend(self.get_valid_texts(doc_lines, page))

                for doc_line in p_lines:
                    if not doc_line[0]:
                        continue
                    dev.append((claim, doc_line[0]))
                    indexes.append((doc_line[1], doc_line[2]))
                # print(len(dev))
                if len(dev) == 0:
                    dev.append((claim, 'no evidence for this claim'))
                    indexes.append(('empty', 0))
                devs.append(dev)
                all_indexes.append(indexes)
        return devs, all_indexes

    def dev_processing(self, data_path):

        jlr = JSONLineReader()

        with open(data_path, "r") as f:
            lines = jlr.process(f)

            devs = []
            labels = []
            for line in tqdm(lines):

                dev = []
                label = []
                if line['label'].upper() == "NOT ENOUGH INFO":
                    continue
                evidence_set = set([(evidence[2], evidence[3])
                                    for evidences in line['evidence']
                                    for evidence in evidences])

                pages = [
                    page for page in line['predicted_pages']
                    if page is not None
                ]
                for page, num in evidence_set:
                    pages.append(page)
                pages = set(pages)

                p_lines = []
                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    p_lines.extend(self.get_valid_texts(doc_lines, page))
                for doc_line in p_lines:
                    if not doc_line[0]:
                        continue
                    dev.append((line['claim'], doc_line[0]))
                    if (doc_line[1], doc_line[2]) in evidence_set:
                        label.append(1)
                    else:
                        label.append(0)
                if len(dev) == 0 or len(label) == 0:
                    continue
                devs.append(dev)
                labels.append(label)
        return devs, labels

    def train_data_loader(self, train_sampled_path, data_path, num_samples=1):

        if os.path.exists(train_sampled_path):
            with open(train_sampled_path, 'rb') as f:
                X = pickle.load(f)
        else:
            X = self.sampling(data_path, num_samples)
            with open(train_sampled_path, 'wb') as f:
                pickle.dump(X, f)
        return X

    def dev_data_loader(self, dev_data_path, data_path):

        if os.path.exists(dev_data_path):
            with open(dev_data_path, "rb") as f:
                data = pickle.load(f)
                devs, labels = zip(*data)
        else:
            devs, labels = self.dev_processing(data_path)
            data = zip(devs, labels)
            with open(dev_data_path, 'wb') as f:
                pickle.dump(data, f)
        return devs, labels

    def predict_data_loader(self, predict_data_path, data_path):

        if os.path.exists(predict_data_path):
            print(predict_data_path)
            with open(predict_data_path, "rb") as f:
                data = pickle.load(f)
                devs, location_indexes = zip(*data)
        else:
            devs, location_indexes = self.predict_processing(data_path)
            data = zip(devs, location_indexes)
            with open(predict_data_path, 'wb') as f:
                pickle.dump(data, f)
        return devs, location_indexes

    def sent_processing(self, sent):
        sent = sent.replace('\n', '')
        sent = sent.replace('-', ' ')
        sent = sent.replace('/', ' ')
        return sent

    def nltk_tokenizer(self, sent):
        # sent = sent_processing(sent)
        return nltk.word_tokenize(sent)

    def proess_sents(self, sents, max_length):

        tokenized_sents = []
        sents_lengths = []
        for sent in sents:
            words = [word.lower() for word in nltk.word_tokenize(sent)]
            if len(words) < self.h_max_length:
                sents_lengths.append(len(words))
                words.extend([""] * (self.h_max_length - len(words)))
                tokenized_sents.append(words)
            else:
                sents_lengths.append(self.h_max_length)
                words = words[:self.h_max_length]
                tokenized_sents.append(words)
        return tokenized_sents, sents_lengths

    def train_data_tokenizer(self, X_train):

        claims = [claim for claim, _, _ in X_train]
        pos_sents = [pos_sent for _, pos_sent, _ in X_train]
        neg_sents = [neg_sent for _, _, neg_sent in X_train]

        tokenized_claims, claims_lengths = self.proess_sents(
            claims, self.h_max_length)
        tokenized_pos_sents, pos_sents_lengths = self.proess_sents(
            pos_sents, self.s_max_length)
        tokenized_neg_sents, neg_sents_lengths = self.proess_sents(
            neg_sents, self.s_max_length)

        new_claims = list(zip(tokenized_claims, claims_lengths))
        new_pos_sents = list(zip(tokenized_pos_sents, pos_sents_lengths))
        new_neg_sents = list(zip(tokenized_neg_sents, neg_sents_lengths))

        return list(zip(new_claims, new_pos_sents, new_neg_sents))

    def predict_data_tokenizer(self, dataset):

        predict_data = []
        for data in dataset:
            claims = [claim for claim, _ in data]
            sents = [sent for _, sent in data]

            tokenized_claims, claims_lengths = self.proess_sents(
                claims, self.h_max_length)
            tokenized_sents, sents_lengths = self.proess_sents(
                sents, self.s_max_length)

            new_claims = list(zip(tokenized_claims, claims_lengths))
            new_sents = list(zip(tokenized_sents, sents_lengths))

            tokenized_data = list(zip(new_claims, new_sents))
            predict_data.append(tokenized_data)
        return predict_data
def in_class_sampling(db_filename, datapath, num_sample=1, k=5):
    """

        :param db_filename: path stores wiki-pages database
        :param datapath: path stores fever predicted pages train set
        :param k: number of sentences where to select negative examples
        :param num_sample: number of negative examples to sample
        :return: X: claim and sentence pairs y: if the sentence in evidence set
        """

    db = FeverDocDB(db_filename)
    jlr = JSONLineReader()

    X = []
    count = 0

    count1 = 1
    with open(datapath, "r") as f:
        lines = jlr.process(f)
        # lines = lines[:1000]

        for line in tqdm(lines):
            pos_pairs = []
            count1 += 1
            num_sampling = num_sample
            if line['label'].upper() == "NOT ENOUGH INFO":
                continue
            p_lines = []
            neg_sents = []
            claim = line['claim']

            for evidence_set in line['evidence']:
                pos_sent = get_whole_evidence(evidence_set, db)
                print("claim:{} pos_sent:{}".format(claim, pos_sent))
                pos_pairs.append((claim, pos_sent))

            evidence_set = set([(evidence[2], evidence[3])
                                for evidences in line['evidence']
                                for evidence in evidences])
            sampled_sents_idx = [(id, number)
                                 for id, number in line['predicted_sentences']]
            sampled_sents_idx = sampled_sents_idx[0:k + 5]
            sampled_sents_idx = [
                index for index in sampled_sents_idx
                if index not in evidence_set
            ]
            pages = set()
            pages.update(evidence[0] for evidence in line['predicted_pages'])
            pages.update(evidence[0] for evidence in evidence_set)
            for page in pages:
                doc_lines = db.get_doc_lines(page)
                p_lines.extend(get_valid_texts(doc_lines, page))
            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                elif (doc_line[1], doc_line[2]) in sampled_sents_idx:
                    neg_sents.append(doc_line[0])
                # elif (doc_line[1], doc_line[2]) in evidence_set:
                #     if count1%10000==0:
                #         print("page_id:{},sent_num:{}".format(doc_line[1],doc_line[2]))
                #         print("evidence_set:{}".format(evidence_set))
                #     pos_pairs.append((claim,doc_line[0]))

            if len(sampled_sents_idx) < num_sample:
                num_sampling = len(neg_sents)
            if num_sampling == 0:
                count += 1
                continue
            else:
                for pair in pos_pairs:
                    samples = random.sample(neg_sents, num_sampling)
                    for sample in samples:
                        X.append((pair[0], pair[1], sample))
                        if count1 % 10000 == 0:
                            print("claim:{},pos:{},neg:{}".format(
                                claim, pair[1], sample))
        print(count)

    return X
예제 #17
0
class Data(object):
    def __init__(self,
                 embedding_path,
                 train_file,
                 dev_file,
                 test_file,
                 fasttext_path,
                 num_negatives,
                 h_max_length,
                 s_max_length,
                 random_seed,
                 reserve_embed=False,
                 db_filepath="data/fever/fever.db"):

        self.random_seed = random_seed

        self.embedding_path = embedding_path
        self.train_file = train_file
        self.dev_file = dev_file
        self.test_file = test_file
        self.fasttext_path = fasttext_path
        self.num_negatives = num_negatives
        self.h_max_length = h_max_length
        self.s_max_length = s_max_length
        self.db_filepath = db_filepath
        self.db = FeverDocDB(self.db_filepath)
        self.reserve_embed = reserve_embed

        self.data_pipeline()

    def data_pipeline(self):

        np.random.seed(self.random_seed)
        random.seed(self.random_seed)

        # create diretory to store sampling data and processed data
        # store_dir = "data.h{}.s{}.seed{}".format(self.h_max_length, self.s_max_length, self.random_seed)
        # self.absou_dir = os.path.join(base_dir, store_dir)
        os.makedirs(self.embedding_path, exist_ok=True)

        train_data_path = os.path.join(self.embedding_path, "train_sample.p")
        X_train = self.train_data_loader(train_data_path,
                                         self.train_file,
                                         num_samples=self.num_negatives)
        dev_datapath = os.path.join(self.embedding_path, "dev_data.p")
        devs, self.dev_labels = self.dev_data_loader(dev_datapath,
                                                     self.dev_file)
        if self.test_file is None:
            self.test_file = self.dev_file

        test_datapath = os.path.join(self.embedding_path, "test_data.p")
        tests, self.test_location_indexes = self.predict_data_loader(
            test_datapath, self.test_file)

        words_dict_path = os.path.join(self.embedding_path, "words_dict.p")
        if os.path.exists(words_dict_path):
            with open(words_dict_path, "rb") as f:
                self.word_dict = pickle.load(f)
        else:
            self.word_dict = self.get_complete_words(words_dict_path, X_train,
                                                     devs, tests)

        self.iword_dict = self.inverse_word_dict(self.word_dict)

        train_indexes_path = os.path.join(self.embedding_path,
                                          "train_indexes.p")
        self.X_train_indexes = self.train_indexes_loader(
            train_indexes_path, X_train)
        dev_indexes_path = os.path.join(self.embedding_path, "dev_indexes.p")
        self.dev_indexes = self.predict_indexes_loader(dev_indexes_path, devs)
        test_indexes_path = os.path.join(self.embedding_path, "test_indexes.p")
        self.test_indexes = self.predict_indexes_loader(
            test_indexes_path, tests)

        embed_dict = self.load_fasttext(self.iword_dict)
        print("embed_dict size {}".format(len(embed_dict)))
        _PAD_ = len(self.word_dict)
        self.word_dict[_PAD_] = '[PAD]'
        self.iword_dict['[PAD]'] = _PAD_
        self.embed = self.embed_to_numpy(embed_dict)

        return self

    def get_whole_evidence(self, evidence_set, db):
        pos_sents = []
        for evidence in evidence_set:
            page = evidence[2]
            doc_lines = db.get_doc_lines(page)
            doc_lines = self.get_valid_texts(doc_lines, page)
            for doc_line in doc_lines:
                if doc_line[2] == evidence[3]:
                    pos_sents.append(doc_line[0])
        pos_sent = ' '.join(pos_sents)
        return pos_sent

    def get_valid_texts(self, lines, page):
        if not lines:
            return []
        doc_lines = [
            doc_line.split("\t")[1] if len(doc_line.split("\t")[1]) > 1 else ""
            for doc_line in lines.split("\n")
        ]
        doc_lines = list(
            zip(doc_lines, [page] * len(doc_lines), range(len(doc_lines))))
        return doc_lines

    def sampling(self, datapath, num_sample=1):

        jlr = JSONLineReader()

        X = []
        count = 0
        with open(datapath, "r") as f:
            lines = jlr.process(f)

            for line in tqdm(lines):
                count += 1
                pos_pairs = []
                # count1 += 1
                if line['label'].upper() == "NOT ENOUGH INFO":
                    continue
                neg_sents = []
                claim = line['claim']

                pos_set = set()
                for evidence_set in line['evidence']:
                    pos_sent = self.get_whole_evidence(evidence_set, self.db)
                    if pos_sent in pos_set:
                        continue
                    pos_set.add(pos_sent)

                p_lines = []
                evidence_set = set([(evidence[2], evidence[3])
                                    for evidences in line['evidence']
                                    for evidence in evidences])

                pages = [
                    page for page in line['predicted_pages']
                    if page is not None
                ]

                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    p_lines.extend(self.get_valid_texts(doc_lines, page))
                for doc_line in p_lines:
                    if (doc_line[1], doc_line[2]) not in evidence_set:
                        neg_sents.append(doc_line[0])

                num_sampling = num_sample
                if len(neg_sents) < num_sampling:
                    num_sampling = len(neg_sents)
                    # print(neg_sents)
                if num_sampling == 0:
                    continue
                else:
                    for pos_sent in pos_set:
                        samples = random.sample(neg_sents, num_sampling)
                        for sample in samples:
                            if not sample:
                                continue
                            X.append((claim, pos_sent, sample))
                            # if count % 1000 == 0:
                            #     print("claim:{} ,evidence :{} sample:{}".format(claim, pos_sent, sample))
        return X

    def predict_processing(self, datapath):

        jlr = JSONLineReader()

        devs = []
        all_indexes = []

        with open(datapath, "rb") as f:
            lines = jlr.process(f)

            for line in tqdm(lines):
                dev = []
                indexes = []
                pages = set()
                # pages = line['predicted_pages']
                pages.update(page for page in line['predicted_pages'])
                # if len(pages) == 0:
                #     pages.add("Michael_Hutchence")
                claim = line['claim']
                p_lines = []
                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    if not doc_lines:
                        continue
                    p_lines.extend(self.get_valid_texts(doc_lines, page))

                for doc_line in p_lines:
                    if not doc_line[0]:
                        continue
                    dev.append((claim, doc_line[0]))
                    indexes.append((doc_line[1], doc_line[2]))
                # print(len(dev))
                if len(dev) == 0:
                    dev.append((claim, 'no evidence for this claim'))
                    indexes.append(('empty', 0))
                devs.append(dev)
                all_indexes.append(indexes)
        return devs, all_indexes

    def dev_processing(self, data_path):

        jlr = JSONLineReader()

        with open(data_path, "r") as f:
            lines = jlr.process(f)

            devs = []
            labels = []
            for line in tqdm(lines):

                dev = []
                label = []
                if line['label'].upper() == "NOT ENOUGH INFO":
                    continue
                evidence_set = set([(evidence[2], evidence[3])
                                    for evidences in line['evidence']
                                    for evidence in evidences])

                pages = [
                    page for page in line['predicted_pages']
                    if page is not None
                ]
                for page, num in evidence_set:
                    pages.append(page)
                pages = set(pages)

                p_lines = []
                for page in pages:
                    doc_lines = self.db.get_doc_lines(page)
                    p_lines.extend(self.get_valid_texts(doc_lines, page))
                for doc_line in p_lines:
                    if not doc_line[0]:
                        continue
                    dev.append((line['claim'], doc_line[0]))
                    if (doc_line[1], doc_line[2]) in evidence_set:
                        label.append(1)
                    else:
                        label.append(0)
                if len(dev) == 0 or len(label) == 0:
                    continue
                devs.append(dev)
                labels.append(label)
        return devs, labels

    def train_data_loader(self, train_sampled_path, data_path, num_samples=1):

        if os.path.exists(train_sampled_path):
            with open(train_sampled_path, 'rb') as f:
                X = pickle.load(f)
        else:
            X = self.sampling(data_path, num_samples)
            with open(train_sampled_path, 'wb') as f:
                pickle.dump(X, f)
        return X

    def dev_data_loader(self, dev_data_path, data_path):

        if os.path.exists(dev_data_path):
            with open(dev_data_path, "rb") as f:
                data = pickle.load(f)
                devs, labels = zip(*data)
        else:
            devs, labels = self.dev_processing(data_path)
            data = list(zip(devs, labels))
            with open(dev_data_path, 'wb') as f:
                pickle.dump(data, f)
        return devs, labels

    def predict_data_loader(self, predict_data_path, data_path):

        if os.path.exists(predict_data_path):
            print(predict_data_path)
            with open(predict_data_path, "rb") as f:
                data = pickle.load(f)
                devs, location_indexes = zip(*data)
        else:
            devs, location_indexes = self.predict_processing(data_path)
            data = list(zip(devs, location_indexes))
            with open(predict_data_path, 'wb') as f:
                pickle.dump(data, f)
        return devs, location_indexes

    def sent_processing(self, sent):
        sent = sent.replace('\n', '')
        sent = sent.replace('-', ' ')
        sent = sent.replace('/', ' ')
        return sent

    def nltk_tokenizer(self, sent):
        # sent = sent_processing(sent)
        return nltk.word_tokenize(sent)

    def get_words(self, claims, sents):

        words = set()
        for claim in claims:
            for idx, word in enumerate(self.nltk_tokenizer(claim)):
                if idx >= self.h_max_length:
                    break
                words.add(word.lower())
        for sent in sents:
            for idx, word in enumerate(self.nltk_tokenizer(sent)):
                if idx >= self.s_max_length:
                    break
                words.add(word.lower())
        return words

    def get_train_words(self, X):
        claims = set()
        sents = []
        for claim, pos, neg in X:
            claims.add(claim)
            sents.append(pos)
            sents.append(neg)

        train_words = self.get_words(claims, sents)
        print("training words processing done!")
        return train_words

    def get_predict_words(self, devs):
        dev_words = set()
        # nlp = StanfordCoreNLP(corenlp_path)
        for dev in tqdm(devs):
            claims = set()
            sents = []
            for pair in dev:
                claims.add(pair[0])
                sents.append(pair[1])
            dev_tokens = self.get_words(claims, sents)
            dev_words.update(dev_tokens)
        print("dev_words processing done!")
        return dev_words

    def word_2_dict(self, words):
        word_dict = {}
        for idx, word in enumerate(words):
            word = word.replace('\n', '')
            word = word.replace('\t', '')
            word_dict[idx] = word

        return word_dict

    def inverse_word_dict(self, word_dict):

        iword_dict = {}
        for key, word in word_dict.items():
            iword_dict[word] = key
        return iword_dict

    def load_fasttext(self, iword_dict):

        embed_dict = {}
        print(self.fasttext_path)
        model = FastText(self.fasttext_path)
        for word, key in iword_dict.items():
            embed_dict[key] = model[word]
            # print(embed_dict[key])
        print('Embedding size: %d' % (len(embed_dict)))
        return embed_dict

    def embed_to_numpy(self, embed_dict):

        feat_size = len(embed_dict[list(embed_dict.keys())[0]])
        if self.reserve_embed:
            embed = np.zeros((len(embed_dict) + 200000, feat_size), np.float32)
        else:
            embed = np.zeros((len(embed_dict), feat_size), np.float32)
        for k in embed_dict:
            embed[k] = np.asarray(embed_dict[k])
        print('Generate numpy embed:', embed.shape)

        return embed

    def sent_2_index(self, sent, word_dict, max_length):
        words = self.nltk_tokenizer(sent)
        word_indexes = []
        for idx, word in enumerate(words):
            if idx >= max_length:
                break
            else:
                word_indexes.append(word_dict[word.lower()])
        return word_indexes

    def train_data_indexes(self, X, word_dict):

        X_indexes = []
        print("start index words into intergers")
        for claim, pos, neg in X:
            claim_indexes = self.sent_2_index(claim, word_dict,
                                              self.h_max_length)
            pos_indexes = self.sent_2_index(pos, word_dict, self.s_max_length)
            neg_indexes = self.sent_2_index(neg, word_dict, self.s_max_length)
            X_indexes.append((claim_indexes, pos_indexes, neg_indexes))
        print('Training data size:', len(X_indexes))
        return X_indexes

    def predict_data_indexes(self, data, word_dict):

        devs_indexes = []
        for dev in data:
            sent_indexes = []
            claim = dev[0][0]
            claim_index = self.sent_2_index(claim, word_dict,
                                            self.h_max_length)
            claim_indexes = [claim_index] * len(dev)
            for claim, sent in dev:
                sent_index = self.sent_2_index(sent, word_dict,
                                               self.s_max_length)
                sent_indexes.append(sent_index)
            assert len(sent_indexes) == len(claim_indexes)
            dev_indexes = list(zip(claim_indexes, sent_indexes))
            devs_indexes.append(dev_indexes)
        return devs_indexes

    def get_complete_words(self, words_dict_path, train_data, dev_data,
                           test_data):

        all_words = set()
        train_words = self.get_train_words(train_data)
        all_words.update(train_words)
        dev_words = self.get_predict_words(dev_data)
        all_words.update(dev_words)
        test_words = self.get_predict_words(test_data)
        all_words.update(test_words)
        word_dict = self.word_2_dict(all_words)
        with open(words_dict_path, "wb") as f:
            pickle.dump(word_dict, f)

        return word_dict

    def train_indexes_loader(self, train_indexes_path, train_data):

        if os.path.exists(train_indexes_path):
            with open(train_indexes_path, "rb") as f:
                X_indexes = pickle.load(f)
        else:
            X_indexes = self.train_data_indexes(train_data, self.iword_dict)
            with open(train_indexes_path, "wb") as f:
                pickle.dump(X_indexes, f)
        return X_indexes

    def predict_indexes_loader(self, predict_indexes_path, predict_data):

        if os.path.exists(predict_indexes_path):
            with open(predict_indexes_path, "rb") as f:
                predicts_indexes = pickle.load(f)
        else:
            predicts_indexes = self.predict_data_indexes(
                predict_data, self.iword_dict)
            with open(predict_indexes_path, "wb") as f:
                pickle.dump(predicts_indexes, f)
        return predicts_indexes

    def update_word_dict(self, test_path):

        self.new_test_datapath = os.path.join(self.embedding_path,
                                              "new_test_data.p")
        new_tests, self.test_location_indexes = self.predict_data_loader(
            self.new_test_datapath, test_path)

        new_test_words = self.get_predict_words(new_tests)
        print(len(self.iword_dict))
        print(len(self.word_dict))
        self.test_words_dict = {}
        for word in new_test_words:
            if word not in self.iword_dict:
                idx = len(self.word_dict)
                self.word_dict[idx] = word
                self.test_words_dict[idx] = word

        self.iword_dict = self.inverse_word_dict(self.word_dict)
        self.test_iword_dict = self.inverse_word_dict(self.test_words_dict)

        print("updated iword dict size: ", len(self.iword_dict))
        print("test iword dict size: ", len(self.test_iword_dict))

    def update_embeddings(self):

        test_embed_dict = self.load_fasttext(self.test_iword_dict)

        for k in test_embed_dict:
            self.embed[k] = np.asarray(test_embed_dict[k])
        print("updated embed size: ", self.embed.shape)

    def get_new_test_indexes(self, test_path):

        new_tests, self.new_test_location_indexes = self.predict_data_loader(
            self.new_test_datapath, test_path)

        new_tests_indexes_path = os.path.join(self.embedding_path,
                                              "new_test_indexes.p")
        self.new_tests_indexes = self.predict_indexes_loader(
            new_tests_indexes_path, new_tests)
def label_sents(db_path, data_path, type="train"):
    """
    This function is to label all sentences in the evidence set to 1 and not in evidence set to 0 for training data
    :param db_path:
    :param data_path:
    :param type:
    :return:
    """

    db = FeverDocDB(db_path)
    jsr = JSONLineReader()
    claims = []
    related_pages_sents = []
    pages_sents_indexes = []
    y = []
    with open(data_path, "r") as f:
        lines = jsr.process(f)
        count = 0
        for line in tqdm(lines):
            if line['label'] == "NOT ENOUGH INFO" and type == "train":
                continue
            p_lines = []
            valid_lines = []
            line_labels = []
            sents_idnexes = []
            claim = line['claim']
            evidences = line['evidence']
            evidence_set = set()
            pages_list = []
            for evidence in evidences:
                for sent in evidence:
                    evidence_set.add((sent[2], sent[3]))
                    pages_list.append(sent[2])
            # predicted_pages = line['predicted_pages']
            predicted_pages = [page[0] for page in line['predicted_pages']]
            predicted_pages = predicted_pages + pages_list
            predicted_pages = set(predicted_pages)
            if len(predicted_pages) > 5:
                count += 1
            claims.append(claim)
            for page in predicted_pages:
                doc_lines = db.get_doc_lines(page)
                if not doc_lines:
                    # print(page)
                    continue
                doc_lines = [
                    doc_line.split("\t")[1]
                    if len(doc_line.split("\t")[1]) > 1 else ""
                    for doc_line in doc_lines.split("\n")
                ]
                p_lines.extend(
                    zip(doc_lines, [page] * len(doc_lines),
                        range(len(doc_lines))))

            for doc_line in p_lines:
                # ignore empty sentences
                if not doc_line[0]:
                    continue
                else:
                    # print(doc_line[0])
                    sents_idnexes.append((doc_line[1], doc_line[2]))
                    valid_lines.append(doc_line[0])
                    is_added = False
                    for sent in evidence_set:
                        if sent[0] == doc_line[1] and sent[1] == doc_line[2]:
                            line_labels.append(1)
                            is_added = True
                            break
                    if is_added != True:
                        line_labels.append(0)
            # print(len(p_lines))
            # print(len(line_labels))
            # print(len(valid_lines))
            assert len(line_labels) == len(valid_lines) == len(sents_idnexes)
            related_pages_sents.append(valid_lines)
            pages_sents_indexes.append(sents_idnexes)
            y.append(line_labels)
    print(count)
    return claims, related_pages_sents, pages_sents_indexes, y
    voc_dict['PAD'] = 0
    voc_dict['UNK'] = 1
    return voc_dict


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('db', help='/path/to/db/file')
    parser.add_argument('output', help='/path/to/output/pickle/file')
    args = parser.parse_args()
    LogHelper.setup()
    logger = LogHelper.get_logger("generate_vocab_all_wiki")
    db = FeverDocDB(args.db)
    vocab = set()
    for doc in tqdm(db.get_doc_ids()):
        lines = db.get_doc_lines(doc)
        lines = lines.split("\n")
        for line in lines:
            segments = line.split("\t")
            if len(segments) < 2:
                continue
            line = segments[1]
            if line.strip() == "":
                continue
            tokens = set(token.lower() for token in tokenize(clean_text(line)))
            vocab.update(tokens)
    logger.info("total size of vocab: " + str(len(vocab)))
    vocab_dict = vocab_map(vocab)
    del vocab
    with open(args.output, 'wb') as f:
        pickle.dump(vocab_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
예제 #20
0
def eval_model(db: FeverDocDB, args) -> Model:
    # archive = load_archive(args.archive_file, cuda_device=args.cuda_device, overrides=args.overrides)

    # config = archive.config
    # ds_params = config["dataset_reader"]
    #
    # model = archive.model
    # model.eval()


    # reader = FEVERReader(db,
    #                              sentence_level=ds_params.pop("sentence_level",False),
    #                              wiki_tokenizer=Tokenizer.from_params(ds_params.pop('wiki_tokenizer', {})),
    #                              claim_tokenizer=Tokenizer.from_params(ds_params.pop('claim_tokenizer', {})),
    #                              token_indexers=TokenIndexer.dict_from_params(ds_params.pop('token_indexers', {})))

    model = NeiRteModel()
    reader = FEVERReader(db, sentence_level=False)


    while True:
        ############### CLAIM
        claim = input("enter claim (or q to quit) >>")
        if claim.lower() == "q":
            break

        ############### DOCUMENT RETRIEVAL
        # ranker = retriever.get_class('tfidf')(tfidf_path=args.model)
        # # ranker = retriever.get_class('tfidf')()
        #
        p_lines = []
        # pages,_ = ranker.closest_docs(claim,5)
        doc_retriever = DrqaDocRetriever(args.model)
        pages, _ = doc_retriever.closest_docs(claim, 5)
        print("Fetched Nearest 5 docs")

        for page in pages:
            lines = db.get_doc_lines(page)
            lines = [line.split("\t")[1] if len(line.split("\t")[1]) > 1 else "" for line in lines.split("\n")]

            p_lines.extend(zip(lines, [page] * len(lines), range(len(lines))))

        lines_field = [pl[0] for pl in p_lines]
        line_indices_field = [pl[2] for pl in p_lines]
        pages_field = [pl[1] for pl in p_lines]

        ############### SENTENCE RETRIEVAL

        # this line would be replaced by a call to the new implementation
        scores = tf_idf_sim(claim, lines_field)

        scores = list(zip(scores, pages_field, line_indices_field, lines_field))
        scores = list(filter(lambda score: len(score[3].strip()), scores))
        sentences_l = list(sorted(scores, reverse=True, key=lambda elem: elem[0]))

        sentences = [s[3] for s in sentences_l[:5]]
        evidence = " ".join(sentences)
        print("Sentences: ", sentences)

        ############### RTE
        print("Best pages: {0}".format(repr(pages)))

        print("Evidence:")
        for idx,sentence in enumerate(sentences_l[:5]):
            print("{0}\t{1}\t\t{2}\t{3}".format(idx+1, sentence[0], sentence[1],sentence[3]) )

        item = reader.text_to_instance(evidence, claim)

        print(f"item: {item}")
        prediction = model.forward(item)
        # prediction = model.forward_on_instance(item, args.cuda_device)
        # cls = model.vocab._index_to_token["labels"][np.argmax(prediction["label_probs"])]
        # print("PREDICTED: {0}".format(cls))
        print("PREDICTED: {0}".format(prediction))
        print("___________________________________")
def test_data(db_path, dataset_path, type="ranking"):
    """
    generate dev examples to feed into the classifier
    :param db_path:
    :param dataset_path:
    :param type:
    :return:
    """

    db = FeverDocDB(db_path)
    jsr = JSONLineReader()

    inputs = []
    X_claim = []
    X_sents = []
    indexes = []

    with open(dataset_path, "r") as f:
        lines = jsr.process(f)

        for line in tqdm(lines):

            p_lines = []
            valid_lines = []
            claims = []
            sents_idnexes = []
            claim = line['claim']
            # X_claim.append([claim])
            predicted_pages = line['predicted_pages']
            for page in predicted_pages:
                # doc_lines = db.get_doc_lines(page[0])
                doc_lines = db.get_doc_lines(page[0])

                if not doc_lines:
                    # print(page)
                    continue
                doc_lines = [doc_line.split("\t")[1] if len(doc_line.split("\t")[1]) > 1 else "" for doc_line in
                             doc_lines.split("\n")]
                p_lines.extend(zip(doc_lines, [page[0]] * len(doc_lines), range(len(doc_lines))))

            for doc_line in p_lines:
                if not doc_line[0]:
                    continue
                else:
                    # print(doc_line[0])
                    if type == "cos":
                        sents_idnexes.append((doc_line[1], doc_line[2]))
                        valid_lines.append(doc_line[0])
                        claims.append(claim)
                    elif type == "ranking":
                        sents_idnexes.append((doc_line[1], doc_line[2]))
                        valid_lines.append((claim, doc_line[0]))
            if type == "cos":
                X_sents.append(valid_lines)
                X_claim.append(claims)
            elif type == "ranking":
                inputs.append(valid_lines)
            indexes.append(sents_idnexes)
        inputs = list(zip(X_claim, X_sents))

        return inputs, indexes
예제 #22
0
    while True:
        # wait for input
        while not os.path.exists(args.dataset):
            pass

        with open(args.dataset, 'r') as dset:
            for line in tqdm.tqdm(dset):
                sample = json.loads(line)
                claim = sample["claim"]
                claim_id = sample["id"]

                pages, _ = doc_retriever.closest_docs(claim, args.n_docs)
                p_lines = []
                for page in pages:
                    lines = db.get_doc_lines(page)
                    lines = [line.split("\t")[1] if len(line.split("\t")[1]) > 1 else "" for line in lines.split("\n")]
                    p_lines.extend(zip(lines, [page] * len(lines), range(len(lines))))
                append_to_file(claim, claim_id, p_lines, db, args.output)

            out_dir = os.path.dirname(os.path.realpath(args.output))
            os.makedirs(out_dir, exist_ok=True)

        with open(args.docs, 'r') as doc_preds:
            for line in tqdm.tqdm(doc_preds):
                sample = json.loads(line)
                claim = sample["claim"]
                claim_id = sample["id"]
                lines_field = sample["lines"]
                line_indices_field = sample["indices"]
                pages_field = sample["page_ids"]
예제 #23
0
from retrieval.fever_doc_db import FeverDocDB

db = FeverDocDB("data/fever/fever.db")
print(db.get_doc_lines("United_States"))