示例#1
0
def environment():
    """
    Return TaskEnvironment
    """
    car = LabelEntity(id=ID(123456789),
                      name="car",
                      domain=Domain.DETECTION,
                      is_empty=True)
    person = LabelEntity(id=ID(987654321),
                         name="person",
                         domain=Domain.DETECTION,
                         is_empty=True)
    labels_list = [car, person]
    dummy_template = __get_path_to_file("./dummy_template.yaml")
    model_template = parse_model_template(dummy_template)
    hyper_parameters = model_template.hyper_parameters.data
    params = ote_config_helper.create(hyper_parameters)
    labels_schema = LabelSchemaEntity.from_labels(labels_list)
    environment = TaskEnvironment(
        model=None,
        hyper_parameters=params,
        label_schema=labels_schema,
        model_template=model_template,
    )
    return environment
示例#2
0
def generate_label_schema(dataset, task_type):
    """
    Generates label schema depending on task type.
    """

    if task_type == TaskType.CLASSIFICATION and dataset.is_multilabel():
        not_empty_labels = dataset.get_labels()
        assert len(not_empty_labels) > 1
        label_schema = LabelSchemaEntity()
        empty_label = LabelEntity(
            name="Empty label", is_empty=True, domain=Domain.CLASSIFICATION
        )
        empty_group = LabelGroup(
            name="empty", labels=[empty_label], group_type=LabelGroupType.EMPTY_LABEL
        )
        single_groups = []
        for label in not_empty_labels:
            single_groups.append(
                LabelGroup(
                    name=label.name, labels=[label], group_type=LabelGroupType.EXCLUSIVE
                )
            )
            label_schema.add_group(single_groups[-1])
        label_schema.add_group(empty_group, exclusive_with=single_groups)
        return label_schema

    if task_type == TaskType.ANOMALY_CLASSIFICATION:
        return LabelSchemaEntity.from_labels(
            [
                LabelEntity(
                    name="Normal", domain=Domain.ANOMALY_CLASSIFICATION, id=ID(0)
                ),
                LabelEntity(
                    name="Anomalous", domain=Domain.ANOMALY_CLASSIFICATION, id=ID(1)
                ),
            ]
        )

    return LabelSchemaEntity.from_labels(dataset.get_labels())
    def test_flat_label_schema_serialization(self):
        """
        This test serializes flat LabelSchema and checks serialized representation.
        Then it compares deserialized LabelSchema with original one.
        """

        cur_date = now()
        names = ["cat", "dog", "mouse"]
        colors = [
            Color(
                randint(0, 255),  # nosec
                randint(0, 255),  # nosec
                randint(0, 255),  # nosec
                randint(0, 255),  # nosec
            )  # nosec  # noqa
            for _ in range(3)
        ]
        labels = [
            LabelEntity(
                name=name,
                domain=Domain.CLASSIFICATION,
                creation_date=cur_date,
                id=ID(i),
                color=colors[i],
            )
            for i, name in enumerate(names)
        ]
        label_shema = LabelSchemaEntity.from_labels(labels)
        serialized = LabelSchemaMapper.forward(label_shema)

        assert serialized == {
            "label_tree": {"type": "tree", "directed": True, "nodes": [], "edges": []},
            "exclusivity_graph": {
                "type": "graph",
                "directed": False,
                "nodes": [],
                "edges": [],
            },
            "label_groups": [
                {
                    "_id": label_shema.get_groups()[0].id,
                    "name": "from_label_list",
                    "label_ids": ["0", "1", "2"],
                    "relation_type": "EXCLUSIVE",
                }
            ],
            "all_labels": {
                "0": {
                    "_id": "0",
                    "name": "cat",
                    "color": ColorMapper.forward(colors[0]),
                    "hotkey": "",
                    "domain": "CLASSIFICATION",
                    "creation_date": DatetimeMapper.forward(cur_date),
                    "is_empty": False,
                },
                "1": {
                    "_id": "1",
                    "name": "dog",
                    "color": ColorMapper.forward(colors[1]),
                    "hotkey": "",
                    "domain": "CLASSIFICATION",
                    "creation_date": DatetimeMapper.forward(cur_date),
                    "is_empty": False,
                },
                "2": {
                    "_id": "2",
                    "name": "mouse",
                    "color": ColorMapper.forward(colors[2]),
                    "hotkey": "",
                    "domain": "CLASSIFICATION",
                    "creation_date": DatetimeMapper.forward(cur_date),
                    "is_empty": False,
                },
            },
        }

        deserialized = LabelSchemaMapper.backward(serialized)
        assert label_shema == deserialized
    def test_model_entity_sets_values(self):
        """
        <b>Description:</b>
        Check that ModelEntity correctly returns the set values

        <b>Expected results:</b>
        Test passes if ModelEntity correctly returns the set values

        <b>Steps</b>
        1. Check set values in the ModelEntity
        """
        def __get_path_to_file(filename: str):
            """
            Return the path to the file named 'filename', which lives in the tests/entities directory
            """
            return str(Path(__file__).parent / Path(filename))

        car = LabelEntity(name="car", domain=Domain.DETECTION)
        labels_list = [car]
        dummy_template = __get_path_to_file("./dummy_template.yaml")
        model_template = parse_model_template(dummy_template)
        hyper_parameters = model_template.hyper_parameters.data
        params = ote_config_helper.create(hyper_parameters)
        labels_schema = LabelSchemaEntity.from_labels(labels_list)
        environment = TaskEnvironment(
            model=None,
            hyper_parameters=params,
            label_schema=labels_schema,
            model_template=model_template,
        )

        item = self.generate_random_image()
        dataset = DatasetEntity(items=[item])
        score_metric = ScoreMetric(name="Model accuracy", value=0.5)

        model_entity = ModelEntity(train_dataset=self.dataset(),
                                   configuration=self.configuration())

        set_params = {
            "configuration": environment.get_model_configuration(),
            "train_dataset": dataset,
            "id": ID(1234567890),
            "creation_date": self.creation_date,
            "previous_trained_revision": 5,
            "previous_revision": 2,
            "version": 2,
            "tags": ["tree", "person"],
            "model_status": ModelStatus.TRAINED_NO_STATS,
            "model_format": ModelFormat.BASE_FRAMEWORK,
            "performance": Performance(score_metric),
            "training_duration": 5.8,
            "precision": [ModelPrecision.INT8],
            "latency": 328,
            "fps_throughput": 20,
            "target_device": TargetDevice.GPU,
            "target_device_type": "notebook",
            "optimization_methods": [OptimizationMethod.QUANTIZATION],
            "optimization_type": ModelOptimizationType.MO,
            "optimization_objectives": {
                "param": "Test param"
            },
            "performance_improvement": {"speed", 0.5},
            "model_size_reduction": 1.0,
        }

        for key, value in set_params.items():
            setattr(model_entity, key, value)
            assert getattr(model_entity, key) == value

        assert model_entity.is_optimized() is True