def nearest_neighbors_network_from_index( hash_and_embedding:dbag.Bag, inverted_index_collection:str, batch_size:int, num_neighbors:int, faiss_index_name="final", weight:float=1.0, )->Iterable[nx.Graph]: """ Applies faiss and runs results through inverted index. Requires knn_util:faiss_index and knn_util:inverted_index to be initialized. """ def apply_faiss_to_edges( hash_and_embedding:Iterable[Record], )->Iterable[nx.Graph]: # The only reason we need parts_written_to_db is to make sure that the # writing happens before this point index = dpg.get(f"knn_util:faiss_{faiss_index_name}") inverted_index = {} graph = nx.Graph() for batch in iter_to_batches(hash_and_embedding, batch_size): hashes, embeddings = records_to_ids_and_embeddings( records=batch, ) _, neighs_per_root = index.search(embeddings, num_neighbors) hashes = hashes.tolist() + flatten_list(neighs_per_root.tolist()) hashes = list(set(hashes) - set(inverted_index.keys())) graph_keys = database_util.get( values=hashes, collection=inverted_index_collection, field_name="hash", desired_fields=["strid"] ) for k, v in zip(hashes, graph_keys): inverted_index[k] = v["strid"] # Create records for root_idx, neigh_indices in zip(hashes, neighs_per_root): root = inverted_index[root_idx] if root is None: continue for neigh_idx in neigh_indices: if neigh_idx == root_idx: continue neigh = inverted_index[neigh_idx] if neigh is None: continue graph.add_edge(root, neigh, weight=weight) graph.add_edge(neigh, root, weight=weight) return [graph] return hash_and_embedding.map_partitions(apply_faiss_to_edges)
def to_training_database(bag:dbag.Bag, database_dir:Path): assert database_dir.is_dir() done_file = database_dir.joinpath("__done__") if not done_file.is_file(): def part_to_db(records): r_name = "".join([random.choice(string.ascii_letters) for _ in range(10)]) db_path = database_dir.joinpath(r_name + ".sqlite") with SqliteDict(db_path, journal_mode="OFF", flag="n") as db: for idx, rec in enumerate(records): db[str(idx)] = rec db.commit() return db_path db_paths = bag.map_partitions(part_to_db).compute() with open(done_file, 'w') as f: for p in db_paths: f.write(f"{p}\n")
def nearest_neighbors_network_from_index( hash_and_embedding: dbag.Bag, hash2name_db: Path, batch_size: int, num_neighbors: int, faiss_index_name="final", weight: float = 1.0, ) -> Iterable[nx.Graph]: """ Applies faiss and runs results through inverted index. """ assert hash2name_db.is_file(), "Missing hash2names sqlite3 db." def apply_faiss_to_edges( hash_and_embedding: Iterable[Record], ) -> Iterable[nx.Graph]: # The only reason we need parts_written_to_db is to make sure that the # writing happens before this point index = dpg.get(f"knn_util:faiss_{faiss_index_name}") graph = nx.Graph() with sqlite3_lookup.Sqlite3LookupTable(hash2name_db) as hash2names: for batch in iter_to_batches(hash_and_embedding, batch_size): hashes, embeddings = to_hash_and_embedding(records=batch) _, neighs_per_root = index.search(embeddings, num_neighbors) hashes = hashes.tolist() + flatten_list( neighs_per_root.tolist()) # Create records for root_hash, neigh_indices in zip(hashes, neighs_per_root): if root_hash in hash2names: for root_name in hash2names[root_hash]: for neigh_hash in neigh_indices: if neigh_hash != root_hash and neigh_hash in hash2names: for neigh_name in hash2names[neigh_hash]: graph.add_edge(root_name, neigh_name, weight=weight) graph.add_edge(neigh_name, root_name, weight=weight) return [graph] return hash_and_embedding.map_partitions(apply_faiss_to_edges)
def nearest_neighbors_network_from_index( hash_and_embedding: dbag.Bag, hash2name_db: Path, batch_size: int, num_neighbors: int, faiss_index_name="final", weight: float = 1.0, ) -> Iterable[str]: """ Applies faiss and runs results through inverted index. """ assert hash2name_db.is_file(), "Missing hash2names sqlite3 db." def _apply_faiss(hash_and_embedding: Iterable[Record], ) -> List[Record]: # The only reason we need parts_written_to_db is to make sure that the # writing happens before this point index = dpg.get(f"knn_util:faiss_{faiss_index_name}") hash2names = sqlite3_lookup.Sqlite3LookupTable(hash2name_db) # "id", "neighs" res = [] for batch in iter_to_batches(hash_and_embedding, batch_size): hashes, embeddings = to_hash_and_embedding(records=batch) _, neighs_per_root = index.search(embeddings, num_neighbors) hashes = hashes.tolist() + flatten_list(neighs_per_root.tolist()) # Create records for root_hash, neigh_indices in zip(hashes, neighs_per_root): if root_hash in hash2names: for root_name in hash2names[root_hash]: val = {"id": root_name, "neighs": set()} for neigh_hash in neigh_indices: if neigh_hash != root_hash and neigh_hash in hash2names: for neigh_name in hash2names[neigh_hash]: val["neighs"].add(neigh_name) res.append(val) return res return graph_util.record_to_bipartite_edges( hash_and_embedding.map_partitions(_apply_faiss), bidirectional=True, get_neighbor_keys_fn=lambda r: r["neighs"])
def perform_document_independent_tasks( config: cpb.ConstructConfig, documents: dbag.Bag, ckpt_prefix: str, semrep_work_dir: Optional[Path] = None, ) -> None: """Performs Tasks that don't require communication between documents Performs all of the document processing operations that are required to happen on each document separately. This is important to separate between different input textual features because this allows us to update/invalidate particular sets of checkpoints faster. Args: config: Constriction Configuration documents: Collection of texts to process ckpt_prefix: To stop collisions, and to improve caching, each call to this function should have a different prefix indicating the type of the corresponding documents. For instance, calling this with medline documents could get the `medline` prefix. semrep_work_dir: The location to store semrep intermediate files. Only used if semrep has been installed and configured. """ ckpt("documents", ckpt_prefix) # Split documents into sentences, filter out too-long and too-short sentences. sentences = documents.map_partitions( text_util.split_sentences, # -- min_sentence_len=config.parser.min_sentence_len, max_sentence_len=config.parser.max_sentence_len, ) ckpt("sentences", ckpt_prefix) # Get metadata terms from each sentence coded_term_edges = graph_util.record_to_bipartite_edges( records=sentences, get_neighbor_keys_fn=text_util.get_mesh_keys, ) ckpt("coded_term_edges", ckpt_prefix, textfile=True) # Make edges between each adj sentence adj_sent_edges = graph_util.record_to_bipartite_edges( records=sentences, get_neighbor_keys_fn=text_util.get_adjacent_sentences, # We can store only one side of the connection because each sentence will # get their own neighbors. Additionally, these should all have the same # sort of connections. bidirectional=False, ) ckpt("adj_sent_edges", ckpt_prefix, textfile=True) # Apply lemmatization and entity extraction to sentences parsed_sentences = sentences.map_partitions( text_util.analyze_sentences, # -- text_field="sent_text", ) ckpt("parsed_sentences", ckpt_prefix) # Get lemma edges lemma_edges = graph_util.record_to_bipartite_edges( records=parsed_sentences, get_neighbor_keys_fn=text_util.get_interesting_token_keys, ) ckpt("lemma_edges", ckpt_prefix, textfile=True) # Get entity edges entity_edges = graph_util.record_to_bipartite_edges( records=parsed_sentences, get_neighbor_keys_fn=text_util.get_entity_keys, ) ckpt("entity_edges", ckpt_prefix, textfile=True) # If we're running semrep if (config.semrep.HasField("semrep_install_dir") and config.semrep.HasField("metamap_install_dir") and semrep_work_dir is not None): prefixed_semrep_work_dir = semrep_work_dir.joinpath(ckpt_prefix) prefixed_semrep_work_dir.mkdir(parents=True, exist_ok=True) semrep_sentences = \ semrep_util.extract_entities_and_predicates_from_sentences( sentence_records=sentences, unicode_to_ascii_jar_path=config.semrep.unicode_to_ascii_jar_path, semrep_install_dir=config.semrep.semrep_install_dir, work_dir=prefixed_semrep_work_dir, lexicon_year=config.semrep.lexicon_year, mm_data_year=config.semrep.mm_data_year, mm_data_version=config.semrep.mm_data_version, ) ckpt("semrep_sentences", ckpt_prefix) # Embed each sentence embedded_sentences = ( sentences.map_partitions( embedding_util.embed_records, # -- batch_size=config.sys.batch_size, text_field="sent_text", max_sequence_length=config.parser.max_sequence_length, )) ckpt("embedded_sentences", ckpt_prefix) # hash each sentence id hashed_embeddings = ( embedded_sentences.map(lambda x: { "id": misc_util.hash_str_to_int(x["id"]), "embedding": x["embedding"] })) ckpt("hashed_embeddings", ckpt_prefix) hashed_names = ( sentences.map(lambda rec: { "name": rec["id"], "hash": misc_util.hash_str_to_int(rec["id"]), })) ckpt("hashed_names", ckpt_prefix)
def record_to_bipartite_edges( records: dbag.Bag, get_neighbor_keys_fn: Callable[[Record], List[str]], weight_by_tf_idf: bool = True, minimum_document_frequency: int = 2, bidirectional: bool = True, default_weight_multiplier: float = 1.0, get_source_key_fn: Callable[[Record], str] = lambda x: x["id"], ) -> dbag.Bag: """ This function is responsible for extracting edges from records. For example, if you had a bag of records, each containing a set of terms, you might want to get the set of edges between records and terms. @param records: The collection of records we wish to extract edges from. @param get_neighbor_keys_fn: Given a record, return a list of graph keys that are adjacent to the given record @param weight_by_tf_idf: If true, perform tf-idf weighting on edges. In this case, if t is a term, d is a document and C is a corpus, than we calculate 1/((times t occurs in d / log(size of d)) * (size of C / number of d with t)) @param minimum_document_frequency: only used if weight_by_tf_idf is true. Removes nodes among neighbors that don't occur frequently enough. @param bidirectional: If true, we write record->neighbor and neighbor->record. If false, we only write record->neighbor. @param default_weight_multiplier: All weights are multiplied by this. If we aren't calculating tf-idf, this is the value of every weight. @param get_source_key_fn: Given a record, return a graph key that uniquely identifies the root. By default we get the "id" field @return A collection of networkx subgraphs """ def to_id_term_freq_len(records): res = [] for record in records: id_ = get_source_key_fn(record) tfs = {} neighs = get_neighbor_keys_fn(record) for n in neighs: if n in tfs: tfs[n] += 1 else: tfs[n] = 1 res += [(id_, term, freq, len(neighs)) for term, freq in tfs.items()] # columns=id, term, freq, doc_len return res def to_partial_doc_freqs(records): t2df = {} for record in records: for t in set(get_neighbor_keys_fn(record)): if t in t2df: t2df[t] += 1 else: t2df[t] = 1 # columns=term, doc_freq return list(t2df.items()) def calculate_tf_idf_part(part, corpus_size): res = [] for row in part.itertuples(): tfidf = 1.0 / ((row.freq / log(row.doc_len + 1)) * (log(float(corpus_size) / row.doc_freq))) res.append([row.id, row.term, tfidf]) return pd.DataFrame(res, columns=["id", "term", "freq"]) def part_to_graph(id_term_freqs): graph = nx.Graph() for row in id_term_freqs: i, t, f = row[:3] f *= default_weight_multiplier graph.add_edge(i, t, weight=f) if bidirectional: graph.add_edge(t, i, weight=f) return [graph] # a list of (id, term, freq) term_df = (records.map_partitions(to_id_term_freq_len).to_dataframe( meta={ "id": str, "term": str, "freq": float, "doc_len": int, })) if weight_by_tf_idf: # A list of (term, doc_freq) document_frequencies = ( records.map_partitions(to_partial_doc_freqs).to_dataframe( meta={ "term": str, "doc_freq": int, }).groupby("term").sum()) # filter document_frequencies = document_frequencies[ document_frequencies["doc_freq"] >= minimum_document_frequency] corpus_size = records.count() term_df = (term_df.join(document_frequencies, how="inner", on="term").map_partitions( calculate_tf_idf_part, corpus_size=corpus_size, meta={ "id": str, "term": str, "freq": float, })) return term_df.to_bag().map_partitions(part_to_graph)
def get_frequent_ngrams( analyzed_sentences:dbag.Bag, max_ngram_length:int, min_ngram_support:int, min_ngram_support_per_partition:int, token_field:str="tokens", ngram_field:str="ngrams" )->dbag.Bag: """ Adds a new field containing a list of all mined n-grams. N-grams are tuples of strings such that at least one string is not a stopword. Strings are collected from the lemmas of sentences. To be counted, an ngram must occur in at least `min_ngram_support` sentences. """ def part_to_ngram_counts( records:Iterable[Record] )->Iterable[Dict[Tuple[str], int]]: ngram2count = {} for rec in records: def interesting(idx): t = rec[token_field][idx] return not t["stop"] and t["pos"] in INTERESTING_POS_TAGS # beginning of ngram for start_tok_idx in range(len(rec[token_field])): # ngrams must begin with an interesting word if not interesting(start_tok_idx): continue # for each potential n-gram size for ngram_len in range(2, max_ngram_length): end_tok_idx = start_tok_idx + ngram_len # ngrams cannot extend beyond the sentence if end_tok_idx > len(rec[token_field]): continue # ngrams must end with an interesting word if not interesting(end_tok_idx-1): continue # the ngram is an ordered tuple of lemmas ngram = tuple( rec[token_field][tok_idx]["lemma"] for tok_idx in range(start_tok_idx, end_tok_idx) ) if ngram in ngram2count: ngram2count[ngram] += 1 else: ngram2count[ngram] = 1 # filter out all low-occurrence ngrams in this partition return [{ n: c for n, c in ngram2count.items() if c >= min_ngram_support_per_partition }] def valid_ngrams(ngram2count:Dict[str,int])->Set[Tuple[str]]: ngrams = { n for n, c in ngram2count.items() if c >= min_ngram_support } return ngrams def parse_ngrams(record:Record, ngram_model:Set[Tuple[str]]): record[ngram_field] = [] start_tok_idx = 0 while start_tok_idx < len(record[token_field]): incr = 1 # amount to move start_tok_idx # from max -> 2. Match longest for ngram_len in range(max_ngram_length, 1, -1): # get bounds of ngram and make sure its within sentence end_tok_idx = start_tok_idx + ngram_len if end_tok_idx > len(record[token_field]): continue ngram = tuple( record[token_field][tok_idx]["lemma"] for tok_idx in range(start_tok_idx, end_tok_idx) ) # if match if ngram in ngram_model: record[ngram_field].append("_".join(ngram)) # skip over matched terms incr = ngram_len break start_tok_idx += incr return record # Begin the actual function if max_ngram_length < 1: # disable, record empty field for all ngrams def init_nothing(rec:Record)->Record: rec[ngram_field]=[] return rec return analyzed_sentences.map(init_nothing) else: ngram2count = analyzed_sentences.map_partitions( part_to_ngram_counts ).fold( misc_util.merge_counts, initial={} ) ngram_model = delayed(valid_ngrams)(ngram2count) return analyzed_sentences.map(parse_ngrams, ngram_model=ngram_model)