예제 #1
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 5986)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        try:  # will start downloading if testing_dir doesn't have the MedNIST files
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True)
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False)
        data.get_num_classes()
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
예제 #2
0
    def test_lr_finder(self):
        # 0.001 gives 54 examples
        train_ds = MedNISTDataset(
            root_dir=self.root_dir,
            transform=self.transforms,
            section="validation",
            val_frac=0.001,
            download=True,
            num_workers=10,
        )
        train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
        num_classes = train_ds.get_num_classes()

        model = DenseNet(
            spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,)
        )
        loss_function = torch.nn.CrossEntropyLoss()
        learning_rate = 1e-5
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)

        lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device)
        lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5)
        print(lr_finder.get_steepest_gradient(0, 0)[0])

        if has_matplotlib:
            ax = plt.subplot()
            plt.show(block=False)
            lr_finder.plot(0, 0, ax=ax)  # to inspect the loss-learning rate graph
            plt.pause(3)
            plt.close()

        lr_finder.reset()  # to reset the model and optimizer to their initial state
예제 #3
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys="image"),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(
                len(dataset),
                int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac))
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue(PostFix.meta("image") in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

        with skip_if_downloading_fails():
            data = MedNISTDataset(root_dir=testing_dir,
                                  transform=transform,
                                  section="test",
                                  download=True,
                                  copy_cache=False)

        _test_dataset(data)

        # testing from
        data = MedNISTDataset(root_dir=Path(testing_dir),
                              transform=transform,
                              section="test",
                              download=False)
        self.assertEqual(data.get_num_classes(), 6)
        _test_dataset(data)
        data = MedNISTDataset(root_dir=testing_dir,
                              section="test",
                              download=False)
        self.assertTupleEqual(data[0]["image"].shape, (64, 64))
        # test same dataset length with different random seed
        data = MedNISTDataset(root_dir=testing_dir,
                              transform=transform,
                              section="test",
                              download=False,
                              seed=42)
        _test_dataset(data)
        self.assertEqual(data[0]["class_name"], "AbdomenCT")
        self.assertEqual(data[0]["label"].cpu().item(), 0)
        shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
        try:
            MedNISTDataset(root_dir=testing_dir,
                           transform=transform,
                           section="test",
                           download=False)
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))