def test_ill_arg(self): param = {"params": [Variable(torch.randn(10), requires_grad=True)]} with self.assertRaisesRegex(ValueError, "Invalid learning rate: -1"): Novograd(param, lr=-1) with self.assertRaisesRegex(ValueError, "Invalid epsilon value: -1"): Novograd(param, eps=-1) with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): Novograd(param, betas=(1.0, 0.98)) with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: -1"): Novograd(param, betas=(0.9, -1)) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): Novograd(param, weight_decay=-1)
def test_step(self, specify_param, default_param, weight, bias, input): optimizer = Novograd(specify_param, **default_param) def fn(): optimizer.zero_grad() y = weight.mv(input) if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device(): y = y.cuda(bias.get_device()) loss = (y + bias).pow(2).sum() loss.backward() return loss initial_value = fn().item() for _ in range(100): optimizer.step(fn) self.assertLess(fn().item(), initial_value)
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}") dist.destroy_process_group()
max_epochs = 6 learning_rate = 1e-4 val_interval = 2 model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, squared_pred=True, batch=True ) optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False ) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)] ) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = [] metric_values = []
def main(cfg): # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- # Create log/model dir log_dir = create_log_dir(cfg) # Set the logger logging.basicConfig( format="%(asctime)s %(levelname)2s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) log_name = os.path.join(log_dir, "logs.txt") logger = logging.getLogger() fh = logging.FileHandler(log_name) fh.setLevel(logging.INFO) logger.addHandler(fh) # Set TensorBoard summary writer writer = SummaryWriter(log_dir) # Save configs logging.info(json.dumps(cfg)) with open(os.path.join(log_dir, "config.json"), "w") as fp: json.dump(cfg, fp, indent=4) # Set device cuda/cpu device = set_device(cfg) # Set cudnn benchmark/deterministic if cfg["benchmark"]: torch.backends.cudnn.benchmark = True else: set_determinism(seed=0) # ------------------------------------------------------------------------- # Transforms and Datasets # ------------------------------------------------------------------------- # Pre-processing preprocess_cpu_train = None preprocess_gpu_train = None preprocess_cpu_valid = None preprocess_gpu_valid = None if cfg["backend"] == "cucim": preprocess_cpu_train = Compose([ToTensorD(keys="label")]) preprocess_gpu_train = Compose([ Range()(ToCupy()), Range("ColorJitter")(RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)), Range("RandomFlip")(RandCuCIM(name="image_flip", apply_prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90", prob=cfg["prob"], max_k=3, spatial_axis=(-2, -1))), Range()(CastToType(dtype=np.float32)), Range("RandomZoom")(RandCuCIM(name="rand_zoom", min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensor(device=device)), ]) preprocess_cpu_valid = Compose([ToTensorD(keys="label")]) preprocess_gpu_valid = Compose([ Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)), Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensor(device=device)), ]) elif cfg["backend"] == "numpy": preprocess_cpu_train = Compose([ Range()(ToTensorD(keys=("image", "label"))), Range("ColorJitter")(TorchVisionD( keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04, )), Range()(ToNumpyD(keys="image")), Range("RandomFlip")(RandFlipD(keys="image", prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandRotate90D(keys="image", prob=cfg["prob"])), Range()(CastToTypeD(keys="image", dtype=np.float32)), Range("RandomZoom")(RandZoomD(keys="image", prob=cfg["prob"], min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensorD(keys="image")), ]) preprocess_cpu_valid = Compose([ Range("ValidCastType")(CastToTypeD(keys="image", dtype=np.float32)), Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensorD(keys=("image", "label"))), ]) else: raise ValueError( f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]" ) # Post-processing postprocess = Compose([ Activations(sigmoid=True), AsDiscrete(threshold=0.5), ]) # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( data=train_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_train, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( data=valid_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_valid, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("First sample is None!") for d in ["image", "label"]: logging.info(f"[{d}] \n" f" {d} shape: {first_sample[d].shape}\n" f" {d} type: {type(first_sample[d])}\n" f" {d} dtype: {first_sample[d].dtype}") logging.info(f"Batch size: {cfg['batch_size']}") logging.info(f"[Training] number of batches: {len(train_dataloader)}") logging.info(f"[Validation] number of batches: {len(valid_dataloader)}") # ------------------------------------------------------------------------- # Deep Learning Model and Configurations # ------------------------------------------------------------------------- # Initialize model model = TorchVisionFCModel("resnet18", n_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # Loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # Optimizer if cfg["novograd"] is True: optimizer = Novograd(model.parameters(), lr=cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6) if cfg["amp"] is True: scaler = GradScaler() else: scaler = None # Learning rate scheduler if cfg["cos"] is True: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) else: scheduler = None # ------------------------------------------------------------------------- # Training/Evaluating # ------------------------------------------------------------------------- train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1} total_valid_time, total_train_time = 0.0, 0.0 t_start = time.perf_counter() metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1} # Training/Validation Loop for _ in range(cfg["n_epochs"]): t_epoch = time.perf_counter() logging.info( f"[Training] learning rate: {optimizer.param_groups[0]['lr']}") # Training with Range("Training Epoch"): train_counter = training( train_counter, model, loss_func, optimizer, scaler, cfg["amp"], train_dataloader, preprocess_gpu_train, postprocess, device, writer, cfg["print_step"], ) if scheduler is not None: scheduler.step() if cfg["save"]: torch.save( model.state_dict(), os.path.join(log_dir, f"model_epoch_{train_counter['epoch']}.pt")) t_train = time.perf_counter() train_time = t_train - t_epoch total_train_time += train_time # Validation if cfg["validate"]: with Range("Validation"): valid_loss, valid_acc = validation( model, loss_func, cfg["amp"], valid_dataloader, preprocess_gpu_valid, postprocess, device, cfg["print_step"], ) t_valid = time.perf_counter() valid_time = t_valid - t_train total_valid_time += valid_time if valid_loss < metric_summary["loss"]: metric_summary["loss"] = min(valid_loss, metric_summary["loss"]) metric_summary["accuracy"] = max(valid_acc, metric_summary["accuracy"]) metric_summary["best_epoch"] = train_counter["epoch"] writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"]) writer.add_scalar("valid/accuracy", valid_acc, train_counter["epoch"]) logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, " f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)" ) else: logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s" ) writer.flush() t_end = time.perf_counter() # Save final metrics metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"] metric_summary["total_time"] = t_end - t_start writer.add_hparams(hparam_dict=cfg, metric_dict=metric_summary, run_name=log_dir) writer.close() logging.info(f"Metric Summary: {metric_summary}") # Save the best and final model if cfg["validate"] is True: copyfile( os.path.join(log_dir, f"model_epoch_{metric_summary['best_epoch']}.pth"), os.path.join(log_dir, "model_best.pth"), ) copyfile( os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"), os.path.join(log_dir, "model_final.pth"), ) # Final prints logging.info( f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s " f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)", ) logging.info(f"Logs and model was saved at: {log_dir}")
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def optimizer(self, context: Context): return Novograd(context.network.parameters(), 0.0001)