Esempio n. 1
0
    def test_dataloader_serializer(self) -> None:
        """Test dataloader config serializer."""
        data = {
            "last_batch": "rollover",
            "batch_size": 2,
            "dataset": {
                "TestDataset": {
                    "dataset_param": "/some/path",
                    "bool_param": True,
                    "list_param": ["item1", "item2"],
                },
            },
            "transform": {
                "TestTransform": {
                    "shape": [1000, 224, 224, 3],
                    "some_op": True
                },
                "AnotherTestTransform": {
                    "shape": [10, 299, 299, 3],
                    "some_op": False
                },
            },
            "filter": {
                "LabelBalance": {
                    "size": 1
                },
            },
        }
        dataloader = Dataloader(data)

        self.assertDictEqual(
            dataloader.serialize(),
            {
                "last_batch": "rollover",
                "batch_size": 2,
                "dataset": {
                    "TestDataset": {
                        "dataset_param": "/some/path",
                        "bool_param": True,
                        "list_param": ["item1", "item2"],
                    },
                },
                "transform": {
                    "TestTransform": {
                        "shape": [1000, 224, 224, 3],
                        "some_op": True
                    },
                    "AnotherTestTransform": {
                        "shape": [10, 299, 299, 3],
                        "some_op": False,
                    },
                },
                "filter": {
                    "LabelBalance": {
                        "size": 1
                    },
                },
            },
        )
Esempio n. 2
0
 def test_dataloader_serializer_defaults(self) -> None:
     """Test dataloader config serializer defaults."""
     dataloader = Dataloader()
     self.assertDictEqual(
         dataloader.serialize(),
         {
             "batch_size": 1,
         },
     )