def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) set_determinism(12345) device = torch.device("cuda:0") # load real data mednist_url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" md5_value = "0bc7306e7427e00ad1c5526a6677552d" extract_dir = "data" tar_save_path = os.path.join(extract_dir, "MedNIST.tar.gz") download_and_extract(mednist_url, tar_save_path, extract_dir, md5_value) hand_dir = os.path.join(extract_dir, "MedNIST", "Hand") real_data = [{ "hand": os.path.join(hand_dir, filename) } for filename in os.listdir(hand_dir)] # define real data transforms train_transforms = Compose([ LoadPNGD(keys=["hand"]), AddChannelD(keys=["hand"]), ScaleIntensityD(keys=["hand"]), RandRotateD(keys=["hand"], range_x=15, prob=0.5, keep_size=True), RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5), RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensorD(keys=["hand"]), ]) # create dataset and dataloader real_dataset = CacheDataset(real_data, train_transforms) batch_size = 300 real_dataloader = DataLoader(real_dataset, batch_size=batch_size, shuffle=True, num_workers=10) # define function to process batchdata for input into discriminator def prepare_batch(batchdata): """ Process Dataloader batchdata dict object and return image tensors for D Inferer """ return batchdata["hand"] # define networks disc_net = Discriminator(in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5).to(device) latent_size = 64 gen_net = Generator(latent_shape=latent_size, start_shape=(latent_size, 8, 8), channels=[32, 16, 8, 1], strides=[2, 2, 2, 1]) # initialize both networks disc_net.apply(normal_init) gen_net.apply(normal_init) # input images are scaled to [0,1] so enforce the same of generated outputs gen_net.conv.add_module("activation", torch.nn.Sigmoid()) gen_net = gen_net.to(device) # create optimizers and loss functions learning_rate = 2e-4 betas = (0.5, 0.999) disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas) gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas) disc_loss_criterion = torch.nn.BCELoss() gen_loss_criterion = torch.nn.BCELoss() real_label = 1 fake_label = 0 def discriminator_loss(gen_images, real_images): """ The discriminator loss is calculated by comparing D prediction for real and generated images. """ real = real_images.new_full((real_images.shape[0], 1), real_label) gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) realloss = disc_loss_criterion(disc_net(real_images), real) genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) return (genloss + realloss) / 2 def generator_loss(gen_images): """ The generator loss is calculated by determining how realistic the discriminator classifies the generated images. """ output = disc_net(gen_images) cats = output.new_full(output.shape, real_label) return gen_loss_criterion(output, cats) # initialize current run dir run_dir = "model_out" print("Saving model output to: %s " % run_dir) # create workflow handlers handlers = [ StatsHandler( name="batch_training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), CheckpointSaver( save_dir=run_dir, save_dict={ "g_net": gen_net, "d_net": disc_net }, save_interval=10, save_final=True, epoch_level=True, ), ] # define key metric key_train_metric = None # create adversarial trainer disc_train_steps = 5 num_epochs = 50 trainer = GanTrainer( device, num_epochs, real_dataloader, gen_net, gen_opt, generator_loss, disc_net, disc_opt, discriminator_loss, d_prepare_batch=prepare_batch, d_train_steps=disc_train_steps, latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=handlers, ) # run GAN training trainer.run() # Training completed, save a few random generated images. print("Saving trained generator sample output.") test_img_count = 10 test_latents = make_latent(test_img_count, latent_size).to(device) fakes = gen_net(test_latents) for i, image in enumerate(fakes): filename = "gen-fake-final-%d.png" % (i) save_path = os.path.join(run_dir, filename) img_array = image[0].cpu().data.numpy() png_writer.write_png(img_array, save_path, scale=255)
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def main(cfg): # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- # Create log/model dir log_dir = create_log_dir(cfg) # Set the logger logging.basicConfig( format="%(asctime)s %(levelname)2s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) log_name = os.path.join(log_dir, "logs.txt") logger = logging.getLogger() fh = logging.FileHandler(log_name) fh.setLevel(logging.INFO) logger.addHandler(fh) # Set TensorBoard summary writer writer = SummaryWriter(log_dir) # Save configs logging.info(json.dumps(cfg)) with open(os.path.join(log_dir, "config.json"), "w") as fp: json.dump(cfg, fp, indent=4) # Set device cuda/cpu device = set_device(cfg) # Set cudnn benchmark/deterministic if cfg["benchmark"]: torch.backends.cudnn.benchmark = True else: set_determinism(seed=0) # ------------------------------------------------------------------------- # Transforms and Datasets # ------------------------------------------------------------------------- # Pre-processing preprocess_cpu_train = None preprocess_gpu_train = None preprocess_cpu_valid = None preprocess_gpu_valid = None if cfg["backend"] == "cucim": preprocess_cpu_train = Compose([ToTensorD(keys="label")]) preprocess_gpu_train = Compose([ Range()(ToCupy()), Range("ColorJitter")(RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)), Range("RandomFlip")(RandCuCIM(name="image_flip", apply_prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90", prob=cfg["prob"], max_k=3, spatial_axis=(-2, -1))), Range()(CastToType(dtype=np.float32)), Range("RandomZoom")(RandCuCIM(name="rand_zoom", min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensor(device=device)), ]) preprocess_cpu_valid = Compose([ToTensorD(keys="label")]) preprocess_gpu_valid = Compose([ Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)), Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensor(device=device)), ]) elif cfg["backend"] == "numpy": preprocess_cpu_train = Compose([ Range()(ToTensorD(keys=("image", "label"))), Range("ColorJitter")(TorchVisionD( keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04, )), Range()(ToNumpyD(keys="image")), Range("RandomFlip")(RandFlipD(keys="image", prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandRotate90D(keys="image", prob=cfg["prob"])), Range()(CastToTypeD(keys="image", dtype=np.float32)), Range("RandomZoom")(RandZoomD(keys="image", prob=cfg["prob"], min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensorD(keys="image")), ]) preprocess_cpu_valid = Compose([ Range("ValidCastType")(CastToTypeD(keys="image", dtype=np.float32)), Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensorD(keys=("image", "label"))), ]) else: raise ValueError( f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]" ) # Post-processing postprocess = Compose([ Activations(sigmoid=True), AsDiscrete(threshold=0.5), ]) # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( data=train_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_train, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( data=valid_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_valid, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("First sample is None!") for d in ["image", "label"]: logging.info(f"[{d}] \n" f" {d} shape: {first_sample[d].shape}\n" f" {d} type: {type(first_sample[d])}\n" f" {d} dtype: {first_sample[d].dtype}") logging.info(f"Batch size: {cfg['batch_size']}") logging.info(f"[Training] number of batches: {len(train_dataloader)}") logging.info(f"[Validation] number of batches: {len(valid_dataloader)}") # ------------------------------------------------------------------------- # Deep Learning Model and Configurations # ------------------------------------------------------------------------- # Initialize model model = TorchVisionFCModel("resnet18", n_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # Loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # Optimizer if cfg["novograd"] is True: optimizer = Novograd(model.parameters(), lr=cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6) if cfg["amp"] is True: scaler = GradScaler() else: scaler = None # Learning rate scheduler if cfg["cos"] is True: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) else: scheduler = None # ------------------------------------------------------------------------- # Training/Evaluating # ------------------------------------------------------------------------- train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1} total_valid_time, total_train_time = 0.0, 0.0 t_start = time.perf_counter() metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1} # Training/Validation Loop for _ in range(cfg["n_epochs"]): t_epoch = time.perf_counter() logging.info( f"[Training] learning rate: {optimizer.param_groups[0]['lr']}") # Training with Range("Training Epoch"): train_counter = training( train_counter, model, loss_func, optimizer, scaler, cfg["amp"], train_dataloader, preprocess_gpu_train, postprocess, device, writer, cfg["print_step"], ) if scheduler is not None: scheduler.step() if cfg["save"]: torch.save( model.state_dict(), os.path.join(log_dir, f"model_epoch_{train_counter['epoch']}.pt")) t_train = time.perf_counter() train_time = t_train - t_epoch total_train_time += train_time # Validation if cfg["validate"]: with Range("Validation"): valid_loss, valid_acc = validation( model, loss_func, cfg["amp"], valid_dataloader, preprocess_gpu_valid, postprocess, device, cfg["print_step"], ) t_valid = time.perf_counter() valid_time = t_valid - t_train total_valid_time += valid_time if valid_loss < metric_summary["loss"]: metric_summary["loss"] = min(valid_loss, metric_summary["loss"]) metric_summary["accuracy"] = max(valid_acc, metric_summary["accuracy"]) metric_summary["best_epoch"] = train_counter["epoch"] writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"]) writer.add_scalar("valid/accuracy", valid_acc, train_counter["epoch"]) logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, " f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)" ) else: logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s" ) writer.flush() t_end = time.perf_counter() # Save final metrics metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"] metric_summary["total_time"] = t_end - t_start writer.add_hparams(hparam_dict=cfg, metric_dict=metric_summary, run_name=log_dir) writer.close() logging.info(f"Metric Summary: {metric_summary}") # Save the best and final model if cfg["validate"] is True: copyfile( os.path.join(log_dir, f"model_epoch_{metric_summary['best_epoch']}.pth"), os.path.join(log_dir, "model_best.pth"), ) copyfile( os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"), os.path.join(log_dir, "model_final.pth"), ) # Final prints logging.info( f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s " f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)", ) logging.info(f"Logs and model was saved at: {log_dir}")
num_test = int(len(all_filenames) * test_frac) num_train = len(all_filenames) - num_test train_datadict = [{"im": fname} for fname in all_filenames[:num_train]] test_datadict = [{"im": fname} for fname in all_testNames[:len(all_testNames)]] print(all_filenames) print(f"total number of images: {len(all_filenames)}") print(f"number of images for training: {len(train_datadict)}") print(f"number of images for testing: {len(test_datadict)}") train_transforms = Compose([ LoadImageD(keys=["im"]), AddChannelD(keys=["im"]), ScaleIntensityD(keys=["im"]), RandRotateD(keys=["im"], range_x=np.pi / 12, prob=0.5, keep_size=True), RandFlipD(keys=["im"], spatial_axis=0, prob=0.5), RandZoomD(keys=["im"], min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensorD(keys=["im"]) ]) test_transforms = Compose([ LoadImageD(keys=["im"]), AddChannelD(keys=["im"]), # ScaleIntensityD(keys=["im"]), ToTensorD(keys=["im"]) ]) batch_size = 100 # error if greater than 0, https://github.com/Project-MONAI/MONAI/pull/307 num_workers = 0 train_ds = CacheDataset(train_datadict,