def test_plot(self): set_determinism(0) testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters()) img = torch.rand(1, 16, 16) data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img} loader = DataLoader([data for _ in range(10)]) trainer = SupervisedTrainer( device=torch.device("cpu"), max_epochs=1, train_data_loader=loader, network=net, optimizer=opt, loss_function=torch.nn.L1Loss(), ) logger = MetricLogger() logger.attach(trainer) con = ThreadContainer(trainer) con.start() con.join() fig = con.plot_status(logger) with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/threadcontainer_plot_test.png" fig.savefig(tempimg) comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-3) self.assertIsNone(comp, comp) # None indicates test passed
def test_training(self): """ check that the quality AffineTransform backpropagation """ atol = 1e-5 set_determinism(seed=0) out_ref, loss_ref, init_loss_ref = compare_2d(True, self.device) print(out_ref.shape, loss_ref, init_loss_ref) set_determinism(seed=0) out, loss, init_loss = compare_2d(False, self.device) print(out.shape, loss, init_loss) np.testing.assert_allclose(out_ref, out, atol=atol) np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol) np.testing.assert_allclose(loss_ref, loss, atol=atol) set_determinism(seed=0) out, loss, init_loss = compare_2d(False, self.device, True) print(out.shape, loss, init_loss) np.testing.assert_allclose(out_ref, out, atol=atol) np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol) np.testing.assert_allclose(loss_ref, loss, atol=atol)
def test_pickle(self): set_determinism(0) data1 = np.random.rand(10) data2 = np.random.rand(10) set_determinism(0) data3 = np.random.rand(10) data4 = np.random.rand(10) set_determinism(None) h1 = pickle_hashing(data1) h2 = pickle_hashing(data3) self.assertEqual(h1, h2) data_dict1 = {"b": data2, "a": data1} data_dict2 = {"a": data3, "b": data4} h1 = pickle_hashing(data_dict1) h2 = pickle_hashing(data_dict2) self.assertEqual(h1, h2) with self.assertRaises(TypeError): json_hashing(data_dict1)
def tearDown(self): set_determinism(seed=None)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
"*.nii.gz"))) train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz"))) data_dicts = [{ "image": image_name, "label": label_name } for image_name, label_name in zip(train_images, train_labels)] #n = len(data_dicts) #train_files, val_files = data_dicts[:-3], data_dicts[-3:] #train_files, val_files = data_dicts[:int(n*0.8)], data_dicts[int(n*0.2):] val_files, train_files, test_files = data_dicts[0:8], data_dicts[ 8:40], data_dicts[40:50] """## Set deterministic training for reproducibility""" set_determinism(seed=0) """## Setup transforms for training and validation Here we use several transforms to augment the dataset: 1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files. 1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape. 1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix. 1. `Orientationd` unifies the data orientation based on the affine matrix. 1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1]. 1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels. 1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio. The image centers of negative samples must be in valid body area. 1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform. 1. `ToTensord` converts the numpy array to PyTorch Tensor for further steps. """
def test_compute(self): set_determinism(123) self._compute()
def setUp(self): set_determinism(seed=1234)
def setUp(self) -> None: set_determinism(seed=0)
def train(args): # load hyper parameters task_id = args.task_id fold = args.fold val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold, args.expr_name) log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold) log_filename = os.path.join(val_output_dir, log_filename) interval = args.interval learning_rate = args.learning_rate max_epochs = args.max_epochs multi_gpu_flag = args.multi_gpu amp_flag = args.amp lr_decay_flag = args.lr_decay sw_batch_size = args.sw_batch_size tta_val = args.tta_val batch_dice = args.batch_dice window_mode = args.window_mode eval_overlap = args.eval_overlap local_rank = args.local_rank determinism_flag = args.determinism_flag determinism_seed = args.determinism_seed if determinism_flag: set_determinism(seed=determinism_seed) if local_rank == 0: print("Using deterministic training.") # transforms train_batch_size = data_loader_params[task_id]["batch_size"] if multi_gpu_flag: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cuda") properties, val_loader = get_data(args, mode="validation") _, train_loader = get_data(args, batch_size=train_batch_size, mode="train") # produce the network checkpoint = args.checkpoint net = get_network(properties, task_id, val_output_dir, checkpoint) net = net.to(device) if multi_gpu_flag: net = DistributedDataParallel(module=net, device_ids=[device], find_unused_parameters=True) optimizer = torch.optim.SGD( net.parameters(), lr=learning_rate, momentum=0.99, weight_decay=3e-5, nesterov=True, ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9) # produce evaluator val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = DynUNetEvaluator( device=device, val_data_loader=val_loader, network=net, n_classes=len(properties["labels"]), inferer=SlidingWindowInferer( roi_size=patch_size[task_id], sw_batch_size=sw_batch_size, overlap=eval_overlap, mode=window_mode, ), post_transform=None, key_val_metric={ "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=amp_flag, tta_val=tta_val, ) # produce trainer loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice) train_handlers = [] if lr_decay_flag: train_handlers += [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True) ] train_handlers += [ ValidationHandler(validator=evaluator, interval=interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), ] trainer = DynUNetTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=optimizer, loss_function=loss, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=amp_flag, ) # run logger = logging.getLogger() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") # Setup file handler fhandler = logging.FileHandler(log_filename) fhandler.setLevel(logging.INFO) fhandler.setFormatter(formatter) # Configure stream handler for the cells chandler = logging.StreamHandler() chandler.setLevel(logging.INFO) chandler.setFormatter(formatter) # Add both handlers if local_rank == 0: logger.addHandler(fhandler) logger.addHandler(chandler) logger.setLevel(logging.INFO) trainer.run()
def setUp(self) -> None: set_determinism(seed=0) im = create_test_image_2d(100, 101)[0] self.data_dict = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] self.data_list = [make_nifti_image(im) if has_nib else im for _ in range(6)]
def test_training(self): repeated = [] test_rounds = 3 if monai.config.get_torch_version_tuple() >= (1, 6) else 2 for i in range(test_rounds): set_determinism(seed=0) repeated.append([]) best_metric = run_training_test(self.data_dir, device=self.device, amp=(i == 2)) print("best metric", best_metric) if i == 2: self.assertTrue( test_integration_value(TASK, key="best_metric_2", data=best_metric, rtol=1e-2)) else: self.assertTrue( test_integration_value(TASK, key="best_metric", data=best_metric, rtol=1e-2)) repeated[i].append(best_metric) model_file = sorted( glob(os.path.join(self.data_dir, "net_key_metric*.pt")))[-1] infer_metric = run_inference_test(self.data_dir, model_file, device=self.device, amp=(i == 2)) print("infer metric", infer_metric) # check inference properties if i == 2: self.assertTrue( test_integration_value(TASK, key="infer_metric_2", data=infer_metric, rtol=1e-2)) else: self.assertTrue( test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2)) repeated[i].append(infer_metric) output_files = sorted( glob(os.path.join(self.data_dir, "img*", "*.nii.gz"))) for output in output_files: ave = np.mean(nib.load(output).get_fdata()) repeated[i].append(ave) if i == 2: self.assertTrue( test_integration_value(TASK, key="output_sums_2", data=repeated[i][2:], rtol=1e-2)) else: self.assertTrue( test_integration_value(TASK, key="output_sums", data=repeated[i][2:], rtol=1e-2)) np.testing.assert_allclose(repeated[0], repeated[1])
def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) result = TorchVision(**input_param)(input_data) torch.testing.assert_allclose(result, expected_value)
def test_values(self): # check system default flags set_determinism(None) self.assertTrue(not torch.backends.cudnn.deterministic) self.assertTrue(get_seed() is None) # set default seed set_determinism() self.assertTrue(get_seed() is not None) self.assertTrue(torch.backends.cudnn.deterministic) self.assertTrue(not torch.backends.cudnn.benchmark) # resume default set_determinism(None) self.assertTrue(not torch.backends.cudnn.deterministic) self.assertTrue(not torch.backends.cudnn.benchmark) self.assertTrue(get_seed() is None) # test seeds seed = 255 set_determinism(seed=seed) self.assertEqual(seed, get_seed()) a = np.random.randint(seed) b = torch.randint(seed, (1,)) set_determinism(seed=seed) c = np.random.randint(seed) d = torch.randint(seed, (1,)) self.assertEqual(a, c) self.assertEqual(b, d) self.assertTrue(torch.backends.cudnn.deterministic) self.assertTrue(not torch.backends.cudnn.benchmark) set_determinism(seed=None)
def prepare_data(self): # set deterministic training for reproducibility set_determinism(seed=0)
# See the License for the specific language governing permissions and # limitations under the License. import os import unittest from unittest import skipUnless import numpy as np from numpy.testing import assert_array_equal from parameterized import parameterized from monai.data import SlidingPatchWSIDataset from monai.utils import WSIPatchKeys, optional_import, set_determinism from tests.utils import download_url_or_skip_test, testing_data_config set_determinism(0) cucim, has_cucim = optional_import("cucim") has_cucim = has_cucim and hasattr(cucim, "CuImage") openslide, has_osl = optional_import("openslide") imwrite, has_tiff = optional_import("tifffile", name="imwrite") _, has_codec = optional_import("imagecodecs") has_tiff = has_tiff and has_codec FILE_KEY = "wsi_img" FILE_URL = testing_data_config("images", FILE_KEY, "url") base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) FILE_PATH_SMALL_0 = os.path.join(os.path.dirname(__file__), "testing_data",
def tearDown(self): set_determinism(seed=None) if os.path.exists(self.img_name): os.remove(self.img_name) if os.path.exists(self.seg_name): os.remove(self.seg_name)
def tearDown(self): set_determinism(seed=None) os.remove(os.path.join(self.data_dir, "best_metric_model.pth"))
def main(): parser = argparse.ArgumentParser() parser.add_argument("-d", "--dir", default="./testdata", type=str, help="directory of Brain Tumor dataset.") # must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP parser.add_argument("--local_rank", type=int, help="node rank for distributed training") parser.add_argument("-j", "--workers", default=1, type=int, metavar="N", help="number of data loading workers (default: 1)") parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument( "-b", "--batch_size", default=4, type=int, metavar="N", help="mini-batch size (default: 256), this is the total " "batch size of all GPUs on the current node when " "using Data Parallel or Distributed Data Parallel", ) parser.add_argument("-p", "--print_freq", default=10, type=int, metavar="N", help="print frequency (default: 10)") parser.add_argument("-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set") parser.add_argument("--seed", default=None, type=int, help="seed for initializing training.") parser.add_argument("--cache_rate", type=float, default=1.0) parser.add_argument("--val_interval", type=int, default=5) parser.add_argument("--network", type=str, default="UNet", choices=["UNet", "SegResNet"]) parser.add_argument("--log_dir", type=str, default=None) args = parser.parse_args() if args.seed is not None: set_determinism(seed=args.seed) warnings.warn("You have chosen to seed training. " "This will turn on the CUDNN deterministic setting, " "which can slow down your training considerably! " "You may see unexpected behavior when restarting " "from checkpoints.") main_worker(args=args)
def prepare_data(self): data_images = sorted([ os.path.join(data_path, x) for x in os.listdir(data_path) if x.startswith("data") ]) data_labels = sorted([ os.path.join(data_path, x) for x in os.listdir(data_path) if x.startswith("label") ]) data_dicts = [{ "image": image_name, "label": label_name, "patient": image_name.split("/")[-1].replace("data", "").replace(".nii.gz", ""), } for image_name, label_name in zip(data_images, data_labels)] train_files, val_files = train_val_split(data_dicts) print( f"Training patients: {len(train_files)}, Validation patients: {len(val_files)}" ) set_determinism(seed=0) train_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=PIXDIM, mode=("bilinear", "nearest")), DataStatsdWithPatient(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-100, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=PATCH_SIZE, pos=1, neg=1, num_samples=16, image_key="image", image_threshold=0, ), RandFlipd(["image", "label"], spatial_axis=[0, 1, 2], prob=0.5), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=PIXDIM, mode=("bilinear", "nearest")), DataStatsdWithPatient(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-100, a_max=300, b_min=0.0, b_max=1.0, clip=True, ), StoreShaped(keys=['image']), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) self.train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=cache_path) self.val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=cache_path)
def tearDown(self) -> None: set_determinism(None)
def train(self, train_info, valid_info, hyperparameters, run_data_check=False): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not run_data_check: start_dt = datetime.datetime.now() start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S') print(f'Training started: {start_dt_string}') # 1. Create folders to save the model timedate_info = str( datetime.datetime.now()).split(' ')[0] + '_' + str( datetime.datetime.now().strftime("%H:%M:%S")).replace( ':', '-') path_to_model = os.path.join( self.out_dir, 'trained_models', self.unique_name + '_' + timedate_info) os.mkdir(path_to_model) # 2. Load hyperparameters learning_rate = hyperparameters['learning_rate'] weight_decay = hyperparameters['weight_decay'] total_epoch = hyperparameters['total_epoch'] multiplicator = hyperparameters['multiplicator'] batch_size = hyperparameters['batch_size'] validation_epoch = hyperparameters['validation_epoch'] validation_interval = hyperparameters['validation_interval'] H = hyperparameters['H'] L = hyperparameters['L'] # 3. Consider class imbalance negative, positive = 0, 0 for _, label in train_info: if int(label) == 0: negative += 1 elif int(label) == 1: positive += 1 pos_weight = torch.Tensor([(negative / positive)]).to(self.device) # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices) train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, train_info) valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir, valid_info) large_image_splitter(train_data, self.cache_dir) set_determinism(seed=100) train_trans, valid_trans = self.transformations(H, L) train_dataset = PersistentDataset( data=train_data[:], transform=train_trans, cache_dir=self.persistent_dataset_dir) valid_dataset = PersistentDataset( data=valid_data[:], transform=valid_trans, cache_dir=self.persistent_dataset_dir) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=self.pin_memory, num_workers=self.num_workers, collate_fn=PadListDataCollate( Method.SYMMETRIC, NumpyPadMode.CONSTANT)) # Perform data checks if run_data_check: check_data = monai.utils.misc.first(train_loader) print(check_data["image"].shape, check_data["label"]) for i in range(batch_size): multi_slice_viewer( check_data["image"][i, 0, :, :, :], check_data["image_meta_dict"]["filename_or_obj"][i]) exit() """c = 1 for d in train_loader: img = d["image"] seg = d["seg"][0] seg, _ = nrrd.read(seg) img_name = d["image_meta_dict"]["filename_or_obj"][0] print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape) multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0]) #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0]) c += 1 exit()""" # 5. Prepare model model = ModelCT().to(self.device) # 6. Define loss function, optimizer and scheduler loss_function = torch.nn.BCEWithLogitsLoss( pos_weight) # pos_weight for class imbalance optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1) # 7. Create post validation transforms and handlers path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard') writer = SummaryWriter(log_dir=path_to_tensorboard) valid_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), ]) valid_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=writer, output_transform=lambda x: None), CheckpointSaver(save_dir=path_to_model, save_dict={"model": model}, save_key_metric=True), MetricsSaver(save_dir=path_to_model, metrics=['Valid_AUC', 'Valid_ACC']), ] # 8. Create validatior discrete = AsDiscrete(threshold_values=True) evaluator = SupervisedEvaluator( device=self.device, val_data_loader=valid_loader, network=model, post_transform=valid_post_transforms, key_val_metric={ "Valid_AUC": ROCAUC(output_transform=lambda x: (x["pred"], x["label"])) }, additional_metrics={ "Valid_Accuracy": Accuracy(output_transform=lambda x: (discrete(x["pred"]), x["label"])) }, val_handlers=valid_handlers, amp=self.amp, ) # 9. Create trainer # Loss function does the last sigmoid, so we dont need it here. train_post_transforms = Compose([ # Empty ]) logger = MetricLogger(evaluator=evaluator) train_handlers = [ logger, LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandlerCT(validator=evaluator, start=validation_epoch, interval=validation_interval, epoch_level=True), StatsHandler(tag_name="loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer, tag_name="Train_Loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=path_to_model, save_dict={ "model": model, "opt": optimizer }, save_interval=1, n_saved=1), ] trainer = SupervisedTrainer( device=self.device, max_epochs=total_epoch, train_data_loader=train_loader, network=model, optimizer=optimizer, loss_function=loss_function, post_transform=train_post_transforms, train_handlers=train_handlers, amp=self.amp, ) # 10. Run trainer trainer.run() # 11. Save results np.save(path_to_model + '/AUCS.npy', np.array(logger.metrics['Valid_AUC'])) np.save(path_to_model + '/ACCS.npy', np.array(logger.metrics['Valid_ACC'])) np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss)) np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters)) return path_to_model
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord from monai.utils import optional_import, set_determinism from tests.utils import skip_if_downloading_fails if TYPE_CHECKING: import matplotlib.pyplot as plt has_matplotlib = True has_pil = True else: plt, has_matplotlib = optional_import("matplotlib.pyplot") _, has_pil = optional_import("PIL.Image") RAND_SEED = 42 random.seed(RAND_SEED) set_determinism(seed=RAND_SEED) device = "cuda" if torch.cuda.is_available() else "cpu" @unittest.skipUnless(sys.platform == "linux", "requires linux") @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): def setUp(self): self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY") if not self.root_dir: self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") self.transforms = Compose( [
def setUp(self): set_determinism(0)
def tearDown(self): set_determinism(seed=None) shutil.rmtree(self.data_dir)
def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) transform = RandTorchVisiond(**input_param) result = transform(input_data) self.assertTrue(isinstance(transform, Randomizable)) torch.testing.assert_allclose(result["img"], expected_value)
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def __init__( self, architecture: SegmentationArchitectures = SegmentationArchitectures.ResidualUNet2D, loss: SegmentationLosses = SegmentationLosses.GeneralizedDiceLoss, optimizer: Optimizers = Optimizers.Adam, mask_type: MaskType = MaskType.TIFF_LABELS, in_channels: int = 1, out_channels: int = 3, roi_size: Tuple[int, int] = (384, 384), num_filters_in_first_layer: int = 16, learning_rate: float = 0.001, weight_decay: float = 0.0001, momentum: float = 0.9, num_epochs: int = 400, batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1), num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1), validation_step: int = 2, sliding_window_batch_size: int = 4, class_names: Tuple[str, ...] = ("Background", "Object", "Border"), experiment_name: str = "Unet", model_name: str = "best_model", seed: int = 4294967295, working_dir: str = '.', stdout: TextIOWrapper = sys.stdout, stderr: TextIOWrapper = sys.stderr ): """Constructor. @param mask_type: MaskType Type of mask: defines file type, mask geometry and they way pixels are assigned to the various classes. @see qu.data.model.MaskType @param architecture: SegmentationArchitectures Core network architecture: one of (SegmentationArchitectures.ResidualUNet2D, SegmentationArchitectures.AttentionUNet2D) @param loss: SegmentationLosses Loss function: currently only SegmentationLosses.GeneralizedDiceLoss is supported @param optimizer: Optimizers Optimizer: one of (Optimizers.Adam, Optimizers.SGD) @param in_channels: int, optional: default = 1 Number of channels in the input (e.g. 1 for gray-value images). @param out_channels: int, optional: default = 3 Number of channels in the output (classes). @param roi_size: Tuple[int, int], optional: default = (384, 384) Crop area (and input size of the U-Net network) used for training and validation/prediction. @param num_filters_in_first_layer: int Number of filters in the first layer. Every subsequent layer doubles the number of filters. @param learning_rate: float, optional: default = 1e-3 Initial learning rate for the optimizer. @param weight_decay: float, optional: default = 1e-4 Weight decay of the learning rate for the optimizer. Used by the Adam optimizer. @param momentum: float, optional: default = 0.9 Momentum of the accelerated gradient for the optimizer. Used by the SGD optimizer. @param num_epochs: int, optional: default = 400 Number of epochs for training. @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1) Batch sizes for training, validation, testing, and prediction, respectively. @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1) Number of workers for training, validation, testing, and prediction, respectively. @param validation_step: int, optional: default = 2 Number of training steps before the next validation is performed. @param sliding_window_batch_size: int, optional: default = 4 Number of batches for sliding window inference during validation and prediction. @param class_names: Tuple[str, ...], optional: default = ("Background", "Object", "Border") Name of the classes for logging validation curves. @param experiment_name: str, optional: default = "" Name of the experiment that maps to the folder that contains training information (to be used by tensorboard). Please note, current datetime will be appended. @param model_name: str, optional: default = "best_model.ph" Name of the file that stores the best model. Please note, current datetime will be appended (before the extension). @param seed: int, optional; default = 4294967295 Set random seed for modules to enable or disable deterministic training. @param working_dir: str, optional, default = "." Working folder where to save the model weights and the logs for tensorboard. """ # Call base constructor super().__init__() # Standard pipe wrappers self._stdout = stdout self._stderr = stderr # Device (initialize as "cpu") self._device = "cpu" # Architecture, loss function and optimizer self._option_architecture = architecture self._option_loss = loss self._option_optimizer = optimizer self._learning_rate = learning_rate self._weight_decay = weight_decay self._momentum = momentum # Mask type self._mask_type = mask_type # Input and output channels self._in_channels = in_channels self._out_channels = out_channels # Define hyper parameters self._roi_size = roi_size self._num_filters_in_first_layer = num_filters_in_first_layer self._training_batch_size = batch_sizes[0] self._validation_batch_size = batch_sizes[1] self._test_batch_size = batch_sizes[2] self._prediction_batch_size = batch_sizes[3] self._training_num_workers = num_workers[0] self._validation_num_workers = num_workers[1] self._test_num_workers = num_workers[2] self._prediction_num_workers = num_workers[3] self._n_epochs = num_epochs self._validation_step = validation_step self._sliding_window_batch_size = sliding_window_batch_size # Other parameters self._class_names = out_channels * ["Unknown"] for i in range(min(out_channels, len(class_names))): self._class_names[i] = class_names[i] # Set monai seed set_determinism(seed=seed) # All file names self._train_image_names: list = [] self._train_mask_names: list = [] self._validation_image_names: list = [] self._validation_mask_names: list = [] self._test_image_names: list = [] self._test_mask_names: list = [] # Transforms self._train_image_transforms = None self._train_mask_transforms = None self._validation_image_transforms = None self._validation_mask_transforms = None self._test_image_transforms = None self._test_mask_transforms = None self._prediction_image_transforms = None self._validation_post_transforms = None self._test_post_transforms = None self._prediction_post_transforms = None # Datasets and data loaders self._train_dataset = None self._train_dataloader = None self._validation_dataset = None self._validation_dataloader = None self._test_dataset = None self._test_dataloader = None self._prediction_dataset = None self._prediction_dataloader = None # Set model architecture, loss function, metric and optimizer self._model = None self._training_loss_function = None self._optimizer = None self._validation_metric = None # Working directory, model file name and experiment name for Tensorboard logs. # The file names will be redefined at the beginning of the training. self._working_dir = Path(working_dir).resolve() self._raw_experiment_name = experiment_name self._raw_model_file_name = model_name # Keep track of the full path of the best model self._best_model = '' # Keep track of last error message self._message = ""
def test_training(self): set_determinism(seed=0) loss, step = run_test(device=self.device) print(f"Deterministic loss {loss} at training step {step}") np.testing.assert_allclose(step, 4) np.testing.assert_allclose(loss, 0.536134, rtol=1e-4)
def main(): #TODO Defining file paths & output directory path json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json') data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19') logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467') if os.path.exists(logdir_path)==False: os.mkdir(logdir_path) # Load Json & Append Root Path with open(json_Path, 'r') as json_f: json_Data = json.load(json_f) train_Data = json_Data['training'] val_Data = json_Data['validation'] for idx, each_d in enumerate(train_Data): train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image']) for idx, each_d in enumerate(val_Data): val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image']) print('Total Number of Training Data Samples: {}'.format(len(train_Data))) print(train_Data) print('#' * 10) print('Total Number of Validation Data Samples: {}'.format(len(val_Data))) print(val_Data) print('#' * 10) # Set Determinism set_determinism(seed=123) # Define Training Transforms train_Transforms = Compose( [ LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"]), Spacingd(keys=["image"], pixdim=( 2.0, 2.0, 2.0), mode=("bilinear")), ScaleIntensityRanged( keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image"], source_key="image"), SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)), RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2), CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False), OneOf(transforms=[ RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8), # Please note that that if image, image_2 are called via the same transform call because of the determinism # they will get augmented the exact same way which is not the required case here, hence two calls are made OneOf(transforms=[ RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True, max_spatial_size=32), RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False, max_spatial_size=64), ] ), RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8) ] ) check_ds = Dataset(data=train_Data, transform=train_Transforms) check_loader = DataLoader(check_ds, batch_size=1) check_data = first(check_loader) image = (check_data["image"][0][0]) print(f"image shape: {image.shape}") # Define Network ViT backbone & Loss & Optimizer device = torch.device("cuda:0") model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed='conv', hidden_size=768, mlp_dim=3072, ) model = model.to(device) # Define Hyper-paramters for training loop max_epochs = 500 val_interval = 2 batch_size = 4 lr = 1e-4 epoch_loss_values = [] step_loss_values = [] epoch_cl_loss_values = [] epoch_recon_loss_values = [] val_loss_values = [] best_val_loss = 1000.0 recon_loss = L1Loss() contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Define DataLoader using MONAI, CacheDataset needs to be used train_ds = Dataset(data=train_Data, transform=train_Transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_ds = Dataset(data=val_Data, transform=train_Transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4) for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 epoch_cl_loss = 0 epoch_recon_loss = 0 step = 0 for batch_data in train_loader: step += 1 start_time = time.time() inputs, inputs_2, gt_input = ( batch_data["image"].to(device), batch_data["image_2"].to(device), batch_data["gt_image"].to(device), ) optimizer.zero_grad() outputs_v1, hidden_v1 = model(inputs) outputs_v2, hidden_v2 = model(inputs_2) flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4) flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4) r_loss = recon_loss(outputs_v1, gt_input) cl_loss = contrastive_loss(flat_out_v1, flat_out_v2) # Adjust the CL loss by Recon Loss total_loss = r_loss + cl_loss * r_loss total_loss.backward() optimizer.step() epoch_loss += total_loss.item() step_loss_values.append(total_loss.item()) # CL & Recon Loss Storage of Value epoch_cl_loss += cl_loss.item() epoch_recon_loss += r_loss.item() end_time = time.time() print( f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {total_loss.item():.4f}, " f"time taken: {end_time-start_time}s") epoch_loss /= step epoch_cl_loss /= step epoch_recon_loss /= step epoch_loss_values.append(epoch_loss) epoch_cl_loss_values.append(epoch_cl_loss) epoch_recon_loss_values.append(epoch_recon_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if epoch % val_interval == 0: print('Entering Validation for epoch: {}'.format(epoch+1)) total_val_loss = 0 val_step = 0 model.eval() for val_batch in val_loader: val_step += 1 start_time = time.time() inputs, gt_input = ( val_batch["image"].to(device), val_batch["gt_image"].to(device), ) print('Input shape: {}'.format(inputs.shape)) outputs, outputs_v2 = model(inputs) val_loss = recon_loss(outputs, gt_input) total_val_loss += val_loss.item() end_time = time.time() total_val_loss /= val_step val_loss_values.append(total_val_loss) print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s") if total_val_loss < best_val_loss: print(f"Saving new model based on validation loss {total_val_loss:.4f}") best_val_loss = total_val_loss checkpoint = {'epoch': max_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt')) plt.figure(1, figsize=(8, 8)) plt.subplot(2, 2, 1) plt.plot(epoch_loss_values) plt.grid() plt.title('Training Loss') plt.subplot(2, 2, 2) plt.plot(val_loss_values) plt.grid() plt.title('Validation Loss') plt.subplot(2, 2, 3) plt.plot(epoch_cl_loss_values) plt.grid() plt.title('Training Contrastive Loss') plt.subplot(2, 2, 4) plt.plot(epoch_recon_loss_values) plt.grid() plt.title('Training Recon Loss') plt.savefig(os.path.join(logdir_path, 'loss_plots.png')) plt.close(1) print('Done') return None