Ejemplo n.º 1
0
def embed_records(
    records: Iterable[Record],
    batch_size: int,
    text_field: str,
    max_sequence_length: int,
    out_embedding_field: str = "embedding",
) -> Iterable[Record]:
    """
  Introduces an embedding field to each record, indicated the bert embedding
  of the supplied text field.
  """

    dev = dpg.get("embedding_util:device")
    tok, model = dpg.get("embedding_util:tok,model")

    res = []
    for batch in iter_to_batches(records, batch_size):
        texts = list(map(lambda x: x[text_field], batch))
        sequs = pad_sequence(
            sequences=[
                torch.tensor(tok.encode(t)[:max_sequence_length])
                for t in texts
            ],
            batch_first=True,
        ).to(dev)
        with torch.no_grad():
            embs = (model(sequs)[-2].mean(axis=1).cpu().detach().numpy())
        for record, emb in zip(batch, embs):
            record[out_embedding_field] = emb
            res.append(record)
    return res
Ejemplo n.º 2
0
    def apply_faiss_to_edges(
        hash_and_embedding: Iterable[Record], ) -> Iterable[nx.Graph]:

        # The only reason we need parts_written_to_db is to make sure that the
        # writing happens before this point
        index = dpg.get(f"knn_util:faiss_{faiss_index_name}")

        graph = nx.Graph()
        with sqlite3_lookup.Sqlite3LookupTable(hash2name_db) as hash2names:
            for batch in iter_to_batches(hash_and_embedding, batch_size):
                hashes, embeddings = to_hash_and_embedding(records=batch)
                _, neighs_per_root = index.search(embeddings, num_neighbors)
                hashes = hashes.tolist() + flatten_list(
                    neighs_per_root.tolist())
                # Create records
                for root_hash, neigh_indices in zip(hashes, neighs_per_root):
                    if root_hash in hash2names:
                        for root_name in hash2names[root_hash]:
                            for neigh_hash in neigh_indices:
                                if neigh_hash != root_hash and neigh_hash in hash2names:
                                    for neigh_name in hash2names[neigh_hash]:
                                        graph.add_edge(root_name,
                                                       neigh_name,
                                                       weight=weight)
                                        graph.add_edge(neigh_name,
                                                       root_name,
                                                       weight=weight)
        return [graph]
Ejemplo n.º 3
0
def main(input_db: Path,
         output_dir: Path,
         nodes_per_file: int = 1e6,
         output_file_fmt_str: str = "{:08d}.txt",
         disable_pbar: bool = False):
    """Sqlite3Graph -> Edge Json

  Args:
    input_db: A graph sqlite3 table.
    output_dir: The location of a directory that we are going to make and fill
      with json files.
    nodes_per_file: Each file is going to contain at most this number of nodes.
    output_file_fmt_str: This string will be called with `.format(int)` for
      each output file. Must produce unique names for each string.
  """

    input_db = Path(input_db)
    assert input_db.is_file(), f"Failed to find {input_db}"

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    assert output_dir.is_dir(), f"Failed to find dir: {output_dir}"
    assert len(list(output_dir.iterdir())) == 0, f"{output_dir} is not empty."

    nodes_per_file = int(nodes_per_file)
    assert nodes_per_file > 0, \
        "Must supply positive number of edges per output file."

    try:
        format_test = output_file_fmt_str.format(1)
    except Exception:
        assert False, "output_file_fmt_str must contain format component `{}`"

    print("Opening", input_db)
    graph = Sqlite3Graph(input_db)
    if not disable_pbar:
        num_nodes = len(graph)
        graph = tqdm(graph, total=num_nodes)
        graph.set_description("Reading edges")

    for batch_idx, edge_batch in enumerate(
            iter_to_batches(graph, nodes_per_file)):
        file_path = output_dir.joinpath(
            output_file_fmt_str.format(batch_idx + 1))
        if disable_pbar:
            print(file_path)
        else:
            graph.set_description(f"Writing {file_path}")
        assert not file_path.exists(), \
          f"Error: {file_path} already exists. Is your format string bad?"
        with open(file_path, 'w') as json_file:
            for node, neighbors in edge_batch:
                for neigh in neighbors:
                    json_file.write(json.dumps({"key": node, "value": neigh}))
                    json_file.write("\n")
        if not disable_pbar:
            graph.set_description("Reading edges")
    print("Done!")
Ejemplo n.º 4
0
def add_points_to_index(
    records: Iterable[Record],
    init_index_path: Path,
    batch_size: int,
    output_path: Path,
) -> Path:
    "Loads an initial index, adds the partition to the index, and writes result"
    index = faiss.read_index(str(init_index_path))
    assert index.is_trained

    for batch in iter_to_batches(records, batch_size):
        hashes, embeddings = to_hash_and_embedding(records=batch)
        index.add_with_ids(embeddings, hashes)
    faiss.write_index(index, str(output_path))
    return output_path
Ejemplo n.º 5
0
 def predict_from_terms(
     self, term_pairs:List[Tuple[str, str]]
 )->List[float]:
   assert self._is_forward_ready, "Must run model.init()"
   res = []
   for pair_batch in iter_to_batches(term_pairs, self.hparams.batch_size):
     model_input = observations_to_tensors([
         generate_predicate_observation(
           subj=subj,
           obj=obj,
           neighbors_per_term=self.hparams.neighbors_per_term,
           graph_index=self.graph_index,
           embedding_index=self.embedding_index,
         )
         for subj, obj in term_pairs
     ])
     res += list(self.forward(model_input).cpu().detach().numpy())
   return res
Ejemplo n.º 6
0
    def _apply_faiss(hash_and_embedding: Iterable[Record], ) -> List[Record]:

        # The only reason we need parts_written_to_db is to make sure that the
        # writing happens before this point
        index = dpg.get(f"knn_util:faiss_{faiss_index_name}")
        hash2names = sqlite3_lookup.Sqlite3LookupTable(hash2name_db)

        # "id", "neighs"
        res = []
        for batch in iter_to_batches(hash_and_embedding, batch_size):
            hashes, embeddings = to_hash_and_embedding(records=batch)
            _, neighs_per_root = index.search(embeddings, num_neighbors)
            hashes = hashes.tolist() + flatten_list(neighs_per_root.tolist())
            # Create records
            for root_hash, neigh_indices in zip(hashes, neighs_per_root):
                if root_hash in hash2names:
                    for root_name in hash2names[root_hash]:
                        val = {"id": root_name, "neighs": set()}
                        for neigh_hash in neigh_indices:
                            if neigh_hash != root_hash and neigh_hash in hash2names:
                                for neigh_name in hash2names[neigh_hash]:
                                    val["neighs"].add(neigh_name)
                        res.append(val)
        return res
Ejemplo n.º 7
0
    def predict_from_terms(
        self,
        terms: List[Tuple[str, str]],
        batch_size: int = 1,
    ) -> List[float]:
        """Evaluates the Agatha model for the given set of predicates.

    For each pair of coded terms in `terms`, we produce a prediction in the
    range 0-1. Behind the scenes this means that we will lookup embeddings for
    the terms themselves as well as samples neighbors of each term. Then, these
    samples will be put through the Agatha transformer model to output a
    ranking criteria in 0-1. If this model has been put on gpu with a command
    like `model.cuda()`, then these predictions will happen on GPU. We will
    batch the predictions according to `batch_size`. This can greatly increase
    performance for large prediction sets.

    Note, behind the scenes there a lot of database accesses and caching. This
    means that your first calls to predict_from_terms will be slow. If you want
    to make many predictions quickly, call `model.preload()` before this
    function.

    Example Usage:

    ```python3
    model = torch.load(...)
    model.configure_paths(...)
    model.predict_from_terms([("C0006826", "C0040329")])
    > [0.9951196908950806]
    ```

    Args:
      terms: A list of coded-term name pairs. Coded terms are any elements that
        agatha names with the `m:` prefix. The prefix is optional when specifying
        terms for this function, meaning "C0040329" and "m:c0040329" will both
        correspond to the same embedding.
      batch_size: The number of predicates to predict at once. This is
        especially important when using the GPU.

    Returns:
      A list of prediction values in the `[0,1]` interval. Higher values
      indicate more plausible results. Output `i` corresponds to `terms[i]`.

    """
        self._assert_configured()
        # This will formulate our input as PredicateEmbeddings examples.
        observation_generator = predicate_util.PredicateObservationGenerator(
            graph=self.graph,
            embeddings=self.embeddings,
            neighbor_sample_rate=self.hparams.neighbor_sample_rate,
        )
        # Clean all of the input terms
        predicates = [
            predicate_util.to_predicate_name(
                predicate_util.clean_coded_term(s),
                predicate_util.clean_coded_term(o),
            ) for s, o in terms
        ]

        result = []
        for predicate_batch in iter_to_batches(predicates, batch_size):
            # Get a tensor representing each stacked sample
            batch = predicate_util.collate_predicate_embeddings(
                [observation_generator[p] for p in predicate_batch])
            # Move batch to device
            batch = batch.to(self.get_device())
            result += self.forward(batch).detach().cpu().numpy().tolist()
        return result