def test_correct_results(self, min_zoom, max_zoom, order, mode, cval, prefilter, use_gpu, keep_size): key = "img" random_zoom = RandZoomd( key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, interp_order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size, ) random_zoom.set_random_state(234) zoomed = random_zoom({key: self.imt[0]}) expected = list() for channel in self.imt[0]: expected.append( zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, cval=cval, prefilter=prefilter)) expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed[key]))
def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter): key = "img" if importlib.util.find_spec("cupy"): random_zoom = RandZoomd( key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, interp_order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=True, keep_size=False, ) random_zoom.set_random_state(234) zoomed = random_zoom({key: self.imt[0]}) expected = list() for channel in self.imt[0]: expected.append( zoom_scipy(channel, zoom=random_zoom._zoom, mode=mode, order=order, cval=cval, prefilter=prefilter)) expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(expected, zoomed[key]))
def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size): key = "img" random_zoom = RandZoomd( key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, align_corners=align_corners, keep_size=keep_size, ) for p in TEST_NDARRAYS: random_zoom.set_random_state(1234) zoomed = random_zoom({key: p(self.imt[0])}) expected = [ zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False) for channel in self.imt[0] ] expected = np.stack(expected).astype(np.float32) assert_allclose(zoomed[key], p(expected), atol=1.0)
def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size): key = "img" random_zoom = RandZoomd( key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, align_corners=align_corners, keep_size=keep_size, ) random_zoom.set_random_state(1234) zoomed = random_zoom({key: self.imt[0]}) expected = [] for channel in self.imt[0]: expected.append( zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) expected = np.stack(expected).astype(np.float32) np.testing.assert_allclose(expected, zoomed[key], atol=1.0)
def test_auto_expand_3d(self): random_zoom = RandZoomd( keys="img", prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode="nearest", keep_size=False ) for p in TEST_NDARRAYS: random_zoom.set_random_state(1234) test_data = {"img": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))} zoomed = random_zoom(test_data) assert_allclose(random_zoom.rand_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) assert_allclose(zoomed["img"].shape, (2, 2, 3, 3))
def test_keep_size(self): key = "img" random_zoom = RandZoomd( keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode="constant", constant_values=2 ) for p in TEST_NDARRAYS: zoomed = random_zoom({key: p(self.imt[0])}) np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:])
def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises): key = "img" with self.assertRaises(raises): random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, interp_order=order) zoomed = random_zoom({key: self.imt[0]})
def test_keep_size(self): key = "img" random_zoom = RandZoomd(key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) zoomed = random_zoom({key: self.imt[0]}) self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:]))
def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): key = "img" with self.assertRaises(raises): random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) random_zoom({key: self.imt[0]})
def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): key = "img" for p in TEST_NDARRAYS: with self.assertRaises(raises): random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) random_zoom({key: p(self.imt[0])})
"Zoomd 2d", "2D", 2e-1, Zoomd(KEYS, zoom=0.9), )) TESTS.append(( "Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), )) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append(( "RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0), )) TESTS.append(( "Rotated 2d", "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi),
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], postfix="inverted2", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted1"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms for i in engine.state.output["image_inverted2"]: # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted2"]: # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107))
) from monai.utils import set_determinism TESTS: List[Tuple] = [] for pad_collate in [pad_list_data_collate, PadListDataCollate()]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) TESTS.append( (list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append( (list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2)))
"2D", 2e-1, Zoomd(KEYS, zoom=0.9), ) ) TESTS.append( ( "Zoomd 3d", "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), ) ) TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append( ( "RandRotated, prob 0", "2D", 0, RandRotated(KEYS, prob=0), ) ) TESTS.append( ( "Rotated 2d", "2D", 8e-2,
aug_prob = 0.5 keys = ("img", "seg") # use these when interpolating binary segmentations to ensure values are 0 or 1 only zoom_mode = monai.utils.enums.InterpolateMode.NEAREST elast_mode = monai.utils.enums.GridSampleMode.BILINEAR, monai.utils.enums.GridSampleMode.NEAREST trans = Compose( [ ScaleIntensityd(keys=("img",)), # rescale image data to range [0,1] AddChanneld(keys=keys), # add 1-size channel dimension RandRotate90d(keys=keys, prob=aug_prob), RandFlipd(keys=keys, prob=aug_prob), RandZoomd(keys=keys, prob=aug_prob, mode=zoom_mode), Rand2DElasticd(keys=keys, prob=aug_prob, spacing=10, magnitude_range=(-2, 2), mode=elast_mode), RandAffined(keys=keys, prob=aug_prob, rotate_range=1, translate_range=16, mode=elast_mode), ToTensord(keys=keys), # convert to tensor ] ) data = [ {"img": train_images[i], "seg": train_segs[i]} for i in range(len(train_images)) ] ds = CacheDataset(data, trans) loader = DataLoader( dataset=ds, batch_size=batch_size,
def training(train_files, val_files, log_dir): # Define transforms for image print(log_dir) train_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]), RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]), RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5), # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1), ToTensord(keys=["inputs"]), ] ) val_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # data_size = len(full_files) # split = data_size // 2 # indices = list(range(data_size)) # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True) # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn) # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn) # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn) # randomBatch_sizeList = [8, 16, 32, 64, 128] # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3] # batch_size = random.choice(randomBatch_sizeList) # lr = random.choice(randomLRList) lr = 0.01 batch_size = 256 # print(batch_size) # print(lr) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) check_data = monai.utils.misc.first(check_loader) # print(check_data) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device) # train_data = monai.utils.misc.first(train_loader) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) # Create Net, CrossEntropyLoss and Adam optimizer # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device) # im_size = (2,) + tuple(train_ds[0]["inputs"].shape) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) classes = np.array([0, 1, 2, 3, 4]) # print(check_data["label"].numpy()) class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy()) class_weights_tensor = torch.Tensor(class_weights).to(device) # print(class_weights_tensor) # loss_function = nn.BCEWithLogitsLoss() loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor) # loss_function = torch.nn.MSELoss() # m = torch.nn.LogSoftmax(dim=1) optimizer = torch.optim.Adam(model.parameters(), lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1) # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(log_dir): checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') # start a typical PyTorch training epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() # checkpoint_interval = 100 for epoch in range(start_epoch + 1, epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") # print(scheduler.get_last_lr()) model.train() epoch_loss = 0 step = 0 # for i, (inputs, labels, imgName) in enumerate(train_loader): for batch_data in train_loader: step += 1 inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device) # batch_arr = [] # for j in range(len(inputs)): # batch_arr.append(inputs[i]) # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # batch_img = batch_img.type(torch.FloatTensor).to(device) outputs = model(inputs) # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5) # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device)) loss = loss_function(outputs, labels.long()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}") epoch_len = len(train_loader) // train_loader.batch_size writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step scheduler.step() print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次 # checkpoint = {'model': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'epoch': epoch # } # path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch) # torch.save(checkpoint, path_checkpoint) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) # for i, (inputs, labels, imgName) in enumerate(val_loader): for val_data in val_loader: val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device) # val_batch_arr = [] # for j in range(len(inputs)): # val_batch_arr.append(inputs[i]) # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # val_img = val_img.type(torch.FloatTensor).to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5) # y_pred = torch.sigmoid(y_pred) # y = (y / 0.25).long() # print(y) # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) # zero = torch.zeros_like(y_pred) # one = torch.ones_like(y_pred) # y_pred_label = torch.where(y_pred > 0.5, one, zero) # print((y_pred_label.sum(1)).to(torch.long)) # y_pred_acc = (y_pred_label.sum(1)).to(torch.long) # print(y_pred.argmax(dim=1)) # kappa_value = kappa(cm) kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic') # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic') metric_values.append(kappa_value) acc_value = torch.eq(y_pred.argmax(dim=1), y) # print(acc_value) acc_metric = acc_value.sum().item() / len(acc_value) if kappa_value > best_metric: best_metric = kappa_value best_metric_epoch = epoch + 1 checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, log_dir) print("saved new best metric model") print( "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format( epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close() plt.figure('train', (12, 6)) plt.subplot(1, 2, 1) plt.title("Epoch Average Loss") x = [i + 1 for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel('epoch') plt.plot(x, y) plt.subplot(1, 2, 2) plt.title("Validation: Area under the ROC curve") x = [val_interval * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel('epoch') plt.plot(x, y) plt.show() evaluta_model(val_files, log_dir)
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = ( make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), # test EnsureTensor for complicated dict data and invert it CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), # test to support Tensor, Numpy array and dictionary when inverting EnsureTyped(keys=["image", "test_dict"]), ToTensord("image"), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted", "label_inverted", "test_dict"], transform=transform, orig_keys=["label", "label", "test_dict"], meta_keys=[ PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None ], orig_meta_keys=[ PostFix.meta("label"), PostFix.meta("label"), None ], nearest_interp=True, to_tensor=[True, False, False], device="cpu", ) inverter_1 = Invertd( # `image` was not copied, invert the original value directly keys=["image_inverted1", "label_inverted1"], transform=transform, orig_keys=["image", "image"], meta_keys=[ PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1") ], orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], nearest_interp=[True, False], to_tensor=[True, True], device="cpu", ) expected_keys = [ "image", "image_inverted", "image_inverted1", PostFix.meta("image_inverted1"), PostFix.meta("image_inverted"), PostFix.meta("image"), "image_transforms", "label", "label_inverted", "label_inverted1", PostFix.meta("label_inverted1"), PostFix.meta("label_inverted"), PostFix.meta("label"), "label_transforms", "test_dict", "test_dict_transforms", ] # execute 1 epoch for d in loader: d = decollate_batch(d) for item in d: item = inverter(item) item = inverter_1(item) self.assertListEqual(sorted(item), expected_keys) self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) # check the nearest interpolation mode i = item["image_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) i = item["label_inverted"] torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape[1:], (100, 101, 107)) # test inverted test_dict self.assertTrue( isinstance(item["test_dict"]["affine"], np.ndarray)) self.assertTrue( isinstance(item["test_dict"]["filename_or_obj"], str)) # check the case that different items use different interpolation mode to invert transforms d = item["image_inverted1"] # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) d = item["label_inverted1"] # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(d.shape, (1, 100, 101, 107)) # check labels match reverted = item["label_inverted"].detach().cpu().numpy().astype( np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") set_determinism(seed=None)
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_samples): if mode != "test": keys = ["image", "label"] else: keys = ["image"] load_transforms = [ LoadImaged(keys=keys), EnsureChannelFirstd(keys=keys), ] # 2. sampling sample_transforms = [ PreprocessAnisotropic( keys=keys, clip_values=clip_values[task_id], pixdim=spacing[task_id], normalize_values=normalize_values[task_id], model_mode=mode, ), ] # 3. spatial transforms if mode == "train": other_transforms = [ SpatialPadd(keys=["image", "label"], spatial_size=patch_size[task_id]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=patch_size[task_id], pos=pos_sample_num, neg=neg_sample_num, num_samples=num_samples, image_key="image", image_threshold=0, ), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("trilinear", "nearest"), align_corners=(True, None), prob=0.15, ), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] elif mode == "validation": other_transforms = [ CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ] else: other_transforms = [ CastToTyped(keys=["image"], dtype=(np.float32)), EnsureTyped(keys=["image"]), ] all_transforms = load_transforms + sample_transforms + other_transforms return Compose(all_transforms)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord( "image" ), # test to support both Tensor and Numpy array when inverting CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) inverter = Invertd( keys=["image", "label"], transform=transform, loader=loader, orig_keys="label", nearest_interp=True, postfix="inverted", to_tensor=[True, False], device="cpu", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ) # execute 1 epoch for d in loader: d = inverter(d) # this unit test only covers basic function, test_handler_transform_inverter covers more self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100)) self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) # check the nearest inerpolation mode for i in d["image_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in d["label_inverted"]: np.testing.assert_allclose( i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) set_determinism(seed=None)
else: _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] TESTS_3D = [( t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3 ) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose( [RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device( "cuda" if torch.cuda.is_available() else "cpu")), ]] TESTS_2D = [ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), Compose([
def run_training(train_file_list, valid_file_list, config_info): """ Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: * Data Preparation: Extract the filenames and prepare the training/validation processing transforms * Load Data: Load training and validation data to PyTorch DataLoader * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. * Run training: The MONAI trainer is run, performing training and validation during training. Args: train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: image filename in the first column and segmentation filename in the second column. The two columns should be separated by a comma. See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. config_info: dict, contains configuration parameters for sampling, network and training. See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. """ """ Read input and configuration parameters """ # print MONAI config information logging.basicConfig(stream=sys.stdout, level=logging.INFO) print_config() # print to log the parameter setups print(yaml.dump(config_info)) # extract network parameters, perform checks/set defaults if not present and print them to log if 'seg_labels' in config_info['training'].keys(): seg_labels = config_info['training']['seg_labels'] else: seg_labels = [1] nr_out_channels = len(seg_labels) print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) patch_size = config_info["training"]["inplane_size"] + [1] print("Considering patch size = {}".format(patch_size)) spacing = config_info["training"]["spacing"] print("Bringing all images to spacing = {}".format(spacing)) if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) else: print("Loading model from {}".format(model_to_load)) else: model_to_load = None # set up either GPU or CPU usage if torch.cuda.is_available(): print("\n#### GPU INFORMATION ###") print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) current_device = torch.device("cuda:0") else: current_device = torch.device("cpu") print("Using device: {}".format(current_device)) # set determinism if required if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: seed = config_info['training']['manual_seed'] else: seed = None if seed is not None: print("Using determinism with seed = {}\n".format(seed)) set_determinism(seed=seed) """ Setup data output directory """ out_model_dir = os.path.join(config_info['output']['out_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['out_postfix']) print("Saving to directory {}\n".format(out_model_dir)) # create cache directory to store results for Persistent Dataset if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) """ Data preparation """ # Read the input files for training and validation print("*** Loading input data for training...") train_files = create_data_list_of_dictionaries(train_file_list) print("Number of inputs for training = {}".format(len(train_files))) val_files = create_data_list_of_dictionaries(valid_file_list) print("Number of inputs for validation = {}".format(len(val_files))) # Define MONAI processing transforms for the training data. This includes: # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, # RandFlipd) # - ToTensor: convert to pytorch tensor train_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), SqueezeDimd(keys=["image", "label"], dim=-1), RandZoomd( keys=["image", "label"], min_zoom=0.9, max_zoom=1.2, mode=("bilinear", "nearest"), align_corners=(True, None), prob=0.16, ), RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, keep_size=True, mode=["bilinear", "nearest"], padding_mode=["zeros", "border"]), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandGaussianSmoothd( keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15, ), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), ToTensord(keys=["image", "label"]), ] ) # Define MONAI processing transforms for the validation data # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 # - CropForegroundd: Reduce the background from the MR image # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the # last direction (lowest resolution) to avoid introducing motion artefact resampling errors # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed # - NormalizeIntensityd: Apply whitening # - ToTensor: convert to pytorch tensor # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference val_transforms = Compose( [ LoadNiftid(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key="image"), InPlaneSpacingd( keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), ), SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), ToTensord(keys=["image", "label"]), ] ) """ Load data """ # create training data loader train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=config_info['training']['batch_size_train'], shuffle=True, num_workers=config_info['device']['num_workers']) check_train_data = misc.first(train_loader) print("Training data tensor shapes:") print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) # create validation data loader if config_info['training']['batch_size_valid'] != 1: raise Exception("Batch size different from 1 at validation ar currently not supported") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=config_info['device']['num_workers']) check_valid_data = misc.first(val_loader) print("Validation data tensor shapes (Example):") print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) """ Network preparation """ print("*** Preparing the network ...") # automatically extracts the strides and kernels based on nnU-Net empirical rules spacings = spacing[:2] sizes = patch_size[:2] strides, kernels = [], [] while True: spacing_ratio = [sp / min(spacings) for sp in spacings] stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] if all(s == 1 for s in stride): break sizes = [i / j for i, j in zip(sizes, stride)] spacings = [i * j for i, j in zip(spacings, stride)] kernels.append(kernel) strides.append(stride) strides.insert(0, len(spacings) * [1]) kernels.append(len(spacings) * [3]) # initialise the network net = DynUNet( spatial_dims=2, in_channels=1, out_channels=nr_out_channels, kernel_size=kernels, strides=strides, upsample_kernel_size=strides[1:], norm_name="instance", deep_supervision=True, deep_supr_num=2, res_block=False, ).to(current_device) print(net) # define the loss function loss_function = choose_loss_function(nr_out_channels, config_info) # define the optimiser and the learning rate scheduler opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 ) """ MONAI evaluator """ print("*** Preparing the dynUNet evaluator engine...\n") # val_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), output_transform=lambda x: None, global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, file_prefix='best_valid'), ] if config_info['output']['val_image_to_tensorboad']: val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"], interval=2)) # Define customized evaluator class DynUNetEvaluator(SupervisedEvaluator): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) flip_inputs_1 = torch.flip(inputs, dims=(2,)) flip_inputs_2 = torch.flip(inputs, dims=(3,)) flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) def _compute_pred(): pred = self.inferer(inputs, self.network) # use random flipping as data augmentation at inference flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 # execute forward computation self.network.eval() with torch.no_grad(): if self.amp: with torch.cuda.amp.autocast(): predictions = _compute_pred() else: predictions = _compute_pred() return {"image": inputs, "label": targets, "pred": predictions} evaluator = DynUNetEvaluator( device=current_device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), post_transform=None, key_val_metric={ "Mean_dice": MeanDice( include_background=False, to_onehot_y=True, mutually_exclusive=True, output_transform=lambda x: (x["pred"], x["label"]), ) }, val_handlers=val_handlers, amp=False, ) """ MONAI trainer """ print("*** Preparing the dynUNet trainer engine...\n") # train_post_transforms = Compose( # [ # Activationsd(keys="pred", sigmoid=True), # ] # ) validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len # define event handlers for the trainer writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(summary_writer=writer_train, log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", output_transform=lambda x: x["loss"], global_epoch_transform=lambda x: trainer.state.iteration), CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_final=True, save_interval=2, epoch_level=True, n_saved=config_info['output']['max_nr_models_saved']), ] if model_to_load is not None: train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) # define customized trainer class DynUNetTrainer(SupervisedTrainer): def _iteration(self, engine, batchdata): inputs, targets = self.prepare_batch(batchdata) inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) def _compute_loss(preds, label): labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) self.network.train() self.optimizer.zero_grad() if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: predictions = self.inferer(inputs, self.network) loss = _compute_loss(predictions, targets).mean() loss.backward() self.optimizer.step() return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} trainer = DynUNetTrainer( device=current_device, max_epochs=config_info['training']['nr_train_epochs'], train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss_function, inferer=SimpleInferer(), post_transform=None, key_train_metric=None, train_handlers=train_handlers, amp=False, ) """ Run training """ print("*** Run training...") trainer.run() print("Done!")
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) for i in engine.state.output["image_inverted"] + engine.state.output[ "label_inverted"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) self.assertTrue((reverted.size - n_good) in (25300, 1812), "diff. in two possible values")