def test_read_patches_openslide(self, input_parameters, expected): dataset = PatchWSIDataset(**input_parameters) samples = dataset[0] for i in range(len(samples)): self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape) self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape) self.assertIsNone( assert_array_equal(samples[i]["label"], expected[i]["label"])) self.assertIsNone( assert_array_equal(samples[i]["image"], expected[i]["image"]))
def main(cfg): # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- # Create log/model dir log_dir = create_log_dir(cfg) # Set the logger logging.basicConfig( format="%(asctime)s %(levelname)2s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) log_name = os.path.join(log_dir, "logs.txt") logger = logging.getLogger() fh = logging.FileHandler(log_name) fh.setLevel(logging.INFO) logger.addHandler(fh) # Set TensorBoard summary writer writer = SummaryWriter(log_dir) # Save configs logging.info(json.dumps(cfg)) with open(os.path.join(log_dir, "config.json"), "w") as fp: json.dump(cfg, fp, indent=4) # Set device cuda/cpu device = set_device(cfg) # Set cudnn benchmark/deterministic if cfg["benchmark"]: torch.backends.cudnn.benchmark = True else: set_determinism(seed=0) # ------------------------------------------------------------------------- # Transforms and Datasets # ------------------------------------------------------------------------- # Pre-processing preprocess_cpu_train = None preprocess_gpu_train = None preprocess_cpu_valid = None preprocess_gpu_valid = None if cfg["backend"] == "cucim": preprocess_cpu_train = Compose([ToTensorD(keys="label")]) preprocess_gpu_train = Compose([ Range()(ToCupy()), Range("ColorJitter")(RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)), Range("RandomFlip")(RandCuCIM(name="image_flip", apply_prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90", prob=cfg["prob"], max_k=3, spatial_axis=(-2, -1))), Range()(CastToType(dtype=np.float32)), Range("RandomZoom")(RandCuCIM(name="rand_zoom", min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensor(device=device)), ]) preprocess_cpu_valid = Compose([ToTensorD(keys="label")]) preprocess_gpu_valid = Compose([ Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)), Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensor(device=device)), ]) elif cfg["backend"] == "numpy": preprocess_cpu_train = Compose([ Range()(ToTensorD(keys=("image", "label"))), Range("ColorJitter")(TorchVisionD( keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04, )), Range()(ToNumpyD(keys="image")), Range("RandomFlip")(RandFlipD(keys="image", prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandRotate90D(keys="image", prob=cfg["prob"])), Range()(CastToTypeD(keys="image", dtype=np.float32)), Range("RandomZoom")(RandZoomD(keys="image", prob=cfg["prob"], min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensorD(keys="image")), ]) preprocess_cpu_valid = Compose([ Range("ValidCastType")(CastToTypeD(keys="image", dtype=np.float32)), Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensorD(keys=("image", "label"))), ]) else: raise ValueError( f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]" ) # Post-processing postprocess = Compose([ Activations(sigmoid=True), AsDiscrete(threshold=0.5), ]) # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( data=train_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_train, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( data=valid_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_valid, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("First sample is None!") for d in ["image", "label"]: logging.info(f"[{d}] \n" f" {d} shape: {first_sample[d].shape}\n" f" {d} type: {type(first_sample[d])}\n" f" {d} dtype: {first_sample[d].dtype}") logging.info(f"Batch size: {cfg['batch_size']}") logging.info(f"[Training] number of batches: {len(train_dataloader)}") logging.info(f"[Validation] number of batches: {len(valid_dataloader)}") # ------------------------------------------------------------------------- # Deep Learning Model and Configurations # ------------------------------------------------------------------------- # Initialize model model = TorchVisionFCModel("resnet18", n_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # Loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # Optimizer if cfg["novograd"] is True: optimizer = Novograd(model.parameters(), lr=cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6) if cfg["amp"] is True: scaler = GradScaler() else: scaler = None # Learning rate scheduler if cfg["cos"] is True: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) else: scheduler = None # ------------------------------------------------------------------------- # Training/Evaluating # ------------------------------------------------------------------------- train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1} total_valid_time, total_train_time = 0.0, 0.0 t_start = time.perf_counter() metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1} # Training/Validation Loop for _ in range(cfg["n_epochs"]): t_epoch = time.perf_counter() logging.info( f"[Training] learning rate: {optimizer.param_groups[0]['lr']}") # Training with Range("Training Epoch"): train_counter = training( train_counter, model, loss_func, optimizer, scaler, cfg["amp"], train_dataloader, preprocess_gpu_train, postprocess, device, writer, cfg["print_step"], ) if scheduler is not None: scheduler.step() if cfg["save"]: torch.save( model.state_dict(), os.path.join(log_dir, f"model_epoch_{train_counter['epoch']}.pt")) t_train = time.perf_counter() train_time = t_train - t_epoch total_train_time += train_time # Validation if cfg["validate"]: with Range("Validation"): valid_loss, valid_acc = validation( model, loss_func, cfg["amp"], valid_dataloader, preprocess_gpu_valid, postprocess, device, cfg["print_step"], ) t_valid = time.perf_counter() valid_time = t_valid - t_train total_valid_time += valid_time if valid_loss < metric_summary["loss"]: metric_summary["loss"] = min(valid_loss, metric_summary["loss"]) metric_summary["accuracy"] = max(valid_acc, metric_summary["accuracy"]) metric_summary["best_epoch"] = train_counter["epoch"] writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"]) writer.add_scalar("valid/accuracy", valid_acc, train_counter["epoch"]) logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, " f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)" ) else: logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s" ) writer.flush() t_end = time.perf_counter() # Save final metrics metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"] metric_summary["total_time"] = t_end - t_start writer.add_hparams(hparam_dict=cfg, metric_dict=metric_summary, run_name=log_dir) writer.close() logging.info(f"Metric Summary: {metric_summary}") # Save the best and final model if cfg["validate"] is True: copyfile( os.path.join(log_dir, f"model_epoch_{metric_summary['best_epoch']}.pth"), os.path.join(log_dir, "model_best.pth"), ) copyfile( os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"), os.path.join(log_dir, "model_final.pth"), ) # Final prints logging.info( f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s " f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)", ) logging.info(f"Logs and model was saved at: {log_dir}")
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()