Esempio n. 1
0
    def test_shape(self, replace_rate, num_replace_workers, transform):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]),
                                     np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
            nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
            test_data = [{
                "image": os.path.join(tempdir, "test_image1.nii.gz"),
                "label": os.path.join(tempdir, "test_label1.nii.gz"),
                "extra": os.path.join(tempdir, "test_extra1.nii.gz"),
            }] * 20
            dataset = SmartCacheDataset(
                data=test_data,
                transform=transform,
                replace_rate=replace_rate,
                cache_num=16,
                num_init_workers=4,
                num_replace_workers=num_replace_workers,
            )
            if transform is None:
                # Check without providing transfrom
                dataset2 = SmartCacheDataset(
                    data=test_data,
                    replace_rate=replace_rate,
                    cache_num=16,
                    num_init_workers=4,
                    num_replace_workers=num_replace_workers,
                )
                for k in ["image", "label", "extra"]:
                    self.assertEqual(dataset[0][k], dataset2[0][k])

            self.assertEqual(len(dataset._cache), dataset.cache_num)
            for i in range(dataset.cache_num):
                self.assertIsNotNone(dataset._cache[i])

            for _ in range(2):
                dataset.start()
                for _ in range(3):
                    dataset.update_cache()
                    self.assertIsNotNone(dataset[15])
                    if isinstance(dataset[15]["image"],
                                  (np.ndarray, torch.Tensor)):
                        assert_allclose(dataset[15]["image"],
                                        dataset[15]["label"])
                    else:
                        self.assertIsInstance(dataset[15]["image"], str)
                dataset.shutdown()
Esempio n. 2
0
    def test_content(self):
        data = [0, 1, 2, 3, 4, 5, 6, 7, 8]
        expected = [
            [0, 1, 2, 3, 4],
            [1, 2, 3, 4, 5],
            [2, 3, 4, 5, 6],
            [3, 4, 5, 6, 7],
            [4, 5, 6, 7, 8],
        ]

        # set up engine
        def _train_func(engine, batch):
            self.assertListEqual(batch.tolist(),
                                 expected[engine.state.epoch - 1])

        engine = Engine(_train_func)

        # set up testing handler
        dataset = SmartCacheDataset(data,
                                    transform=None,
                                    replace_rate=0.2,
                                    cache_num=5,
                                    shuffle=False)
        workers = 2 if sys.platform == "linux" else 0
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=5,
                                                  num_workers=workers,
                                                  persistent_workers=False)
        SmartCacheHandler(dataset).attach(engine)

        engine.run(data_loader, max_epochs=5)
Esempio n. 3
0
    def test_update_cache(self):
        # Given
        test_data = [{
            "image": f"test_image{i}.nii.gz",
            "label": f"test_image{i}.nii.gz"
        } for i in range(40)]
        dataset = SmartCacheDataset(
            data=test_data,
            transform=None,
            replace_rate=0.2,
            cache_num=10,
            num_init_workers=4,
            num_replace_workers=4,
            shuffle=False,
        )
        dataset.start()
        start_num = int(0.2 * 10)
        remain_num = int((1 - 0.2) * 10)

        old_cache = copy.deepcopy(dataset._cache)
        # When
        with dataset._update_lock:
            replacements = copy.deepcopy(dataset._replacements)
        dataset.update_cache()
        new_cache = dataset._cache
        kept_cache = old_cache[start_num:]
        # Then
        for string1, string2 in zip(kept_cache, new_cache[0:remain_num]):
            assert string1 == string2
        for string_new, string_replacement in zip(replacements,
                                                  new_cache[remain_num:]):
            assert string_new == string_replacement
Esempio n. 4
0
    def test_shuffle(self):
        test_data = [{"image": f"test_image{i}.nii.gz"} for i in range(20)]
        dataset = SmartCacheDataset(
            data=test_data,
            transform=None,
            replace_rate=0.1,
            cache_num=16,
            num_init_workers=4,
            num_replace_workers=4,
            shuffle=True,
            seed=123,
        )

        dataset.start()
        for i in range(3):
            dataset.update_cache()

            if i == 0:
                self.assertEqual(dataset[15]["image"], "test_image18.nii.gz")
            elif i == 1:
                self.assertEqual(dataset[15]["image"], "test_image13.nii.gz")
            else:
                self.assertEqual(dataset[15]["image"], "test_image5.nii.gz")

        dataset.shutdown()
    def test_datalist(self):
        data_list = [np.array([i]) for i in range(5)]
        data_list_backup = copy.copy(data_list)

        SmartCacheDataset(data=data_list,
                          transform=None,
                          cache_rate=0.5,
                          replace_rate=0.4,
                          shuffle=True)
        np.testing.assert_allclose(data_list, data_list_backup)
    def test_set_data(self):
        data_list1 = list(range(10))

        transform = Lambda(func=lambda x: np.array([x * 10]))

        dataset = SmartCacheDataset(
            data=data_list1,
            transform=transform,
            cache_rate=0.5,
            replace_rate=0.4,
            num_init_workers=4,
            num_replace_workers=2,
            shuffle=False,
            progress=True,
        )

        num_workers = 2 if sys.platform == "linux" else 0
        dataloader = DataLoader(dataset=dataset,
                                num_workers=num_workers,
                                batch_size=1)

        dataset.start()
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list1[i] * 10]], d)
        # replace cache content, move forward 2(5 * 0.4) items
        dataset.update_cache()
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list1[i + 2] * 10]], d)
        # shutdown to update data
        dataset.shutdown()
        # update the datalist and fill the cache content
        data_list2 = list(range(-10, 0))
        dataset.set_data(data=data_list2)
        # restart the dataset
        dataset.start()
        # rerun with updated cache content
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list2[i] * 10]], d)
        # replace cache content, move forward 2(5 * 0.4) items
        dataset.update_cache()
        for i, d in enumerate(dataloader):
            np.testing.assert_allclose([[data_list2[i + 2] * 10]], d)
        # finally shutdown the dataset
        dataset.shutdown()
Esempio n. 7
0
    def _dataset(self, context, datalist, replace_rate=0.25):
        if context.multi_gpu:
            world_size = torch.distributed.get_world_size()
            if len(
                    datalist
            ) // world_size:  # every gpu gets full data when datalist is smaller
                datalist = partition_dataset(
                    data=datalist,
                    num_partitions=world_size,
                    even_divisible=True)[context.local_rank]

        transforms = self._validate_transforms(
            self.train_pre_transforms(context), "Training", "pre")
        dataset = (
            CacheDataset(datalist, transforms)
            if context.dataset_type == "CacheDataset" else
            SmartCacheDataset(datalist, transforms, replace_rate)
            if context.dataset_type == "SmartCacheDataset" else
            PersistentDataset(datalist,
                              transforms,
                              cache_dir=os.path.join(context.cache_dir, "pds"))
            if context.dataset_type == "PersistentDataset" else Dataset(
                datalist, transforms))
        return dataset, datalist
Esempio n. 8
0
def train(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    elif not os.path.exists(args.dir):
        # create 40 random image, mask paris for training
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # partition dataset based on current rank number, every rank trains with its own data
    # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch
    data_part = partition_dataset(
        data=train_files,
        num_partitions=dist.get_world_size(),
        shuffle=True,
        even_divisible=True,
    )[dist.get_rank()]

    train_ds = SmartCacheDataset(
        data=data_part,
        transform=train_transforms,
        replace_rate=0.2,
        cache_num=
        15,  # we suppose to use 2 ranks in this example, every rank has 20 training images
        num_init_workers=2,
        num_replace_workers=2,
    )
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    # start a typical PyTorch training
    epoch_loss_values = list()
    # start the replacement thread of SmartCache
    train_ds.start()

    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        # replace 20% of cache content for next epoch
        train_ds.update_cache()
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    # stop replacement thread of SmartCache
    train_ds.shutdown()
    print(f"train completed, epoch losses: {epoch_loss_values}")
    if dist.get_rank() == 0:
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
    dist.destroy_process_group()
Esempio n. 9
0
    def test_thread_safe(self, persistent_workers, cache_workers,
                         loader_workers):
        expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002]
        _kwg = {
            "persistent_workers": persistent_workers
        } if pytorch_after(1, 8) else {}
        data_list = list(range(1, 11))
        dataset = CacheDataset(data=data_list,
                               transform=_StatefulTransform(),
                               cache_rate=1.0,
                               num_workers=cache_workers,
                               progress=False)
        self.assertListEqual(expected, list(dataset))
        loader = DataLoader(
            CacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=1.0,
                num_workers=cache_workers,
                progress=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected, [y.item() for y in loader])
        self.assertListEqual(expected, [y.item() for y in loader])

        dataset = SmartCacheDataset(
            data=data_list,
            transform=_StatefulTransform(),
            cache_rate=0.7,
            replace_rate=0.5,
            num_replace_workers=cache_workers,
            progress=False,
            shuffle=False,
        )
        self.assertListEqual(expected[:7], list(dataset))
        loader = DataLoader(
            SmartCacheDataset(
                data=data_list,
                transform=_StatefulTransform(),
                cache_rate=0.7,
                replace_rate=0.5,
                num_replace_workers=cache_workers,
                progress=False,
                shuffle=False,
            ),
            batch_size=1,
            num_workers=loader_workers,
            **_kwg,
        )
        self.assertListEqual(expected[:7], [y.item() for y in loader])
        self.assertListEqual(expected[:7], [y.item() for y in loader])

        with tempfile.TemporaryDirectory() as tempdir:
            pdata = PersistentDataset(data=data_list,
                                      transform=_StatefulTransform(),
                                      cache_dir=tempdir)
            self.assertListEqual(expected, list(pdata))
            loader = DataLoader(
                PersistentDataset(data=data_list,
                                  transform=_StatefulTransform(),
                                  cache_dir=tempdir),
                batch_size=1,
                num_workers=loader_workers,
                shuffle=False,
                **_kwg,
            )
            self.assertListEqual(expected, [y.item() for y in loader])
            self.assertListEqual(expected, [y.item() for y in loader])