示例#1
0
def train_distributed_knn(
    hash_and_embedding: dbag.Bag,
    batch_size: int,
    num_centroids: int,
    num_probes: int,
    num_quantizers: int,
    bits_per_quantizer: int,
    training_sample_prob: float,
    shared_scratch_dir: Path,
    final_index_path: Path,
    id_field: str = "id",
    embedding_field: str = "embedding",
) -> Path:
    """
  Computing all of the embeddings and then performing a KNN is a problem for memory.
  So, what we need to do instead is compute batches of embeddings, and use them in Faiss
  to reduce their dimensionality and process the appropriatly.

  I'm so sorry this one function has to do so much...

  @param hash_and_embedding: bag of hash value and embedding values
  @param text_field: input text field that we embed.
  @param id_field: output id field we use to store number hashes
  @param batch_size: number of sentences per batch
  @param num_centroids: number of voronoi cells in approx nn
  @param num_probes: number of cells to consider when querying
  @param num_quantizers: number of sub-vectors to discritize
  @param bits_per_quantizer: bits per sub-vector
  @param shared_scratch_dir: location to store intermediate results.
  @param training_sample_prob: chance a point is trained on
  @return The path you can load the resulting FAISS index
  """
    init_index_path = shared_scratch_dir.joinpath("init.index")

    if not init_index_path.is_file():
        print("\t- Constructing initial index:", init_index_path)
        # First off, we need to get a representative sample for faiss training
        training_data = hash_and_embedding.random_sample(
            prob=training_sample_prob).pluck(embedding_field)

        # Train initial index, store result in init_index_path
        init_index_path = dask.compute(
            dask.delayed(train_initial_index)(
                training_data=training_data,
                num_centroids=num_centroids,
                num_probes=num_probes,
                num_quantizers=num_quantizers,
                bits_per_quantizer=bits_per_quantizer,
                output_path=init_index_path,
            ))
    else:
        print("\t- Using initial index:", init_index_path)

    # For each partition, load embeddings to idx
    partial_idx_paths = []
    for part_idx, part in enumerate(hash_and_embedding.to_delayed()):
        part_path = shared_scratch_dir.joinpath(f"part-{part_idx}.index")
        if part_path.is_file():  # rudimentary ckpt
            partial_idx_paths.append(dask.delayed(part_path))
        else:
            partial_idx_paths.append(
                dask.delayed(add_points_to_index)(
                    records=part,
                    init_index_path=init_index_path,
                    output_path=part_path,
                    batch_size=batch_size,
                ))

    return dask.delayed(merge_index)(
        init_index_path=init_index_path,
        partial_idx_paths=partial_idx_paths,
        final_index_path=final_index_path,
    )
示例#2
0
文件: text_util.py 项目: healx/agatha
def get_frequent_ngrams(analyzed_sentences: dbag.Bag,
                        max_ngram_length: int,
                        min_ngram_support: int,
                        min_ngram_support_per_partition: int,
                        ngram_sample_rate: float,
                        token_field: str = "tokens",
                        ngram_field: str = "ngrams") -> dbag.Bag:
    """
  Adds a new field containing a list of all mined n-grams.  N-grams are tuples
  of strings such that at least one string is not a stopword.  Strings are
  collected from the lemmas of sentences.  To be counted, an ngram must occur
  in at least `min_ngram_support` sentences.
  """
    def part_to_ngram_counts(
            records: Iterable[Record]) -> Iterable[Dict[Tuple[str], int]]:
        ngram2count = {}
        for rec in records:

            def interesting(idx):
                t = rec[token_field][idx]
                return not t["stop"] and t["pos"] in INTERESTING_POS_TAGS

            # beginning of ngram
            for start_tok_idx in range(len(rec[token_field])):
                # ngrams must begin with an interesting word
                if not interesting(start_tok_idx):
                    continue
                # for each potential n-gram size
                for ngram_len in range(2, max_ngram_length):
                    end_tok_idx = start_tok_idx + ngram_len
                    # ngrams cannot extend beyond the sentence
                    if end_tok_idx > len(rec[token_field]):
                        continue
                    # ngrams must end with an interesting word
                    if not interesting(end_tok_idx - 1):
                        continue
                    # the ngram is an ordered tuple of lemmas
                    ngram = tuple(
                        rec[token_field][tok_idx]["lemma"]
                        for tok_idx in range(start_tok_idx, end_tok_idx))
                    if ngram in ngram2count:
                        ngram2count[ngram] += 1
                    else:
                        ngram2count[ngram] = 1
        # filter out all low-occurrence ngrams in this partition
        return [{
            n: c
            for n, c in ngram2count.items()
            if c >= min_ngram_support_per_partition
        }]

    def valid_ngrams(ngram2count: Dict[str, int]) -> Set[Tuple[str]]:
        ngrams = {n for n, c in ngram2count.items() if c >= min_ngram_support}
        return ngrams

    def parse_ngrams(record: Record, ngram_model: Set[Tuple[str]]):
        record[ngram_field] = []
        start_tok_idx = 0
        while start_tok_idx < len(record[token_field]):
            incr = 1  # amount to move start_tok_idx
            # from max -> 2. Match longest
            for ngram_len in range(max_ngram_length, 1, -1):
                # get bounds of ngram and make sure its within sentence
                end_tok_idx = start_tok_idx + ngram_len
                if end_tok_idx > len(record[token_field]):
                    continue
                ngram = tuple(record[token_field][tok_idx]["lemma"]
                              for tok_idx in range(start_tok_idx, end_tok_idx))
                # if match
                if ngram in ngram_model:
                    record[ngram_field].append("_".join(ngram))
                    # skip over matched terms
                    incr = ngram_len
                    break
            start_tok_idx += incr
        return record

    # Begin the actual function
    if max_ngram_length < 1:
        # disable, record empty field for all ngrams
        def init_nothing(rec: Record) -> Record:
            rec[ngram_field] = []
            return rec

        return analyzed_sentences.map(init_nothing)
    else:
        ngram2count = (analyzed_sentences.random_sample(
            ngram_sample_rate).map_partitions(part_to_ngram_counts).fold(
                misc_util.merge_counts, initial={}))
        ngram_model = delayed(valid_ngrams)(ngram2count)
        return analyzed_sentences.map(parse_ngrams, ngram_model=ngram_model)