def test_shuffle(self): test_data = [{"image": f"test_image{i}.nii.gz"} for i in range(20)] dataset = SmartCacheDataset( data=test_data, transform=None, replace_rate=0.1, cache_num=16, num_init_workers=4, num_replace_workers=4, shuffle=True, seed=123, ) dataset.start() for i in range(3): dataset.update_cache() if i == 0: self.assertEqual(dataset[15]["image"], "test_image18.nii.gz") elif i == 1: self.assertEqual(dataset[15]["image"], "test_image13.nii.gz") else: self.assertEqual(dataset[15]["image"], "test_image5.nii.gz") dataset.shutdown()
def test_content(self): data = [0, 1, 2, 3, 4, 5, 6, 7, 8] expected = [ [0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8], ] # set up engine def _train_func(engine, batch): self.assertListEqual(batch.tolist(), expected[engine.state.epoch - 1]) engine = Engine(_train_func) # set up testing handler dataset = SmartCacheDataset(data, transform=None, replace_rate=0.2, cache_num=5, shuffle=False) workers = 2 if sys.platform == "linux" else 0 data_loader = torch.utils.data.DataLoader(dataset, batch_size=5, num_workers=workers, persistent_workers=False) SmartCacheHandler(dataset).attach(engine) engine.run(data_loader, max_epochs=5)
def test_datalist(self): data_list = [np.array([i]) for i in range(5)] data_list_backup = copy.copy(data_list) SmartCacheDataset(data=data_list, transform=None, cache_rate=0.5, replace_rate=0.4, shuffle=True) np.testing.assert_allclose(data_list, data_list_backup)
def test_shape(self, replace_rate, num_replace_workers, transform): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) test_data = [{ "image": os.path.join(tempdir, "test_image1.nii.gz"), "label": os.path.join(tempdir, "test_label1.nii.gz"), "extra": os.path.join(tempdir, "test_extra1.nii.gz"), }] * 20 dataset = SmartCacheDataset( data=test_data, transform=transform, replace_rate=replace_rate, cache_num=16, num_init_workers=4, num_replace_workers=num_replace_workers, ) if transform is None: # Check without providing transfrom dataset2 = SmartCacheDataset( data=test_data, replace_rate=replace_rate, cache_num=16, num_init_workers=4, num_replace_workers=num_replace_workers, ) for k in ["image", "label", "extra"]: self.assertEqual(dataset[0][k], dataset2[0][k]) self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): self.assertIsNotNone(dataset._cache[i]) for _ in range(2): dataset.start() for _ in range(3): dataset.update_cache() self.assertIsNotNone(dataset[15]) if isinstance(dataset[15]["image"], (np.ndarray, torch.Tensor)): assert_allclose(dataset[15]["image"], dataset[15]["label"]) else: self.assertIsInstance(dataset[15]["image"], str) dataset.shutdown()
def test_update_cache(self): # Given test_data = [{ "image": f"test_image{i}.nii.gz", "label": f"test_image{i}.nii.gz" } for i in range(40)] dataset = SmartCacheDataset( data=test_data, transform=None, replace_rate=0.2, cache_num=10, num_init_workers=4, num_replace_workers=4, shuffle=False, ) dataset.start() start_num = int(0.2 * 10) remain_num = int((1 - 0.2) * 10) old_cache = copy.deepcopy(dataset._cache) # When with dataset._update_lock: replacements = copy.deepcopy(dataset._replacements) dataset.update_cache() new_cache = dataset._cache kept_cache = old_cache[start_num:] # Then for string1, string2 in zip(kept_cache, new_cache[0:remain_num]): assert string1 == string2 for string_new, string_replacement in zip(replacements, new_cache[remain_num:]): assert string_new == string_replacement
def _dataset(self, context, datalist, replace_rate=0.25): if context.multi_gpu: world_size = torch.distributed.get_world_size() if len( datalist ) // world_size: # every gpu gets full data when datalist is smaller datalist = partition_dataset( data=datalist, num_partitions=world_size, even_divisible=True)[context.local_rank] transforms = self._validate_transforms( self.train_pre_transforms(context), "Training", "pre") dataset = ( CacheDataset(datalist, transforms) if context.dataset_type == "CacheDataset" else SmartCacheDataset(datalist, transforms, replace_rate) if context.dataset_type == "SmartCacheDataset" else PersistentDataset(datalist, transforms, cache_dir=os.path.join(context.cache_dir, "pds")) if context.dataset_type == "PersistentDataset" else Dataset( datalist, transforms)) return dataset, datalist
def train(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f elif not os.path.exists(args.dir): # create 40 random image, mask paris for training print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["img", "seg"]), ]) # partition dataset based on current rank number, every rank trains with its own data # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch data_part = partition_dataset( data=train_files, num_partitions=dist.get_world_size(), shuffle=True, even_divisible=True, )[dist.get_rank()] train_ds = SmartCacheDataset( data=data_part, transform=train_transforms, replace_rate=0.2, cache_num= 15, # we suppose to use 2 ranks in this example, every rank has 20 training images num_init_workers=2, num_replace_workers=2, ) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True).to(device) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # start a typical PyTorch training epoch_loss_values = list() # start the replacement thread of SmartCache train_ds.start() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) # replace 20% of cache content for next epoch train_ds.update_cache() print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # stop replacement thread of SmartCache train_ds.shutdown() print(f"train completed, epoch losses: {epoch_loss_values}") if dist.get_rank() == 0: # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") dist.destroy_process_group()
def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] _kwg = { "persistent_workers": persistent_workers } if pytorch_after(1, 8) else {} data_list = list(range(1, 11)) dataset = CacheDataset(data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False) self.assertListEqual(expected, list(dataset)) loader = DataLoader( CacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) dataset = SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ) self.assertListEqual(expected[:7], list(dataset)) loader = DataLoader( SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected[:7], [y.item() for y in loader]) self.assertListEqual(expected[:7], [y.item() for y in loader]) with tempfile.TemporaryDirectory() as tempdir: pdata = PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir) self.assertListEqual(expected, list(pdata)) loader = DataLoader( PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir), batch_size=1, num_workers=loader_workers, shuffle=False, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader])
def test_set_data(self): data_list1 = list(range(10)) transform = Lambda(func=lambda x: np.array([x * 10])) dataset = SmartCacheDataset( data=data_list1, transform=transform, cache_rate=0.5, replace_rate=0.4, num_init_workers=4, num_replace_workers=2, shuffle=False, progress=True, ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) dataset.start() for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i] * 10]], d) # replace cache content, move forward 2(5 * 0.4) items dataset.update_cache() for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i + 2] * 10]], d) # shutdown to update data dataset.shutdown() # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # restart the dataset dataset.start() # rerun with updated cache content for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list2[i] * 10]], d) # replace cache content, move forward 2(5 * 0.4) items dataset.update_cache() for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list2[i + 2] * 10]], d) # finally shutdown the dataset dataset.shutdown()