Ejemplo n.º 1
0
def analyze_sentences(
    records:Iterable[Record],
    text_field:str,
    token_field:str="tokens",
    entity_field:str="entities",
)->Iterable[Record]:
  """
  Parses the text fields of all records using SciSpacy.
  Requires that text_util:nlp and text_util:stopwords have both been loaded into
  dask_process_global.

  @param records: A partition of records to parse, each must contain `text_field`
  @param text_field: The name of the field we wish to parse.
  @param token_field: The output field for all basic tokens. These are
  sub-records containing information such as POS tag and lemma.
  @param entity_field: The output field for all entities, which are multi-token
  phrases.
  @return a list of records with token and entity fields
  """
  nlp = dpg.get("text_util:nlp")
  stopwords = dpg.get("text_util:stopwords")

  res = []
  for sent_rec, doc in zip(
      records,
      nlp.pipe(
        map(
          lambda x:x[text_field],
          records
        ),
      )
  ):
    sent_rec[entity_field] = [
      {
        "tok_start": ent.start,
        "tok_end": ent.end,
        "cha_start": ent.start_char,
        "cha_end": ent.end_char,
        "label": ent.label_,
      }
      for ent in doc.ents
      if ent.end - ent.start > 1  # don't want 1-gram ents
    ]
    sent_rec[token_field] = [
        {
          "cha_start": tok.idx,
          "cha_end": tok.idx + len(tok),
          "lemma": tok.lemma_,
          "pos": tok.pos_,
          "tag": tok.tag_,
          "dep": tok.dep_,
          "stop": \
              tok.lemma_ in stopwords or tok.text.strip().lower() in stopwords
        }
        for tok in doc
    ]
    res.append(sent_rec)
  return res
Ejemplo n.º 2
0
def set_index(
    collection: str,
    field_name: str,
    index_type: MONGO_INDEX = pymongo.TEXT,
) -> None:
    db = dpg.get("database:db")
    db[collection].create_index([(field_name, index_type)])
Ejemplo n.º 3
0
 def _init():
     device = dpg.get("embedding_util:device")
     tok = BertTokenizer.from_pretrained(bert_model)
     model = BertModel.from_pretrained(bert_model)
     model.eval()
     model.to(device)
     return (tok, model)
Ejemplo n.º 4
0
 def _init():
     device = dpg.get("embedding_util:device")
     model = model_class(**model_kwargs)
     model.load_state_dict(torch.load(
         str(data_dir),
         map_location=device,
     ))
     model.eval()
     return model
Ejemplo n.º 5
0
def get(values: Iterable[str],
        collection: str,
        field_name: str,
        desired_fields: List[str] = None,
        **kwargs) -> Iterable[Record]:
    db = dpg.get("database:db")
    if desired_fields is None:
        proj_obj = {}
    else:
        proj_obj = {field_name: 1 for field_name in desired_fields}
    proj_obj["_id"] = 0
    return [db[collection].find_one({field_name: v}, proj_obj) for v in values]
Ejemplo n.º 6
0
def put(records: Iterable[Record], collection: str, **kwargs) -> None:
    """
  Inserts all the records, returns the resulting count. Note that
  additional_args are allowed to specify delayed dependencies.
  """
    db = dpg.get("database:db")
    for r in records:
        try:
            db[collection].insert(r)
        except pymongo.errors.InvalidOperation as e:
            print("Encountered non-fatal issue:", e)
            pass
Ejemplo n.º 7
0
def apply_sentence_classifier_to_part(
    records: Iterable[Record],
    batch_size: int,
    sentence_classifier_name="sentence_classifier",
    predicted_type_suffix=":pred",
    sentence_type_field="sent_type",
) -> Iterable[Record]:
    device = dpg.get("embedding_util:device")
    model = dpg.get(f"embedding_util:{sentence_classifier_name}")

    res = []
    for rec_batch in iter_to_batches(records, batch_size):
        model_input = torch.stack([
            record_to_sentence_classifier_input(r) for r in rec_batch
        ]).to(device)
        predicted_labels = sentence_classifier_output_to_labels(
            model(model_input))
        for r, lbl in zip(rec_batch, predicted_labels):
            r[sentence_type_field] = lbl + predicted_type_suffix
            res.append(r)
    print(len(res))
    return res
Ejemplo n.º 8
0
def embed_records(
    records: Iterable[Record],
    batch_size: int,
    text_field: str,
    max_sequence_length: int,
    out_embedding_field: str = "embedding",
    show_pbar: bool = False,
) -> Iterable[Record]:
    """
  Introduces an embedding field to each record, indicated the bert embedding
  of the supplied text field.
  """

    dev = dpg.get("embedding_util:device")
    tok, model = dpg.get("embedding_util:tok,model")

    res = []
    # PBar is necessary when using embed_records in helper scripts.
    pbar = tqdm(
        iter_to_batches(records, batch_size),
        total=int(len(records) / batch_size),
        disable=not show_pbar,
    )
    for batch in pbar:
        texts = list(map(lambda x: x[text_field], batch))
        sequs = pad_sequence(
            sequences=[
                torch.tensor(tok.encode(t)[:max_sequence_length])
                for t in texts
            ],
            batch_first=True,
        ).to(dev)
        with torch.no_grad():
            embs = (model(sequs)[-2].mean(axis=1).cpu().detach().numpy())
        for record, emb in zip(batch, embs):
            record[out_embedding_field] = emb
            res.append(record)
    return res
Ejemplo n.º 9
0
  def apply_faiss_to_edges(
      hash_and_embedding:Iterable[Record],
  )->Iterable[nx.Graph]:

    # The only reason we need parts_written_to_db is to make sure that the
    # writing happens before this point
    index = dpg.get(f"knn_util:faiss_{faiss_index_name}")
    inverted_index = {}

    graph = nx.Graph()
    for batch in iter_to_batches(hash_and_embedding, batch_size):
      hashes, embeddings = records_to_ids_and_embeddings(
          records=batch,
      )
      _, neighs_per_root = index.search(embeddings, num_neighbors)

      hashes = hashes.tolist() + flatten_list(neighs_per_root.tolist())
      hashes = list(set(hashes) - set(inverted_index.keys()))

      graph_keys = database_util.get(
          values=hashes,
          collection=inverted_index_collection,
          field_name="hash",
          desired_fields=["strid"]
      )
      for k, v in zip(hashes, graph_keys):
        inverted_index[k] = v["strid"]

      # Create records
      for root_idx, neigh_indices in zip(hashes, neighs_per_root):
        root = inverted_index[root_idx]
        if root is None:
          continue
        for neigh_idx in neigh_indices:
          if neigh_idx == root_idx:
            continue
          neigh = inverted_index[neigh_idx]
          if neigh is None:
            continue
          graph.add_edge(root, neigh, weight=weight)
          graph.add_edge(neigh, root, weight=weight)
    return [graph]
Ejemplo n.º 10
0
def clear_collection(collection: str, **kwargs) -> None:
    db = dpg.get("database:db")
    db[collection].drop()