Exemple #1
0
    def test_dataloader_serializer(self) -> None:
        """Test dataloader config serializer."""
        data = {
            "last_batch": "rollover",
            "batch_size": 2,
            "dataset": {
                "TestDataset": {
                    "dataset_param": "/some/path",
                    "bool_param": True,
                    "list_param": ["item1", "item2"],
                },
            },
            "transform": {
                "TestTransform": {
                    "shape": [1000, 224, 224, 3],
                    "some_op": True
                },
                "AnotherTestTransform": {
                    "shape": [10, 299, 299, 3],
                    "some_op": False
                },
            },
            "filter": {
                "LabelBalance": {
                    "size": 1
                },
            },
        }
        dataloader = Dataloader(data)

        self.assertDictEqual(
            dataloader.serialize(),
            {
                "last_batch": "rollover",
                "batch_size": 2,
                "dataset": {
                    "TestDataset": {
                        "dataset_param": "/some/path",
                        "bool_param": True,
                        "list_param": ["item1", "item2"],
                    },
                },
                "transform": {
                    "TestTransform": {
                        "shape": [1000, 224, 224, 3],
                        "some_op": True
                    },
                    "AnotherTestTransform": {
                        "shape": [10, 299, 299, 3],
                        "some_op": False,
                    },
                },
                "filter": {
                    "LabelBalance": {
                        "size": 1
                    },
                },
            },
        )
Exemple #2
0
 def test_dataloader_serializer_defaults(self) -> None:
     """Test dataloader config serializer defaults."""
     dataloader = Dataloader()
     self.assertDictEqual(
         dataloader.serialize(),
         {
             "batch_size": 1,
         },
     )
Exemple #3
0
    def test_dataloader_constructor_defaults_batch_overwrite(self) -> None:
        """Test dataloader config constructor defaults with batch overwrite."""
        dataloader = Dataloader(batch_size=8)

        self.assertIsNone(dataloader.last_batch)
        self.assertEqual(dataloader.batch_size, 8)
        self.assertIsNone(dataloader.dataset)
        self.assertDictEqual(dataloader.transform, {})
        self.assertIsNone(dataloader.filter)
Exemple #4
0
    def test_dataloader_constructor_defaults(self) -> None:
        """Test dataloader config constructor defaults."""
        dataloader = Dataloader()

        self.assertIsNone(dataloader.last_batch)
        self.assertEqual(dataloader.batch_size, 1)
        self.assertIsNone(dataloader.dataset)
        self.assertDictEqual(dataloader.transform, {})
        self.assertIsNone(dataloader.filter)
Exemple #5
0
 def test_dataloader_constructor_batch_overwrite(self) -> None:
     """Test dataloader config constructor with batch overwrite."""
     dataloader = Dataloader(
         data={"batch_size": 2},
         batch_size=32,
     )
     self.assertIsNone(dataloader.last_batch)
     self.assertEqual(
         dataloader.batch_size,
         32,
     )  # Batch size from parameter has higher priority
     self.assertIsNone(dataloader.dataset)
     self.assertDictEqual(dataloader.transform, {})
     self.assertIsNone(dataloader.filter)
Exemple #6
0
    def __init__(self, data: Dict[str, Any] = {}):
        """Initialize Configuration Performance class."""
        super().__init__()
        self.warmup: int = data.get("warmup", 10)

        self.iteration: int = data.get("iteration", -1)

        self.configs: Configs = Configs(data.get("configs", {}))

        self.dataloader: Optional[Dataloader] = None
        if isinstance(data.get("dataloader"), dict):
            self.dataloader = Dataloader(data.get("dataloader", {}))

        self.postprocess: Optional[Postprocess] = None
        if isinstance(data.get("postprocess"), dict):
            self.postprocess = Postprocess(data.get("postprocess", {}))
Exemple #7
0
    def __init__(self, data: Dict[str, Any] = {}):
        """Initialize Configuration Accuracy class."""
        super().__init__()
        self.metric = None
        if isinstance(data.get("metric"), dict):
            self.metric = Metric(data.get("metric", {}))

        self.configs = None
        if isinstance(data.get("configs"), dict):
            self.configs = Configs(data.get("configs", {}))

        self.dataloader = None
        if isinstance(data.get("dataloader"), dict):
            self.dataloader = Dataloader(data.get("dataloader", {}))

        self.postprocess = None
        if isinstance(data.get("postprocess"), dict):
            self.postprocess = Postprocess(data.get("postprocess", {}))
Exemple #8
0
    def test_dataloader_constructor(self) -> None:
        """Test dataloader config constructor."""
        data = {
            "last_batch": "rollover",
            "batch_size": 2,
            "dataset": {
                "TestDataset": {
                    "dataset_param": "/some/path",
                    "bool_param": True,
                    "list_param": ["item1", "item2"],
                },
            },
            "transform": {
                "TestTransform": {
                    "shape": [1000, 224, 224, 3],
                    "some_op": True
                },
                "AnotherTestTransform": {
                    "shape": [10, 299, 299, 3],
                    "some_op": False
                },
            },
            "filter": {
                "LabelBalance": {
                    "size": 1
                },
            },
        }
        dataloader = Dataloader(data)

        self.assertEqual(dataloader.last_batch, "rollover")
        self.assertEqual(dataloader.batch_size, 2)
        self.assertIsNotNone(dataloader.dataset)
        self.assertEqual(dataloader.dataset.name, "TestDataset")
        self.assertDictEqual(
            dataloader.dataset.params,
            {
                "dataset_param": "/some/path",
                "bool_param": True,
                "list_param": ["item1", "item2"],
            },
        )
        transform_name, transform = list(dataloader.transform.items())[0]
        self.assertEqual(transform_name, "TestTransform")
        self.assertEqual(transform.name, "TestTransform")
        self.assertDictEqual(
            transform.parameters,
            {
                "shape": [1000, 224, 224, 3],
                "some_op": True,
            },
        )
        transform_name, transform = list(dataloader.transform.items())[1]
        self.assertEqual(transform_name, "AnotherTestTransform")
        self.assertEqual(transform.name, "AnotherTestTransform")
        self.assertDictEqual(
            transform.parameters,
            {
                "shape": [10, 299, 299, 3],
                "some_op": False
            },
        )
        self.assertIsNotNone(dataloader.filter)
        self.assertIsNotNone(dataloader.filter.LabelBalance)
        self.assertEqual(dataloader.filter.LabelBalance.size, 1)