def setup(self, stage: Optional[str] = None): """Creates persistent training and validation sets based on provided splits. Args: stage (Optional[str], optional): Stage (e.g., "fit", "eval") for more efficient setup. Defaults to None. """ set_determinism(seed=self.seed) if stage == "fit" or stage is None: train_scans, val_scans = self.splits self.train_dicts = [{ "image": nod["image"], "label": nod[self.target] } for scan in train_scans for nod in scan["nodules"] if nod["annotations"] >= self.min_anns and nod[self.target] not in self.exclude_labels ] self.val_dicts = [{ "image": nod["image"], "label": nod[self.target] } for scan in val_scans for nod in scan["nodules"] if nod["annotations"] >= self.min_anns and nod[self.target] not in self.exclude_labels] self.train_ds = PersistentDataset(self.train_dicts, transform=self.train_transforms, cache_dir=self.cache_dir) self.val_ds = PersistentDataset(self.val_dicts, transform=self.val_transforms, cache_dir=self.cache_dir) return
def test_mp_dataset(self): print("persistent", dist.get_rank()) items = [[list(range(i))] for i in range(5)] ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir) self.assertEqual(list(ds1), list(ds)) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, hash_func=json_hashing) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, hash_func=json_hashing) self.assertEqual(list(ds1), list(ds)) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] with tempfile.TemporaryDirectory() as tempdir: ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual(list(ds1), list(ds)) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing) self.assertEqual(list(ds1), list(ds)) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
def test_mp_dataset(self): print("persistent", dist.get_rank()) items = [[list(range(i))] for i in range(5)] cache_dir = os.path.join(self.tempdir, "test") ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir) self.assertEqual(list(ds1), list(ds)) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
def test_shape(self, expected_shape): test_image = nib.Nifti1Image( np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) tempdir = tempfile.mkdtemp() nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) test_data = [ { "image": os.path.join(tempdir, "test_image1.nii.gz"), "label": os.path.join(tempdir, "test_label1.nii.gz"), "extra": os.path.join(tempdir, "test_extra1.nii.gz"), }, { "image": os.path.join(tempdir, "test_image2.nii.gz"), "label": os.path.join(tempdir, "test_label2.nii.gz"), "extra": os.path.join(tempdir, "test_extra2.nii.gz"), }, ] test_transform = Compose([ LoadNiftid(keys=["image", "label", "extra"]), SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ]) dataset_precached = PersistentDataset(data=test_data, transform=test_transform, cache_dir=tempdir) data1_precached = dataset_precached[0] data2_precached = dataset_precached[1] dataset_postcached = PersistentDataset(data=test_data, transform=test_transform, cache_dir=tempdir) data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1] shutil.rmtree(tempdir) self.assertTupleEqual(data1_precached["image"].shape, expected_shape) self.assertTupleEqual(data1_precached["label"].shape, expected_shape) self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) self.assertTupleEqual(data2_precached["image"].shape, expected_shape) self.assertTupleEqual(data2_precached["label"].shape, expected_shape) self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) test_data = [ { "image": os.path.join(tempdir, "test_image1.nii.gz"), "label": os.path.join(tempdir, "test_label1.nii.gz"), "extra": os.path.join(tempdir, "test_extra1.nii.gz"), }, { "image": os.path.join(tempdir, "test_image2.nii.gz"), "label": os.path.join(tempdir, "test_label2.nii.gz"), "extra": os.path.join(tempdir, "test_extra2.nii.gz"), }, ] cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) data1_precached = dataset_precached[0] data2_precached = dataset_precached[1] dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir) data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1] data3_postcached = dataset_postcached[0:2] if transform is None: self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) else: self.assertTupleEqual(data1_precached["image"].shape, expected_shape) self.assertTupleEqual(data1_precached["label"].shape, expected_shape) self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) self.assertTupleEqual(data2_precached["image"].shape, expected_shape) self.assertTupleEqual(data2_precached["label"].shape, expected_shape) self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) for d in data3_postcached: self.assertTupleEqual(d["image"].shape, expected_shape)
def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) tempdir = tempfile.mkdtemp() nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) test_data = [ { "image": os.path.join(tempdir, "test_image1.nii.gz"), "label": os.path.join(tempdir, "test_label1.nii.gz"), "extra": os.path.join(tempdir, "test_extra1.nii.gz"), }, { "image": os.path.join(tempdir, "test_image2.nii.gz"), "label": os.path.join(tempdir, "test_label2.nii.gz"), "extra": os.path.join(tempdir, "test_extra2.nii.gz"), }, ] dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) data1_precached = dataset_precached[0] data2_precached = dataset_precached[1] dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1] shutil.rmtree(tempdir) if transform is None: self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) else: self.assertTupleEqual(data1_precached["image"].shape, expected_shape) self.assertTupleEqual(data1_precached["label"].shape, expected_shape) self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) self.assertTupleEqual(data2_precached["image"].shape, expected_shape) self.assertTupleEqual(data2_precached["label"].shape, expected_shape) self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
def setup(self, stage: Optional[str] = None): """Set up persistent datasets for training and validation. Args: stage (Optional[str], optional): Stage in the model lifecycle, e.g., `fit` or `test`. Only needed datasets will be created. Defaults to None. """ self.scans = pd.read_csv( self.data_dir/"meta/scans.csv", index_col="PatientID") set_determinism(seed=self.seed) if stage == "fit" or stage is None: train_scans, val_scans = self.splits self.train_dicts = [{"image": scan["image"], "label": scan["mask"]} for scan in train_scans] self.val_dicts = [{"image": scan["image"], "label": scan["mask"]} for scan in val_scans] self.train_ds = PersistentDataset( self.train_dicts, transform=self.train_transforms, cache_dir=self.cache_dir) self.val_ds = PersistentDataset( self.val_dicts, transform=self.val_transforms, cache_dir=self.cache_dir) return
def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] class _InplaceXform(Transform): def __call__(self, data): if data: data[0] = data[0] + np.pi else: data.append(1) return data with tempfile.TemporaryDirectory() as tempdir: ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) self.assertEqual(list(ds1), list(ds)) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing) self.assertEqual(list(ds1), list(ds)) self.assertEqual( items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
def _dataset(self, context, datalist, replace_rate=0.25): if context.multi_gpu: world_size = torch.distributed.get_world_size() if len( datalist ) // world_size: # every gpu gets full data when datalist is smaller datalist = partition_dataset( data=datalist, num_partitions=world_size, even_divisible=True)[context.local_rank] transforms = self._validate_transforms( self.train_pre_transforms(context), "Training", "pre") dataset = ( CacheDataset(datalist, transforms) if context.dataset_type == "CacheDataset" else SmartCacheDataset(datalist, transforms, replace_rate) if context.dataset_type == "SmartCacheDataset" else PersistentDataset(datalist, transforms, cache_dir=os.path.join(context.cache_dir, "pds")) if context.dataset_type == "PersistentDataset" else Dataset( datalist, transforms)) return dataset, datalist
def train(self, train_info, valid_info, hyperparameters, run_data_check=False): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not run_data_check: start_dt = datetime.datetime.now() start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S') print(f'Training started: {start_dt_string}') # 1. Create folders to save the model timedate_info = str( datetime.datetime.now()).split(' ')[0] + '_' + str( datetime.datetime.now().strftime("%H:%M:%S")).replace( ':', '-') path_to_model = os.path.join( self.out_dir, 'trained_models', self.unique_name + '_' + timedate_info) os.mkdir(path_to_model) # 2. Load hyperparameters learning_rate = hyperparameters['learning_rate'] weight_decay = hyperparameters['weight_decay'] total_epoch = hyperparameters['total_epoch'] multiplicator = hyperparameters['multiplicator'] batch_size = hyperparameters['batch_size'] validation_epoch = hyperparameters['validation_epoch'] validation_interval = hyperparameters['validation_interval'] H = hyperparameters['H'] L = hyperparameters['L'] # 3. Consider class imbalance negative, positive = 0, 0 for _, label in train_info: if int(label) == 0: negative += 1 elif int(label) == 1: positive += 1 pos_weight = torch.Tensor([(negative / positive)]).to(self.device) # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices) train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, train_info) valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, valid_info) large_image_splitter(train_data, self.cache_dir) set_determinism(seed=100) train_trans, valid_trans = self.transformations(H, L) train_dataset = PersistentDataset( data=train_data[:], transform=train_trans, cache_dir=self.persistent_dataset_dir) valid_dataset = PersistentDataset( data=valid_data[:], transform=valid_trans, cache_dir=self.persistent_dataset_dir) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Perform data checks if run_data_check: check_data = monai.utils.misc.first(train_loader) print(check_data["image"].shape, check_data["label"]) for i in range(batch_size): multi_slice_viewer( check_data["image"][i, 0, :, :, :], check_data["image_meta_dict"]["filename_or_obj"][i]) exit() """c = 1 for d in train_loader: img = d["image"] seg = d["seg"][0] seg, _ = nrrd.read(seg) img_name = d["image_meta_dict"]["filename_or_obj"][0] print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape) multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0]) #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0]) c += 1 exit()""" # 5. Prepare model model = ModelCT().to(self.device) # 6. Define loss function, optimizer and scheduler loss_function = torch.nn.BCEWithLogitsLoss( pos_weight) # pos_weight for class imbalance optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1) # 7. Create post validation transforms and handlers path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard') writer = SummaryWriter(log_dir=path_to_tensorboard) valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) valid_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=writer, output_transform=lambda x: None), CheckpointSaver(save_dir=path_to_model, save_dict={"model": model}, save_key_metric=True), MetricsSaver(save_dir=path_to_model, metrics=['Valid_AUC', 'Valid_ACC']), ] # 8. Create validatior discrete = AsDiscrete(threshold_values=True) evaluator = SupervisedEvaluator( device=self.device, val_data_loader=valid_loader, network=model, post_transform=valid_post_transforms, key_val_metric={ "Valid_AUC": ROCAUC(output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "Valid_Accuracy": Accuracy(output_transform=lambda x: (discrete(x["pred"]), x["label"])) }, val_handlers=valid_handlers, amp=self.amp, ) # 9. Create trainer # Loss function does the last sigmoid, so we dont need it here. train_post_transforms = Compose([ # Empty ]) logger = MetricLogger(evaluator=evaluator) train_handlers = [ logger, LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandlerCT(validator=evaluator, start=validation_epoch, interval=validation_interval, epoch_level=True), StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer, tag_name="Train_Loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=path_to_model, save_dict={ "model": model, "opt": optimizer }, save_interval=1, n_saved=1), ] trainer = SupervisedTrainer( device=self.device, max_epochs=total_epoch, train_data_loader=train_loader, network=model, optimizer=optimizer, loss_function=loss_function, post_transform=train_post_transforms, train_handlers=train_handlers, amp=self.amp, ) # 10. Run trainer trainer.run() # 11. Save results np.save(path_to_model + '/AUCS.npy', np.array(logger.metrics['Valid_AUC'])) np.save(path_to_model + '/ACCS.npy', np.array(logger.metrics['Valid_ACC'])) np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss)) np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters)) return path_to_model
def prepare_data(self): data_images = sorted([ os.path.join(data_path, x) for x in os.listdir(data_path) if x.startswith("data") ]) data_labels = sorted([ os.path.join(data_path, x) for x in os.listdir(data_path) if x.startswith("label") ]) data_dicts = [{ "image": image_name, "label": label_name, "patient": image_name.split("/")[-1].replace("data", "").replace(".nii.gz", ""), } for image_name, label_name in zip(data_images, data_labels)] train_files, val_files = train_val_split(data_dicts) print( f"Training patients: {len(train_files)}, Validation patients: {len(val_files)}" ) set_determinism(seed=0) train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=PIXDIM, mode=("bilinear", "nearest")), DataStatsdWithPatient(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-100, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=PATCH_SIZE, pos=1, neg=1, num_samples=16, image_key="image", image_threshold=0, ), RandFlipd(["image", "label"], spatial_axis=[0, 1, 2], prob=0.5), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=PIXDIM, mode=("bilinear", "nearest")), DataStatsdWithPatient(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-100, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), StoreShaped(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) self.train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=cache_path) self.val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=cache_path)
def main(train_output): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) large_image_splitter(_train_data_hackathon, dirs["cache"]) balance_training_data(_train_data_hackathon, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43) # Setup transforms # Crop foreground crop_foreground = CropForegroundd( keys=["image"], source_key="image", margin=(5, 5, 0), #select_fn = lambda x: x != 0 ) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_train_transfrom, cache_dir=dirs["persistent"]) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True, pin_memory=using_gpu, num_workers=1, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()
def get_dataflow(seed, data_dir, cache_dir, batch_size): img = nib.load(str(data_dir / "average_smwc1.nii")) img_data_1 = img.get_fdata() img_data_1 = np.expand_dims(img_data_1, axis=0) img = nib.load(str(data_dir / "average_smwc2.nii")) img_data_2 = img.get_fdata() img_data_2 = np.expand_dims(img_data_2, axis=0) img = nib.load(str(data_dir / "average_smwc3.nii")) img_data_3 = img.get_fdata() img_data_3 = np.expand_dims(img_data_3, axis=0) mask = np.concatenate((img_data_1, img_data_2, img_data_3)) mask[mask > 0.3] = 1 mask[mask <= 0.3] = 0 # Define transformations train_transforms = transforms.Compose([ transforms.LoadNiftid(keys=["c1", "c2", "c3"]), transforms.AddChanneld(keys=["c1", "c2", "c3"]), transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"), transforms.MaskIntensityd(keys=["img"], mask_data=mask), transforms.ScaleIntensityd(keys="img"), transforms.ToTensord(keys=["img", "label"]) ]) val_transforms = transforms.Compose([ transforms.LoadNiftid(keys=["c1", "c2", "c3"]), transforms.AddChanneld(keys=["c1", "c2", "c3"]), transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"), transforms.MaskIntensityd(keys=["img"], mask_data=mask), transforms.ScaleIntensityd(keys="img"), transforms.ToTensord(keys=["img", "label"]) ]) # Get img paths df = pd.read_csv(data_dir / "banc2019_training_dataset.csv") df = df.sample(frac=1, random_state=seed) df["NormAge"] = (((df["Age"] - 18) / (92 - 18)) * 2) - 1 data_dicts = [] for index, row in df.iterrows(): study_dir = data_dir / row["Study"] / "derivatives" / "spm" subj = list(study_dir.glob(f"sub-{row['Subject']}")) if subj == []: subj = list(study_dir.glob(f"*sub-{row['Subject']}*")) if subj == []: subj = list( study_dir.glob(f"*sub-{row['Subject'].rstrip('_S1')}*")) if subj == []: if row["Study"] == "SALD": subj = list( study_dir.glob(f"sub-{int(row['Subject']):06d}*")) if subj == []: print(f"{row['Study']} {row['Subject']}") continue else: print(f"{row['Study']} {row['Subject']}") continue c1_img = list(subj[0].glob("./smwc1*")) c2_img = list(subj[0].glob("./smwc2*")) c3_img = list(subj[0].glob("./smwc3*")) if c1_img == []: print(f"{row['Study']} {row['Subject']}") continue if c2_img == []: print(f"{row['Study']} {row['Subject']}") continue if c3_img == []: print(f"{row['Study']} {row['Subject']}") continue data_dicts.append({ "c1": str(c1_img[0]), "c2": str(c2_img[0]), "c3": str(c3_img[0]), "label": row["NormAge"] }) print(f"Found {len(data_dicts)} subjects.") val_size = len(data_dicts) // 10 # Create datasets and dataloaders train_ds = PersistentDataset(data=data_dicts[:-val_size], transform=train_transforms, cache_dir=cache_dir) # train_ds = Dataset(data=data_dicts[:-val_size], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=list_data_collate) val_ds = PersistentDataset(data=data_dicts[-val_size:], transform=val_transforms, cache_dir=cache_dir) # val_ds = Dataset(data=data_dicts[-val_size:], transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4) return train_loader, val_loader
def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] _kwg = { "persistent_workers": persistent_workers } if pytorch_after(1, 8) else {} data_list = list(range(1, 11)) dataset = CacheDataset(data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False) self.assertListEqual(expected, list(dataset)) loader = DataLoader( CacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) dataset = SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ) self.assertListEqual(expected[:7], list(dataset)) loader = DataLoader( SmartCacheDataset( data=data_list, transform=_StatefulTransform(), cache_rate=0.7, replace_rate=0.5, num_replace_workers=cache_workers, progress=False, shuffle=False, ), batch_size=1, num_workers=loader_workers, **_kwg, ) self.assertListEqual(expected[:7], [y.item() for y in loader]) self.assertListEqual(expected[:7], [y.item() for y in loader]) with tempfile.TemporaryDirectory() as tempdir: pdata = PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir) self.assertListEqual(expected, list(pdata)) loader = DataLoader( PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir), batch_size=1, num_workers=loader_workers, shuffle=False, **_kwg, ) self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader])
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # Setup directories dirs = setup_directories() # Setup torch device device, using_gpu = create_device("cuda") # Load and randomize images # HACKATON image and segmentation data hackathon_dir = os.path.join(dirs["data"], 'HACKATHON') map_fn = lambda x: (x[0], int(x[1])) with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp: train_info_hackathon = [ map_fn(entry.strip().split(',')) for entry in fp.readlines() ] image_dir = os.path.join(hackathon_dir, 'images', 'train') seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train') _train_data_hackathon = get_data_from_info(image_dir, seg_dir, train_info_hackathon, dual_output=False) _train_data_hackathon = large_image_splitter(_train_data_hackathon, dirs["cache"]) copy_list = transform_and_copy(_train_data_hackathon, dirs['cache']) balance_training_data2(_train_data_hackathon, copy_list, seed=72) # PSUF data """psuf_dir = os.path.join(dirs["data"], 'psuf') with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp: train_info = [entry.strip().split(',') for entry in fp.readlines()] image_dir = os.path.join(psuf_dir, 'images') train_data_psuf = get_data_from_info(image_dir, None, train_info)""" # Split data into train, validate and test train_split, test_data_hackathon = train_test_split(_train_data_hackathon, test_size=0.2, shuffle=True, random_state=42) train_data_hackathon, valid_data_hackathon = train_test_split( train_split, test_size=0.2, shuffle=True, random_state=43) #balance_training_data(train_data_hackathon, seed=72) #balance_training_data(valid_data_hackathon, seed=73) #balance_training_data(test_data_hackathon, seed=74) # Setup transforms # Crop foreground crop_foreground = CropForegroundd(keys=["image"], source_key="image", margin=(5, 5, 0), select_fn=lambda x: x != 0) # Crop Z crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12)) # Window width and level (window center) WW, WL = 1500, -600 ct_window = CTWindowd(keys=["image"], width=WW, level=WL) # Random axis flip rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50) rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50) rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50) # Rand affine transform rand_affine = RandAffined(keys=["image"], prob=0.5, rotate_range=(0, 0, np.pi / 12), shear_range=(0.07, 0.07, 0.0), translate_range=(0, 0, 0), scale_range=(0.07, 0.07, 0.0), padding_mode="zeros") # Pad image to have hight at least 30 spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30)) resize = Resized(keys=["image"], spatial_size=(int(512 * 0.50), int(512 * 0.50), -1), mode="trilinear") # Apply Gaussian noise rand_gaussian_noise = RandGaussianNoised(keys=["image"], prob=0.25, mean=0.0, std=0.1) # Create transforms common_transform = Compose([ LoadImaged(keys=["image"]), ct_window, CTSegmentation(keys=["image"]), AddChanneld(keys=["image"]), resize, crop_foreground, crop_z, spatial_pad, ]) hackathon_train_transform = Compose([ common_transform, rand_x_flip, rand_y_flip, rand_z_flip, rand_affine, rand_gaussian_noise, ToTensord(keys=["image"]), ]).flatten() hackathon_valid_transfrom = Compose([ common_transform, #rand_x_flip, #rand_y_flip, #rand_z_flip, #rand_affine, ToTensord(keys=["image"]), ]).flatten() hackathon_test_transfrom = Compose([ common_transform, ToTensord(keys=["image"]), ]).flatten() psuf_transforms = Compose([ LoadImaged(keys=["image"]), AddChanneld(keys=["image"]), ToTensord(keys=["image"]), ]) # Setup data #set_determinism(seed=100) train_dataset = PersistentDataset(data=train_data_hackathon[:], transform=hackathon_train_transform, cache_dir=dirs["persistent"]) valid_dataset = PersistentDataset(data=valid_data_hackathon[:], transform=hackathon_valid_transfrom, cache_dir=dirs["persistent"]) test_dataset = PersistentDataset(data=test_data_hackathon[:], transform=hackathon_test_transfrom, cache_dir=dirs["persistent"]) train_loader = DataLoader( train_dataset, batch_size=4, #shuffle=True, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( train_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader( valid_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, sampler=ImbalancedDatasetSampler( valid_data_hackathon, callback_get_label=lambda x, i: x[i]['_label']), collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT)) test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=using_gpu, num_workers=2, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Setup network, loss function, optimizer and scheduler network = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) # pos_weight for class imbalance _, n, p = calculate_class_imbalance(train_data_hackathon) pos_weight = torch.Tensor([n, p]).to(device) loss_function = torch.nn.BCEWithLogitsLoss(pos_weight) optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1) # Setup validator and trainer valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), #Activationsd(keys="pred", softmax=True), ]) validator = Validator(device=device, val_data_loader=valid_loader, network=network, post_transform=valid_post_transforms, amp=using_gpu, non_blocking=using_gpu) trainer = Trainer(device=device, out_dir=dirs["out"], out_name="DenseNet121", max_epochs=120, validation_epoch=1, validation_interval=1, train_data_loader=train_loader, network=network, optimizer=optimizer, loss_function=loss_function, lr_scheduler=None, validator=validator, amp=using_gpu, non_blocking=using_gpu) """x_max, y_max, z_max, size_max = 0, 0, 0, 0 for data in valid_loader: image = data["image"] label = data["label"] print() print(len(data['image_transforms'])) #print(data['image_transforms']) print(label) shape = image.shape x_max = max(x_max, shape[-3]) y_max = max(y_max, shape[-2]) z_max = max(z_max, shape[-1]) size = int(image.nelement()*image.element_size()/1024/1024) size_max = max(size_max, size) print("shape:", shape, "size:", str(size)+"MB") #multi_slice_viewer(image[0, 0, :, :, :], str(label)) print(x_max, y_max, z_max, str(size_max)+"MB") exit()""" # Run trainer train_output = trainer.run() # Setup tester tester = Tester(device=device, test_data_loader=test_loader, load_dir=train_output, out_dir=dirs["out"], network=network, post_transform=valid_post_transforms, non_blocking=using_gpu, amp=using_gpu) # Run tester tester.run()