Ejemplo n.º 1
0
    def run_interaction(self, train, compose):
        data = [{
            "image": np.ones((1, 2, 2, 2)).astype(np.float32),
            "label": np.ones((1, 2, 2, 2))
        } for _ in range(5)]
        network = torch.nn.Linear(2, 2)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        train_transforms = Compose([
            FindAllValidSlicesd(label="label", sids="sids"),
            AddInitialSeedPointd(label="label",
                                 guidance="guidance",
                                 sids="sids"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ])
        dataset = Dataset(data, transform=train_transforms)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [
            Activationsd(keys="pred", sigmoid=True),
            ToNumpyd(keys=["image", "label", "pred"]),
            FindDiscrepancyRegionsd(label="label",
                                    pred="pred",
                                    discrepancy="discrepancy"),
            AddRandomGuidanced(guidance="guidance",
                               discrepancy="discrepancy",
                               probability="probability"),
            AddGuidanceSignald(image="image", guidance="guidance"),
            ToTensord(keys=("image", "label")),
        ]
        iteration_transforms = Compose(
            iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms,
                        train=train,
                        max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 6,
                         "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED,
                                 add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED,
                                 add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch[0].get("guidance"),
                             "guidance is missing")
        self.assertEqual(engine.state.best_metric, 9)
Ejemplo n.º 2
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
Ejemplo n.º 3
0
    def run_interaction(self, train, compose):
        data = []
        for i in range(5):
            data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])})
        network = torch.nn.Linear(1, 1)
        lr = 1e-3
        opt = torch.optim.SGD(network.parameters(), lr)
        loss = torch.nn.L1Loss()
        dataset = Dataset(data, transform=None)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

        iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]
        iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms

        i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)
        self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms")

        # set up engine
        engine = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=data_loader,
            network=network,
            optimizer=opt,
            loss_function=loss,
            iteration_update=i,
        )
        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)
        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)

        engine.run()
        self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing")
        self.assertEqual(engine.state.best_metric, 9)
Ejemplo n.º 4
0
 def test_numpy_input(self):
     test_data = np.array([[1, 2], [3, 4]])
     test_data = np.rot90(test_data)
     self.assertFalse(test_data.flags["C_CONTIGUOUS"])
     result = ToNumpyd(keys="img")({"img": test_data})["img"]
     self.assertTrue(isinstance(result, np.ndarray))
     self.assertTrue(result.flags["C_CONTIGUOUS"])
     np.testing.assert_allclose(result, test_data)
Ejemplo n.º 5
0
 def test_tensor_input(self):
     test_data = torch.tensor([[1, 2], [3, 4]])
     test_data = test_data.rot90()
     self.assertFalse(test_data.is_contiguous())
     result = ToNumpyd(keys="img")({"img": test_data})["img"]
     self.assertTrue(isinstance(result, np.ndarray))
     self.assertTrue(result.flags["C_CONTIGUOUS"])
     np.testing.assert_allclose(result, test_data.numpy())
Ejemplo n.º 6
0
 def post_transforms(self, data=None):
     return [
         Activationsd(keys="pred", sigmoid=True),
         AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5),
         ToNumpyd(keys="pred"),
         RestoreLabeld(keys="pred", ref_image="image", mode="nearest"),
         AsChannelLastd(keys="pred"),
     ]
Ejemplo n.º 7
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
     ]
Ejemplo n.º 8
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
         BoundingBoxd(keys="pred", result="result", bbox="bbox"),
     ]
Ejemplo n.º 9
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred", sigmoid=True),
         AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5),
         ToNumpyd(keys="pred"),
         RestoreLabeld(keys="pred", ref_image="image", mode="nearest"),
         AsChannelLastd(keys="pred"),
     ]
Ejemplo n.º 10
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     return [
         EnsureTyped(keys="pred", device=data.get("device") if data else None),
         Activationsd(keys="pred", softmax=len(self.labels) > 1, sigmoid=len(self.labels) == 1),
         AsDiscreted(keys="pred", argmax=len(self.labels) > 1, threshold=0.5 if len(self.labels) == 1 else None),
         SqueezeDimd(keys="pred", dim=0),
         ToNumpyd(keys=("image", "pred")),
         PostFilterLabeld(keys="pred", image="image"),
         FindContoursd(keys="pred", labels=self.labels),
     ]
Ejemplo n.º 11
0
def _default_transforms(image_key, label_key, pixdim):
    keys = [image_key] if label_key is None else [image_key, label_key]
    mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST
            ] if len(keys) == 2 else [GridSampleMode.BILINEAR]
    return Compose([
        LoadImaged(keys=keys),
        AsChannelFirstd(keys=keys),
        Orientationd(keys=keys, axcodes="RAS"),
        Spacingd(keys=keys, pixdim=pixdim, mode=mode),
        FromMetaTensord(keys=keys),
        ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
    ])
Ejemplo n.º 12
0
 def get_click_transforms(self, context: Context):
     return [
         Activationsd(keys="pred", sigmoid=True),
         ToNumpyd(keys=("image", "label", "pred")),
         FindDiscrepancyRegionsd(label="label",
                                 pred="pred",
                                 discrepancy="discrepancy"),
         AddRandomGuidanced(guidance="guidance",
                            discrepancy="discrepancy",
                            probability="probability"),
         AddGuidanceSignald(image="image", guidance="guidance"),
         ToTensord(keys=("image", "label")),
     ]
Ejemplo n.º 13
0
def get_click_transforms():
    return Compose([
        Activationsd(keys="pred", sigmoid=True),
        ToNumpyd(keys=("image", "label", "pred")),
        FindDiscrepancyRegionsd(label="label",
                                pred="pred",
                                discrepancy="discrepancy"),
        AddRandomGuidanced(
            guidance="guidance",
            discrepancy="discrepancy",
            probability="probability",
        ),
        AddGuidanceSignald(image="image", guidance="guidance"),
        EnsureTyped(keys=("image", "label")),
    ])
Ejemplo n.º 14
0
def get_click_transforms():
    return Compose([
        Activationsd(keys='pred', sigmoid=True),
        ToNumpyd(keys=('image', 'label', 'pred', 'probability', 'guidance')),
        FindDiscrepancyRegionsd(label='label',
                                pred='pred',
                                discrepancy='discrepancy',
                                batched=True),
        AddRandomGuidanced(guidance='guidance',
                           discrepancy='discrepancy',
                           probability='probability',
                           batched=True),
        AddGuidanceSignald(image='image', guidance='guidance', batched=True),
        ToTensord(keys=('image', 'label'))
    ])
Ejemplo n.º 15
0
 def get_click_transforms(self, context: Context):
     return [
         Activationsd(keys="pred", softmax=True),
         AsDiscreted(keys="pred", argmax=True),
         ToNumpyd(keys=("image", "label", "pred")),
         # Transforms for click simulation
         FindDiscrepancyRegionsCustomd(keys="label",
                                       pred="pred",
                                       discrepancy="discrepancy"),
         AddRandomGuidanceCustomd(
             keys="NA",
             guidance="guidance",
             discrepancy="discrepancy",
             probability="probability",
         ),
         AddGuidanceSignalCustomd(
             keys="image",
             guidance="guidance",
             number_intensity_ch=self.number_intensity_ch),
         #
         ToTensord(keys=("image", "label")),
     ]
Ejemplo n.º 16
0
 def post_transforms(self, data=None) -> Sequence[Callable]:
     largest_cc = False if not data else data.get("largest_cc", False)
     applied_labels = list(self.labels.values()) if isinstance(
         self.labels, dict) else self.labels
     t = [
         EnsureTyped(keys="pred",
                     device=data.get("device") if data else None),
         Activationsd(keys="pred",
                      softmax=len(self.labels) > 1,
                      sigmoid=len(self.labels) == 1),
         AsDiscreted(keys="pred",
                     argmax=len(self.labels) > 1,
                     threshold=0.5 if len(self.labels) == 1 else None),
     ]
     if largest_cc:
         t.append(
             KeepLargestConnectedComponentd(keys="pred",
                                            applied_labels=applied_labels))
     t.extend([
         ToNumpyd(keys="pred"),
         Restored(keys="pred", ref_image="image"),
     ])
     return t
Ejemplo n.º 17
0
    def test_values(self):
        with tempfile.TemporaryDirectory() as tempdir:
            test_data1 = [
                ["subject_id", "label", "image", "ehr_0", "ehr_1", "ehr_2"],
                [
                    "s000000", 5, "./imgs/s000000.png", 2.007843256,
                    2.29019618, 2.054902077
                ],
                [
                    "s000001", 0, "./imgs/s000001.png", 6.839215755,
                    6.474509716, 5.862744808
                ],
                [
                    "s000002", 4, "./imgs/s000002.png", 3.772548914,
                    4.211764812, 4.635294437
                ],
                [
                    "s000003", 1, "./imgs/s000003.png", 3.333333254,
                    3.235294342, 3.400000095
                ],
                [
                    "s000004", 9, "./imgs/s000004.png", 6.427451134,
                    6.254901886, 5.976470947
                ],
            ]
            test_data2 = [
                [
                    "subject_id", "ehr_3", "ehr_4", "ehr_5", "ehr_6", "ehr_7",
                    "ehr_8"
                ],
                [
                    "s000000", 3.019608021, 3.807843208, 3.584313869,
                    3.141176462, 3.1960783, 4.211764812
                ],
                [
                    "s000001", 5.192157269, 5.274509907, 5.250980377,
                    4.647058964, 4.886274338, 4.392156601
                ],
                [
                    "s000002", 5.298039436, 9.545097351, 12.57254887,
                    6.799999714, 2.1960783, 1.882352948
                ],
                [
                    "s000003", 3.164705753, 3.086274624, 3.725490093,
                    3.698039293, 3.698039055, 3.701960802
                ],
                [
                    "s000004", 6.26274538, 7.717647076, 9.584313393,
                    6.082352638, 2.662744999, 2.34117651
                ],
            ]
            test_data3 = [
                [
                    "subject_id", "ehr_9", "ehr_10", "meta_0", "meta_1",
                    "meta_2"
                ],
                ["s000000", 6.301961422, 6.470588684, "TRUE", "TRUE", "TRUE"],
                [
                    "s000001", 5.219608307, 7.827450752, "FALSE", "TRUE",
                    "FALSE"
                ],
                ["s000002", 1.882352948, 2.031372547, "TRUE", "FALSE", "TRUE"],
                [
                    "s000003", 3.309803963, 3.729412079, "FALSE", "FALSE",
                    "TRUE"
                ],
                ["s000004", 2.062745094, 2.34117651, "FALSE", "TRUE", "TRUE"],
            ]

            def prepare_csv_file(data, filepath):
                with open(filepath, "a") as f:
                    for d in data:
                        f.write((",".join([str(i) for i in d])) + "\n")

            filepath1 = os.path.join(tempdir, "test_data1.csv")
            filepath2 = os.path.join(tempdir, "test_data2.csv")
            filepath3 = os.path.join(tempdir, "test_data3.csv")
            prepare_csv_file(test_data1, filepath1)
            prepare_csv_file(test_data2, filepath2)
            prepare_csv_file(test_data3, filepath3)

            # test single CSV file
            dataset = CSVIterableDataset(filepath1)
            for i, item in enumerate(dataset):
                if i == 2:
                    self.assertDictEqual(
                        {
                            k: round(v, 4) if not isinstance(v, str) else v
                            for k, v in item.items()
                        },
                        {
                            "subject_id": "s000002",
                            "label": 4,
                            "image": "./imgs/s000002.png",
                            "ehr_0": 3.7725,
                            "ehr_1": 4.2118,
                            "ehr_2": 4.6353,
                        },
                    )
                    break
            # test reset iterables
            dataset.reset(filename=filepath3)
            for i, item in enumerate(dataset):
                if i == 3:
                    self.assertEqual(item["meta_0"], False)

            # test multiple CSV files, join tables with kwargs
            dataset = CSVIterableDataset([filepath1, filepath2, filepath3],
                                         on="subject_id")
            for i, item in enumerate(dataset):
                if i == 3:
                    self.assertDictEqual(
                        {
                            k: round(v, 4)
                            if not isinstance(v, (str, np.bool_)) else v
                            for k, v in item.items()
                        },
                        {
                            "subject_id": "s000003",
                            "label": 1,
                            "image": "./imgs/s000003.png",
                            "ehr_0": 3.3333,
                            "ehr_1": 3.2353,
                            "ehr_2": 3.4000,
                            "ehr_3": 3.1647,
                            "ehr_4": 3.0863,
                            "ehr_5": 3.7255,
                            "ehr_6": 3.6980,
                            "ehr_7": 3.6980,
                            "ehr_8": 3.7020,
                            "ehr_9": 3.3098,
                            "ehr_10": 3.7294,
                            "meta_0": False,
                            "meta_1": False,
                            "meta_2": True,
                        },
                    )

            # test selected columns and chunk size
            dataset = CSVIterableDataset(
                filename=[filepath1, filepath2, filepath3],
                chunksize=2,
                col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"],
            )
            for i, item in enumerate(dataset):
                if i == 3:
                    self.assertDictEqual(
                        {
                            k: round(v, 4)
                            if not isinstance(v, (str, np.bool_)) else v
                            for k, v in item.items()
                        },
                        {
                            "subject_id": "s000003",
                            "image": "./imgs/s000003.png",
                            "ehr_1": 3.2353,
                            "ehr_7": 3.6980,
                            "meta_1": False,
                        },
                    )

            # test group columns
            dataset = CSVIterableDataset(
                filename=[filepath1, filepath2, filepath3],
                col_names=[
                    "subject_id", "image", *[f"ehr_{i}" for i in range(11)],
                    "meta_0", "meta_1", "meta_2"
                ],
                col_groups={
                    "ehr": [f"ehr_{i}" for i in range(11)],
                    "meta12": ["meta_1", "meta_2"]
                },
            )
            for i, item in enumerate(dataset):
                if i == 3:
                    np.testing.assert_allclose(
                        [round(i, 4) for i in item["ehr"]],
                        [
                            3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255,
                            3.6980, 3.6980, 3.7020, 3.3098, 3.7294
                        ],
                    )
                    np.testing.assert_allclose(item["meta12"], [False, True])

            # test transform
            dataset = CSVIterableDataset(
                chunksize=2,
                buffer_size=4,
                filename=[filepath1, filepath2, filepath3],
                col_groups={"ehr": [f"ehr_{i}" for i in range(5)]},
                transform=ToNumpyd(keys="ehr"),
                shuffle=True,
            )
            dataset.set_random_state(123)
            expected = [
                [3.7725, 4.2118, 4.6353, 5.298, 9.5451],
                [2.0078, 2.2902, 2.0549, 3.0196, 3.8078],
                [6.4275, 6.2549, 5.9765, 6.2627, 7.7176],
                [6.8392, 6.4745, 5.8627, 5.1922, 5.2745],
            ]
            for item, exp in zip(dataset, expected):
                self.assertTrue(isinstance(item["ehr"], np.ndarray))
                np.testing.assert_allclose(np.around(item["ehr"], 4), exp)

            # test multiple processes loading
            dataset = CSVIterableDataset(filepath1,
                                         transform=ToNumpyd(keys="label"))
            # set num workers = 0 for mac / win
            num_workers = 2 if sys.platform == "linux" else 0
            dataloader = DataLoader(dataset=dataset,
                                    num_workers=num_workers,
                                    batch_size=2)
            for item in dataloader:
                # test the last item which only has 1 data
                if len(item) == 1:
                    self.assertListEqual(item["subject_id"], ["s000002"])
                    np.testing.assert_allclose(item["label"], [4])
                    self.assertListEqual(item["image"], ["./imgs/s000002.png"])
Ejemplo n.º 18
0
    def test_values(self):
        with tempfile.TemporaryDirectory() as tempdir:
            test_data1 = [
                ["subject_id", "label", "image", "ehr_0", "ehr_1", "ehr_2"],
                [
                    "s000000", 5, "./imgs/s000000.png", 2.007843256,
                    2.29019618, 2.054902077
                ],
                [
                    "s000001", 0, "./imgs/s000001.png", 6.839215755,
                    6.474509716, 5.862744808
                ],
                [
                    "s000002", 4, "./imgs/s000002.png", 3.772548914,
                    4.211764812, 4.635294437
                ],
                [
                    "s000003", 1, "./imgs/s000003.png", 3.333333254,
                    3.235294342, 3.400000095
                ],
                [
                    "s000004", 9, "./imgs/s000004.png", 6.427451134,
                    6.254901886, 5.976470947
                ],
            ]
            test_data2 = [
                [
                    "subject_id", "ehr_3", "ehr_4", "ehr_5", "ehr_6", "ehr_7",
                    "ehr_8"
                ],
                [
                    "s000000", 3.019608021, 3.807843208, 3.584313869,
                    3.141176462, 3.1960783, 4.211764812
                ],
                [
                    "s000001", 5.192157269, 5.274509907, 5.250980377,
                    4.647058964, 4.886274338, 4.392156601
                ],
                [
                    "s000002", 5.298039436, 9.545097351, 12.57254887,
                    6.799999714, 2.1960783, 1.882352948
                ],
                [
                    "s000003", 3.164705753, 3.086274624, 3.725490093,
                    3.698039293, 3.698039055, 3.701960802
                ],
                [
                    "s000004", 6.26274538, 7.717647076, 9.584313393,
                    6.082352638, 2.662744999, 2.34117651
                ],
            ]
            test_data3 = [
                [
                    "subject_id", "ehr_9", "ehr_10", "meta_0", "meta_1",
                    "meta_2"
                ],
                ["s000000", 6.301961422, 6.470588684, "TRUE", "TRUE", "TRUE"],
                [
                    "s000001", 5.219608307, 7.827450752, "FALSE", "TRUE",
                    "FALSE"
                ],
                ["s000002", 1.882352948, 2.031372547, "TRUE", "FALSE", "TRUE"],
                [
                    "s000003", 3.309803963, 3.729412079, "FALSE", "FALSE",
                    "TRUE"
                ],
                ["s000004", 2.062745094, 2.34117651, "FALSE", "TRUE", "TRUE"],
                # generate NaN values in the row
                ["s000005", 3.353655643, 1.675674543, "TRUE", "TRUE", "FALSE"],
            ]

            def prepare_csv_file(data, filepath):
                with open(filepath, "a") as f:
                    for d in data:
                        f.write((",".join([str(i) for i in d])) + "\n")

            filepath1 = os.path.join(tempdir, "test_data1.csv")
            filepath2 = os.path.join(tempdir, "test_data2.csv")
            filepath3 = os.path.join(tempdir, "test_data3.csv")
            filepaths = [filepath1, filepath2, filepath3]
            prepare_csv_file(test_data1, filepath1)
            prepare_csv_file(test_data2, filepath2)
            prepare_csv_file(test_data3, filepath3)

            # test single CSV file
            dataset = CSVDataset(filepath1)
            self.assertDictEqual(
                {
                    k: round(v, 4) if not isinstance(v, str) else v
                    for k, v in dataset[2].items()
                },
                {
                    "subject_id": "s000002",
                    "label": 4,
                    "image": "./imgs/s000002.png",
                    "ehr_0": 3.7725,
                    "ehr_1": 4.2118,
                    "ehr_2": 4.6353,
                },
            )

            # test multiple CSV files, join tables with kwargs
            dataset = CSVDataset(filepaths, on="subject_id")
            self.assertDictEqual(
                {
                    k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v
                    for k, v in dataset[3].items()
                },
                {
                    "subject_id": "s000003",
                    "label": 1,
                    "image": "./imgs/s000003.png",
                    "ehr_0": 3.3333,
                    "ehr_1": 3.2353,
                    "ehr_2": 3.4000,
                    "ehr_3": 3.1647,
                    "ehr_4": 3.0863,
                    "ehr_5": 3.7255,
                    "ehr_6": 3.6980,
                    "ehr_7": 3.6980,
                    "ehr_8": 3.7020,
                    "ehr_9": 3.3098,
                    "ehr_10": 3.7294,
                    "meta_0": False,
                    "meta_1": False,
                    "meta_2": True,
                },
            )

            # test selected rows and columns
            dataset = CSVDataset(
                src=filepaths,
                row_indices=[[0, 2], 3],  # load row: 0, 1, 3
                col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"],
            )
            self.assertEqual(len(dataset), 3)
            self.assertDictEqual(
                {
                    k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v
                    for k, v in dataset[-1].items()
                },
                {
                    "subject_id": "s000003",
                    "image": "./imgs/s000003.png",
                    "ehr_1": 3.2353,
                    "ehr_7": 3.6980,
                    "meta_1": False,
                },
            )

            # test group columns
            dataset = CSVDataset(
                src=filepaths,
                row_indices=[1, 3],  # load row: 1, 3
                col_names=[
                    "subject_id", "image", *[f"ehr_{i}" for i in range(11)],
                    "meta_0", "meta_1", "meta_2"
                ],
                col_groups={
                    "ehr": [f"ehr_{i}" for i in range(11)],
                    "meta12": ["meta_1", "meta_2"]
                },
            )
            np.testing.assert_allclose(
                [round(i, 4) for i in dataset[-1]["ehr"]],
                [
                    3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980,
                    3.6980, 3.7020, 3.3098, 3.7294
                ],
            )
            np.testing.assert_allclose(dataset[-1]["meta12"], [False, True])

            # test transform
            dataset = CSVDataset(
                src=filepaths,
                col_groups={"ehr": [f"ehr_{i}" for i in range(5)]},
                transform=ToNumpyd(keys="ehr"))
            self.assertEqual(len(dataset), 5)
            expected = [
                [2.0078, 2.2902, 2.0549, 3.0196, 3.8078],
                [6.8392, 6.4745, 5.8627, 5.1922, 5.2745],
                [3.7725, 4.2118, 4.6353, 5.2980, 9.5451],
                [3.3333, 3.2353, 3.4000, 3.1647, 3.0863],
                [6.4275, 6.2549, 5.9765, 6.2627, 7.7176],
            ]
            for item, exp in zip(dataset, expected):
                self.assertTrue(isinstance(item["ehr"], np.ndarray))
                np.testing.assert_allclose(np.around(item["ehr"], 4), exp)

            # test default values and dtype
            dataset = CSVDataset(
                src=filepaths,
                col_names=["subject_id", "image", "ehr_1", "ehr_9", "meta_1"],
                col_types={
                    "image": {
                        "type": str,
                        "default": "No image"
                    },
                    "ehr_1": {
                        "type": int,
                        "default": 0
                    }
                },
                how="outer",  # generate NaN values in this merge mode
            )
            self.assertEqual(len(dataset), 6)
            self.assertEqual(dataset[-1]["image"], "No image")
            self.assertEqual(type(dataset[-1]["ehr_1"]), int)
            np.testing.assert_allclose(dataset[-1]["ehr_9"], 3.3537, rtol=1e-2)

            # test pre-loaded DataFrame
            df = pd.read_csv(filepath1)
            dataset = CSVDataset(src=df)
            self.assertDictEqual(
                {
                    k: round(v, 4) if not isinstance(v, str) else v
                    for k, v in dataset[2].items()
                },
                {
                    "subject_id": "s000002",
                    "label": 4,
                    "image": "./imgs/s000002.png",
                    "ehr_0": 3.7725,
                    "ehr_1": 4.2118,
                    "ehr_2": 4.6353,
                },
            )

            # test pre-loaded multiple DataFrames, join tables with kwargs
            dfs = [pd.read_csv(i) for i in filepaths]
            dataset = CSVDataset(src=dfs, on="subject_id")
            self.assertEqual(dataset[3]["subject_id"], "s000003")
            self.assertEqual(dataset[3]["label"], 1)
            self.assertEqual(round(dataset[3]["ehr_0"], 4), 3.3333)
            self.assertEqual(dataset[3]["meta_0"], False)