예제 #1
0
    def pack(
        self,
        pack: PackType,
        predict_results: Dict[str, List[Any]],
        context: Optional[Annotation] = None,
    ):
        #     pass
        # def pack(self, data_pack: DataPack, output_dict: Optional[Dict] =
        # None):
        r"""Add corresponding fields to data_pack"""
        if predict_results is None:
            return

        for i in range(len(predict_results["RelationLink"]["parent.tid"])):
            for j in range(
                len(predict_results["RelationLink"]["parent.tid"][i])
            ):
                link = RelationLink(pack)
                link.rel_type = predict_results["RelationLink"]["rel_type"][i][
                    j
                ]
                parent: EntityMention = pack.get_entry(  # type: ignore
                    predict_results["RelationLink"]["parent.tid"][i][j]
                )
                link.set_parent(parent)
                child: EntityMention = pack.get_entry(  # type: ignore
                    predict_results["RelationLink"]["child.tid"][i][j]
                )
                link.set_child(child)
예제 #2
0
    def cache_data(self, collection: Any, pack: PackType, append: bool):
        r"""Specify the path to the cache directory.

        After you call this method, the dataset reader will use its
        ``cache_directory`` to store a cache of :class:`BasePack` read
        from every document passed to :func:`read`, serialized as one
        string-formatted :class:`BasePack`. If the cache file for a given
        ``file_path`` exists, we read the :class:`BasePack` from the cache.
        If the cache file does not exist, we will `create` it on our first
        pass through the data.

        Args:
            collection: The collection is a piece of data from the
                :meth:`_collect` function, to be read to produce DataPack(s).
                During caching, a cache key is computed based on the data in
                this collection.
            pack: The data pack to be cached.
            append: Whether to allow appending to the cache.
        """
        if not self._cache_directory:
            raise ValueError("Can not cache without a cache_directory!")

        os.makedirs(self._cache_directory, exist_ok=True)

        cache_filename = os.path.join(self._cache_directory,
                                      self._get_cache_location(collection))

        logger.info("Caching pack to %s", cache_filename)
        if append:
            with open(cache_filename, 'a') as cache:
                cache.write(pack.serialize() + "\n")
        else:
            with open(cache_filename, 'w') as cache:
                cache.write(pack.serialize() + "\n")
예제 #3
0
    def process(self, input_pack: PackType):
        # Set the component for recording purpose.
        input_pack.set_control_component(self.name)
        self._process(input_pack)

        # Change status for pack processors
        q_index = self._process_manager.current_queue_index
        u_index = self._process_manager.unprocessed_queue_indices[q_index]
        current_queue = self._process_manager.current_queue

        for job_i in itertools.islice(current_queue, 0, u_index + 1):
            if job_i.status == ProcessJobStatus.UNPROCESSED:
                job_i.set_status(ProcessJobStatus.PROCESSED)
예제 #4
0
 def pack(
     self,
     pack: PackType,
     predict_results: Dict,
     context: Optional[Annotation] = None,
 ):
     for tag, batched_predictions in predict_results.items():
         # preds contains batched results.
         if self.do_eval:
             self.__extractor(tag).pre_evaluation_action(pack, context)
         for prediction in batched_predictions:
             self.__extractor(tag).add_to_pack(pack, prediction, context)
     pack.add_all_remaining_entries()
예제 #5
0
파일: writers.py 프로젝트: huzecong/forte
    def _process(self, input_pack: PackType):
        sub_path = self.sub_output_path(input_pack)
        if sub_path == '':
            raise ValueError("No concrete path provided from sub_output_path.")

        p = os.path.join(self.root_output_dir, sub_path)
        ensure_dir(p)

        if self.zip_pack:
            with gzip.open(p + '.gz', 'wt') as out:
                out.write(input_pack.serialize())
        else:
            with open(p, 'w') as out:
                out.write(input_pack.serialize())
예제 #6
0
    def _get_data_batch(
        self,
        data_pack: PackType,
        context_type: Type[Annotation],
        requests: Optional[DataRequest] = None,
        offset: int = 0,
    ) -> Iterable[Tuple[Dict[Any, Any], int]]:
        r"""Get data batches based on the requests.

        Args:
            data_pack: The data pack to retrieve data from.
            context_type: The context type of the data pack.
                This is not used and is only for compatibility reason.
            requests: The request detail.
                This is not used and is only for compatiblilty reason.
            offset: The offset for get_data.
                This is not used and is only for compatibility reason.
        """
        packs: List[PackType] = []
        instances: List[Annotation] = []
        features_collection: List[Dict[str, Feature]] = []
        current_size = self.pool_size

        for instance in list(data_pack.get(self.scope)):
            features = {}
            for tag, scheme in self.feature_scheme.items():
                features[tag] = scheme["extractor"].extract(
                    data_pack, instance)
            packs.append(data_pack)
            instances.append(instance)
            features_collection.append(features)

            if len(instances) == self.batch_size - current_size:
                self.batch_is_full = True
                batch = {"dummy": (packs, instances, features_collection)}
                yield batch, len(instances)
                self.batch_is_full = False
                packs = []
                instances = []
                features_collection = []
                current_size = self.pool_size

        # Flush the remaining data.
        if len(instances) > 0:
            batch = {"dummy": (packs, instances, features_collection)}
            yield batch, len(instances)
예제 #7
0
    def _get_data_batch(
        self,
        data_pack: PackType,
    ) -> Iterable[Tuple[Dict[Any, Any], int]]:
        r"""Get data batches based on the requests.

        Args:
            data_pack: The data pack to retrieve data from.
        """
        packs: List[PackType] = []
        contexts: List[Annotation] = []
        features_collection: List[Dict[str, Feature]] = []
        current_size = self.pool_size

        for instance in data_pack.get(self._context_type):
            contexts.append(instance)
            features = {}
            for tag, scheme in self._feature_scheme.items():
                features[tag] = scheme["extractor"].extract(data_pack)
            packs.append(data_pack)
            features_collection.append(features)

            if len(contexts) == self.batch_size - current_size:
                self.batch_is_full = True

                batch = {
                    "packs": packs,
                    "contexts": contexts,
                    "features": features_collection,
                }

                yield batch, len(contexts)
                self.batch_is_full = False
                packs = []
                contexts = []
                features_collection = []
                current_size = self.pool_size

        # Flush the remaining data.
        if len(contexts) > 0:
            batch = {
                "packs": packs,
                "contexts": contexts,
                "features": features_collection,
            }
            yield batch, len(contexts)
예제 #8
0
def write_pack(input_pack: PackType,
               output_dir: str,
               sub_path: str,
               indent: Optional[int] = None,
               zip_pack: bool = False,
               overwrite: bool = False) -> str:
    """
    Write a pack to a path.

    Args:
        input_pack: A Pack to be written.
        output_dir: The output directory.
        sub_path: The file name for this pack.
        indent: Whether to format JSON with an indent.
        zip_pack: Whether to zip the output JSON.
        overwrite: Whether to overwrite the file if already exists.

    Returns:
        If successfully written, will return the path of the output file.
        otherwise, will return None.

    """
    output_path = os.path.join(output_dir, sub_path) + '.json'
    if overwrite or not os.path.exists(output_path):
        if zip_pack:
            output_path = output_path + '.gz'

        ensure_dir(output_path)

        out_str: str = input_pack.serialize()

        if indent:
            out_str = json.dumps(json.loads(out_str), indent=indent)

        if zip_pack:
            with gzip.open(output_path, 'wt') as out:
                out.write(out_str)
        else:
            with open(output_path, 'w') as out:
                out.write(out_str)

    logging.info("Writing a pack to %s", output_path)
    return output_path
예제 #9
0
 def serialize_instance(instance: PackType) -> str:
     """
     Serialize a pack to a string.
     """
     return instance.serialize()