Exemple #1
0
def embed_records(
    records: Iterable[Record],
    batch_size: int,
    text_field: str,
    max_sequence_length: int,
    out_embedding_field: str = "embedding",
) -> Iterable[Record]:
    """
  Introduces an embedding field to each record, indicated the bert embedding
  of the supplied text field.
  """

    dev = dpg.get("embedding_util:device")
    tok, model = dpg.get("embedding_util:tok,model")

    res = []
    for batch in iter_to_batches(records, batch_size):
        texts = list(map(lambda x: x[text_field], batch))
        sequs = pad_sequence(
            sequences=[
                torch.tensor(tok.encode(t)[:max_sequence_length])
                for t in texts
            ],
            batch_first=True,
        ).to(dev)
        with torch.no_grad():
            embs = (model(sequs)[-2].mean(axis=1).cpu().detach().numpy())
        for record, emb in zip(batch, embs):
            record[out_embedding_field] = emb
            res.append(record)
    return res
Exemple #2
0
def generate_predicates(
    abstract_text:str,
    pred_patt_opts=None
)->Iterable[Tuple[str, str, str]]:
  "Requires that pred_util:nlp and pred_util:stopwords be initialized"
  nlp = dpg.get("pred_util:nlp")
  parser = Spacy2ConllParser(nlp=nlp)
  stopwords = dpg.get("pred_util:stopwords")

  doc = nlp(abstract_text)
  for sent in doc.sents:
    # if the sentence is very long
    if len(sent) >= 20:
      word_count = defaultdict(int)
      for tok in sent:
        word_count[str(tok)] += 1
        # if one word dominates the long sentence
      if max(word_count.values()) >= len(sent)*0.2:
        continue  # we likely generated the same word over-and-over
    conllu = "".join(list(parser.parse(input_str=str(sent))))
    for _, pred_patt_parse in load_conllu(conllu):
      predicates = PredPatt(
        pred_patt_parse,
        opts=pred_patt_opts
      ).instances
      for predicate in predicates:
        # We only care about 2-entity predicates
        if len(predicate.arguments) == 2:
          a_ents, b_ents = [
              # Get the set of entities
              filter(
                # Not in the stopword list
                lambda x: x not in stopwords,
                [str(e).strip() for e in nlp(args.phrase()).ents]
              )
              # For each argument
              for args in predicate.arguments
          ]
          # Slight cleaning needed to better match the predicate phrase
          # Note, that PredPatt predicates use ?a and ?b placeholders
          predicate_stmt = (
              re.match(
                r".*\?a(.*)\?b.*", # get text between placeholders
                predicate.phrase()
              )
              .group(1) # get the group matched between the placeholders
              .strip()
          )
          if len(predicate_stmt) > 0:
            # We're going to iterate all predicates
            for a, b in product(a_ents, b_ents):
              if a != b:
                yield (a, predicate_stmt, b)
Exemple #3
0
def analyze_sentences(
    records: Iterable[Record],
    text_field: str,
    token_field: str = "tokens",
    entity_field: str = "entities",
) -> Iterable[Record]:
    """
  Parses the text fields of all records using SciSpacy.
  Requires that text_util:nlp and text_util:stopwords have both been loaded into
  dask_process_global.

  @param records: A partition of records to parse, each must contain `text_field`
  @param text_field: The name of the field we wish to parse.
  @param token_field: The output field for all basic tokens. These are
  sub-records containing information such as POS tag and lemma.
  @param entity_field: The output field for all entities, which are multi-token
  phrases.
  @return a list of records with token and entity fields
  """
    nlp = dpg.get("text_util:nlp")
    stopwords = dpg.get("text_util:stopwords")

    res = []
    for sent_rec, doc in zip(records,
                             nlp.pipe(map(lambda x: x[text_field],
                                          records), )):
        sent_rec[entity_field] = [
            {
                "tok_start": ent.start,
                "tok_end": ent.end,
                "cha_start": ent.start_char,
                "cha_end": ent.end_char,
                "label": ent.label_,
            } for ent in doc.ents
            if ent.end - ent.start > 1  # don't want 1-gram ents
        ]
        sent_rec[token_field] = [
            {
              "cha_start": tok.idx,
              "cha_end": tok.idx + len(tok),
              "lemma": tok.lemma_,
              "pos": tok.pos_,
              "tag": tok.tag_,
              "dep": tok.dep_,
              "stop": \
                  tok.lemma_ in stopwords or tok.text.strip().lower() in stopwords
            }
            for tok in doc
        ]
        res.append(sent_rec)
    return res
Exemple #4
0
 def _init():
     device = dpg.get("embedding_util:device")
     tok = BertTokenizer.from_pretrained(bert_model)
     model = BertModel.from_pretrained(bert_model)
     model.eval()
     model.to(device)
     return (tok, model)
Exemple #5
0
    def apply_faiss_to_edges(
        hash_and_embedding: Iterable[Record], ) -> Iterable[nx.Graph]:

        # The only reason we need parts_written_to_db is to make sure that the
        # writing happens before this point
        index = dpg.get(f"knn_util:faiss_{faiss_index_name}")

        graph = nx.Graph()
        with sqlite3_lookup.Sqlite3LookupTable(hash2name_db) as hash2names:
            for batch in iter_to_batches(hash_and_embedding, batch_size):
                hashes, embeddings = to_hash_and_embedding(records=batch)
                _, neighs_per_root = index.search(embeddings, num_neighbors)
                hashes = hashes.tolist() + flatten_list(
                    neighs_per_root.tolist())
                # Create records
                for root_hash, neigh_indices in zip(hashes, neighs_per_root):
                    if root_hash in hash2names:
                        for root_name in hash2names[root_hash]:
                            for neigh_hash in neigh_indices:
                                if neigh_hash != root_hash and neigh_hash in hash2names:
                                    for neigh_name in hash2names[neigh_hash]:
                                        graph.add_edge(root_name,
                                                       neigh_name,
                                                       weight=weight)
                                        graph.add_edge(neigh_name,
                                                       root_name,
                                                       weight=weight)
        return [graph]
Exemple #6
0
def _sentence_partition_to_records(
    records: List[Record],
    unicode_to_ascii_jar_path: Path,
    input_path: Path,
    semrep_install_dir: Path,
    lexicon_year: int,
    mm_data_year: str,
    mm_data_version: str,
) -> List[Record]:
  input_path = Path(input_path)
  # output_path = Path(output_path)
  # Convert Sentences for SemRep Input
  if not input_path.is_file():
    with open(input_path, 'w') as input_file:
      for line in sentences_to_semrep_input(records, unicode_to_ascii_jar_path):
        input_file.write(f"{line}\n")
  # Process text with SemRep
  # if not output_path.is_file():
  records = SemRepRunner(
      semrep_install_dir=semrep_install_dir,
      metamap_server=dpg.get("semrep:metamap_server"),
      lexicon_year=lexicon_year,
      mm_data_year=mm_data_year,
      mm_data_version=mm_data_version,
  ).run(input_path)
  return records
Exemple #7
0
 def _init():
     device = dpg.get("embedding_util:device")
     model = model_class(**model_kwargs)
     model.load_state_dict(torch.load(
         str(data_dir),
         map_location=device,
     ))
     model.eval()
     return model
Exemple #8
0
    def _apply_faiss(hash_and_embedding: Iterable[Record], ) -> List[Record]:

        # The only reason we need parts_written_to_db is to make sure that the
        # writing happens before this point
        index = dpg.get(f"knn_util:faiss_{faiss_index_name}")
        hash2names = sqlite3_lookup.Sqlite3LookupTable(hash2name_db)

        # "id", "neighs"
        res = []
        for batch in iter_to_batches(hash_and_embedding, batch_size):
            hashes, embeddings = to_hash_and_embedding(records=batch)
            _, neighs_per_root = index.search(embeddings, num_neighbors)
            hashes = hashes.tolist() + flatten_list(neighs_per_root.tolist())
            # Create records
            for root_hash, neigh_indices in zip(hashes, neighs_per_root):
                if root_hash in hash2names:
                    for root_name in hash2names[root_hash]:
                        val = {"id": root_name, "neighs": set()}
                        for neigh_hash in neigh_indices:
                            if neigh_hash != root_hash and neigh_hash in hash2names:
                                for neigh_name in hash2names[neigh_hash]:
                                    val["neighs"].add(neigh_name)
                        res.append(val)
        return res