示例#1
0
    def __init__(
        self,
        hparams: OTEClassificationParameters,
        label_schema: LabelSchemaEntity,
        model_file: Union[str, bytes],
        weight_file: Union[str, bytes, None] = None,
        device: str = "CPU",
        num_requests: int = 1,
    ):
        """
        Inferencer implementation for OTEDetection using OpenVINO backend.
        :param model: Path to model to load, `.xml`, `.bin` or `.onnx` file.
        :param hparams: Hyper parameters that the model should use.
        :param num_requests: Maximum number of requests that the inferencer can make.
            Good value is the number of available cores. Defaults to 1.
        :param device: Device to run inference on, such as CPU, GPU or MYRIAD. Defaults to "CPU".
        """

        multilabel = len(label_schema.get_groups(False)) > 1 and \
            len(label_schema.get_groups(False)) == len(label_schema.get_labels(include_empty=False))

        self.label_schema = label_schema

        model_adapter = OpenvinoAdapter(create_core(),
                                        model_file,
                                        weight_file,
                                        device=device,
                                        max_num_requests=num_requests)
        self.configuration = {'multilabel': multilabel}
        self.model = Model.create_model("ote_classification",
                                        model_adapter,
                                        self.configuration,
                                        preload=True)

        self.converter = ClassificationToAnnotationConverter(self.label_schema)
    def __init__(self, label_schema: LabelSchemaEntity):
        if len(label_schema.get_labels(False)) == 1:
            self.labels = label_schema.get_labels(include_empty=True)
        else:
            self.labels = label_schema.get_labels(include_empty=False)
        self.empty_label = get_empty_label(label_schema)
        multilabel = len(label_schema.get_groups(False)) > 1 and len(
            label_schema.get_groups(False)) == len(
                label_schema.get_labels(include_empty=False))

        self.hierarchical = False
        if not multilabel and len(label_schema.get_groups(False)) > 1:
            self.labels = get_leaf_labels(label_schema)
            self.hierarchical = True

        self.label_schema = label_schema
    def forward(
        instance: LabelSchemaEntity,
    ) -> dict:
        """Serializes to dict."""

        label_groups = [
            LabelGroupMapper().forward(group)
            for group in instance.get_groups(include_empty=True)
        ]

        output_dict = {
            "label_tree": LabelGraphMapper().forward(instance.label_tree),
            "exclusivity_graph": LabelGraphMapper().forward(instance.exclusivity_graph),
            "label_groups": label_groups,
        }

        output_dict["all_labels"] = {
            IDMapper().forward(label.id): LabelMapper().forward(label)
            for label in instance.get_labels(True)
        }

        return output_dict