예제 #1
0
def count_detected_entities(
    arxiv_id: ArxivId,
    detected_entities_dirkey: str,
    entities_filename: str = "entities.csv",
) -> Optional[int]:

    num_entities_detected = None
    if directories.registered(detected_entities_dirkey):
        detected_entities_path = os.path.join(
            directories.arxiv_subdir(detected_entities_dirkey, arxiv_id),
            entities_filename,
        )
        if os.path.exists(detected_entities_path):
            num_entities_detected = len(
                list(
                    file_utils.load_from_csv(detected_entities_path,
                                             SerializableEntity)))

    return num_entities_detected
예제 #2
0
def count_hues_located(
    arxiv_id: ArxivId,
    hue_locations_dirkey: str,
    hue_locations_filename: str = "hue_locations.csv",
) -> Optional[int]:

    num_hues_located = None
    if directories.registered(hue_locations_dirkey):
        hue_locations_path = os.path.join(
            directories.arxiv_subdir(hue_locations_dirkey, arxiv_id),
            hue_locations_filename,
        )
        if os.path.exists(hue_locations_path):
            num_hues_located = len(
                list(
                    file_utils.load_from_csv(hue_locations_path,
                                             HueLocationInfo)))

    return num_hues_located
예제 #3
0
    def load(self) -> Iterator[PaperProcessingResult]:
        for arxiv_id in self.arxiv_ids:

            # Load the S2 ID for this paper
            s2_id_path = os.path.join(
                directories.arxiv_subdir("s2-metadata", arxiv_id), "s2_id")
            if not os.path.exists(s2_id_path):
                logging.warning("Could not find S2 ID file for %s. Skipping",
                                arxiv_id)
                continue
            with open(s2_id_path) as s2_id_file:
                s2_id = s2_id_file.read()

            # Load in all extracted entities. See note in 'colorize_tex.py' for why entities
            # might be saved in multiple files. If they are, for this upload function to work,
            # each of the entities need to have a unique pair of 'ID' and 'tex_path'.
            entities_dir = directories.arxiv_subdir(
                f"detected-{self.get_entity_name()}", arxiv_id)
            entities: List[SerializableEntity] = []
            for entities_path in glob.glob(
                    os.path.join(entities_dir, "entities*.csv")):
                entities.extend(
                    file_utils.load_from_csv(
                        entities_path,
                        self.get_detected_entity_type(
                            os.path.basename(entities_path)),
                    ))

            # Load locations for entities.
            locations_path = os.path.join(
                directories.arxiv_subdir(f"{self.get_entity_name()}-locations",
                                         arxiv_id),
                "entity_locations.csv",
            )
            if not os.path.exists(locations_path):
                logging.warning(  # pylint: disable=logging-not-lazy
                    "No locations have been saved for entities in command '%s' for paper %s. No entities "
                    + "will be uploaded for this paper.",
                    str(self.get_name()),
                    arxiv_id,
                )
                continue
            entity_location_infos = list(
                file_utils.load_from_csv(locations_path, EntityLocationInfo))

            # Load in contexts for all entities.
            contexts_loaded = False
            contexts_by_entity = {}
            if directories.registered(
                    f"contexts-for-{self.get_entity_name()}"):
                contexts_path = os.path.join(
                    directories.arxiv_subdir(
                        f"contexts-for-{self.get_entity_name()}", arxiv_id),
                    "contexts.csv",
                )
                if os.path.exists(contexts_path):
                    contexts = file_utils.load_from_csv(contexts_path, Context)
                    contexts_by_entity = {c.entity_id: c for c in contexts}
                    contexts_loaded = True

            if not contexts_loaded:
                logging.warning(  # pylint: disable=logging-not-lazy
                    "No contexts have been saved for entities in command '%s' for paper %s. No "
                    + "contexts will be saved for any of these entities.",
                    str(self.get_name()),
                    arxiv_id,
                )

            # Group each entity with its location and context. Then pass all entity information to
            # the upload function.
            entity_summaries = []
            for entity in entities:
                matching_locations = []
                for h in entity_location_infos:
                    if h.entity_id == entity.id_ and h.tex_path == entity.tex_path:
                        matching_locations.append(h)

                entity_summaries.append(
                    EntityExtractionResult(entity, matching_locations,
                                           contexts_by_entity.get(entity.id_)))

            yield PaperProcessingResult(
                arxiv_id=arxiv_id,
                s2_id=s2_id,
                entities=entity_summaries,
            )