def train_distributed_knn( hash_and_embedding: dbag.Bag, batch_size: int, num_centroids: int, num_probes: int, num_quantizers: int, bits_per_quantizer: int, training_sample_prob: float, shared_scratch_dir: Path, final_index_path: Path, id_field: str = "id", embedding_field: str = "embedding", ) -> Path: """ Computing all of the embeddings and then performing a KNN is a problem for memory. So, what we need to do instead is compute batches of embeddings, and use them in Faiss to reduce their dimensionality and process the appropriatly. I'm so sorry this one function has to do so much... @param hash_and_embedding: bag of hash value and embedding values @param text_field: input text field that we embed. @param id_field: output id field we use to store number hashes @param batch_size: number of sentences per batch @param num_centroids: number of voronoi cells in approx nn @param num_probes: number of cells to consider when querying @param num_quantizers: number of sub-vectors to discritize @param bits_per_quantizer: bits per sub-vector @param shared_scratch_dir: location to store intermediate results. @param training_sample_prob: chance a point is trained on @return The path you can load the resulting FAISS index """ init_index_path = shared_scratch_dir.joinpath("init.index") if not init_index_path.is_file(): print("\t- Constructing initial index:", init_index_path) # First off, we need to get a representative sample for faiss training training_data = hash_and_embedding.random_sample( prob=training_sample_prob).pluck(embedding_field) # Train initial index, store result in init_index_path init_index_path = dask.compute( dask.delayed(train_initial_index)( training_data=training_data, num_centroids=num_centroids, num_probes=num_probes, num_quantizers=num_quantizers, bits_per_quantizer=bits_per_quantizer, output_path=init_index_path, )) else: print("\t- Using initial index:", init_index_path) # For each partition, load embeddings to idx partial_idx_paths = [] for part_idx, part in enumerate(hash_and_embedding.to_delayed()): part_path = shared_scratch_dir.joinpath(f"part-{part_idx}.index") if part_path.is_file(): # rudimentary ckpt partial_idx_paths.append(dask.delayed(part_path)) else: partial_idx_paths.append( dask.delayed(add_points_to_index)( records=part, init_index_path=init_index_path, output_path=part_path, batch_size=batch_size, )) return dask.delayed(merge_index)( init_index_path=init_index_path, partial_idx_paths=partial_idx_paths, final_index_path=final_index_path, )
def get_frequent_ngrams(analyzed_sentences: dbag.Bag, max_ngram_length: int, min_ngram_support: int, min_ngram_support_per_partition: int, ngram_sample_rate: float, token_field: str = "tokens", ngram_field: str = "ngrams") -> dbag.Bag: """ Adds a new field containing a list of all mined n-grams. N-grams are tuples of strings such that at least one string is not a stopword. Strings are collected from the lemmas of sentences. To be counted, an ngram must occur in at least `min_ngram_support` sentences. """ def part_to_ngram_counts( records: Iterable[Record]) -> Iterable[Dict[Tuple[str], int]]: ngram2count = {} for rec in records: def interesting(idx): t = rec[token_field][idx] return not t["stop"] and t["pos"] in INTERESTING_POS_TAGS # beginning of ngram for start_tok_idx in range(len(rec[token_field])): # ngrams must begin with an interesting word if not interesting(start_tok_idx): continue # for each potential n-gram size for ngram_len in range(2, max_ngram_length): end_tok_idx = start_tok_idx + ngram_len # ngrams cannot extend beyond the sentence if end_tok_idx > len(rec[token_field]): continue # ngrams must end with an interesting word if not interesting(end_tok_idx - 1): continue # the ngram is an ordered tuple of lemmas ngram = tuple( rec[token_field][tok_idx]["lemma"] for tok_idx in range(start_tok_idx, end_tok_idx)) if ngram in ngram2count: ngram2count[ngram] += 1 else: ngram2count[ngram] = 1 # filter out all low-occurrence ngrams in this partition return [{ n: c for n, c in ngram2count.items() if c >= min_ngram_support_per_partition }] def valid_ngrams(ngram2count: Dict[str, int]) -> Set[Tuple[str]]: ngrams = {n for n, c in ngram2count.items() if c >= min_ngram_support} return ngrams def parse_ngrams(record: Record, ngram_model: Set[Tuple[str]]): record[ngram_field] = [] start_tok_idx = 0 while start_tok_idx < len(record[token_field]): incr = 1 # amount to move start_tok_idx # from max -> 2. Match longest for ngram_len in range(max_ngram_length, 1, -1): # get bounds of ngram and make sure its within sentence end_tok_idx = start_tok_idx + ngram_len if end_tok_idx > len(record[token_field]): continue ngram = tuple(record[token_field][tok_idx]["lemma"] for tok_idx in range(start_tok_idx, end_tok_idx)) # if match if ngram in ngram_model: record[ngram_field].append("_".join(ngram)) # skip over matched terms incr = ngram_len break start_tok_idx += incr return record # Begin the actual function if max_ngram_length < 1: # disable, record empty field for all ngrams def init_nothing(rec: Record) -> Record: rec[ngram_field] = [] return rec return analyzed_sentences.map(init_nothing) else: ngram2count = (analyzed_sentences.random_sample( ngram_sample_rate).map_partitions(part_to_ngram_counts).fold( misc_util.merge_counts, initial={})) ngram_model = delayed(valid_ngrams)(ngram2count) return analyzed_sentences.map(parse_ngrams, ngram_model=ngram_model)