def test_seg_no_basedir(self): with tempfile.TemporaryDirectory() as tempdir: test_data = { "name": "Spleen", "description": "Spleen Segmentation", "labels": {"0": "background", "1": "spleen"}, "training": [ { "image": os.path.join(tempdir, "spleen_19.nii.gz"), "label": os.path.join(tempdir, "spleen_19.nii.gz"), }, { "image": os.path.join(tempdir, "spleen_31.nii.gz"), "label": os.path.join(tempdir, "spleen_31.nii.gz"), }, ], "test": [os.path.join(tempdir, "spleen_15.nii.gz"), os.path.join(tempdir, "spleen_23.nii.gz")], } json_str = json.dumps(test_data) file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) result = load_decathlon_datalist(file_path, True, "training", None) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) result = load_decathlon_datalist(file_path, True, "test", None) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))
def test_cls_values(self): with tempfile.TemporaryDirectory() as tempdir: test_data = { "name": "ChestXRay", "description": "Chest X-ray classification", "labels": { "0": "background", "1": "chest" }, "training": [{ "image": "chest_19.nii.gz", "label": 0 }, { "image": "chest_31.nii.gz", "label": 1 }], "test": ["chest_15.nii.gz", "chest_23.nii.gz"], } json_str = json.dumps(test_data) file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) result = load_decathlon_datalist(file_path, False, "training", tempdir) self.assertEqual(result[0]["image"], os.path.join(tempdir, "chest_19.nii.gz")) self.assertEqual(result[0]["label"], 0)
def test_content(self): with tempfile.TemporaryDirectory() as tempdir: datalist = [] for i in range(5): image = os.path.join(tempdir, f"test_image{i}.nii.gz") label = os.path.join(tempdir, f"test_label{i}.nii.gz") Path(image).touch() Path(label).touch() datalist.append({"image": image, "label": label}) filename = os.path.join(tempdir, "test_datalist.json") result = create_cross_validation_datalist( datalist=datalist, nfolds=5, train_folds=[0, 1, 2, 3], val_folds=4, train_key="test_train", val_key="test_val", filename=Path(filename), shuffle=True, seed=123, check_missing=True, keys=["image", "label"], root_dir=None, allow_missing_keys=False, raise_error=True, ) loaded = load_decathlon_datalist(filename, data_list_key="test_train") for r, l in zip(result["test_train"], loaded): self.assertEqual(r["image"], l["image"]) self.assertEqual(r["label"], l["label"])
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: # the types of the item in data list should be compatible with the dataloader dataset_dir = Path(dataset_dir) section = "training" if self.section in ["training", "validation" ] else "test" datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True, section) return self._split_datalist(datalist)
def get_task_params(args): """ This function is used to achieve the spacings of decathlon dataset. In addition, for CT images (task 03, 06, 07, 08, 09 and 10), this function also prints the mean and std values (used for normalization), and the min (0.5 percentile) and max(99.5 percentile) values (used for clip). """ task_id = args.task_id root_dir = args.root_dir datalist_path = args.datalist_path dataset_path = os.path.join(root_dir, task_name[task_id]) datalist_name = "dataset_task{}.json".format(task_id) # get all training data datalist = load_decathlon_datalist( os.path.join(datalist_path, datalist_name), True, "training", dataset_path ) # get modality info. properties = load_decathlon_properties( os.path.join(datalist_path, datalist_name), "modality" ) dataset = Dataset( data=datalist, transform=LoadImaged(keys=["image", "label"]), ) calculator = DatasetSummary(dataset, num_workers=4) target_spacing = calculator.get_target_spacing() print("spacing: ", target_spacing) if properties["modality"]["0"] == "CT": print("CT input, calculate statistics:") calculator.calculate_statistics() print("mean: ", calculator.data_mean, " std: ", calculator.data_std) calculator.calculate_percentiles( sampling_flag=True, interval=10, min_percentile=0.5, max_percentile=99.5 ) print( "min: ", calculator.data_min_percentile, " max: ", calculator.data_max_percentile, ) else: print("non CT input, skip calculating.")
def _generate_data_list(self, dataset_dir: str) -> List[Dict]: section = "training" if self.section in ["training", "validation" ] else "test" datalist = load_decathlon_datalist( os.path.join(dataset_dir, "dataset.json"), True, section) if section == "test": return datalist else: data = list() for i in datalist: self.randomize() if self.section == "training": if self.rann < self.val_frac: continue else: if self.rann >= self.val_frac: continue data.append(i) return data
def test_additional_items(self): with tempfile.TemporaryDirectory() as tempdir: with open(os.path.join(tempdir, "mask31.txt"), "w") as f: f.write("spleen31 mask") test_data = { "name": "Spleen", "description": "Spleen Segmentation", "labels": { "0": "background", "1": "spleen" }, "training": [ { "image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz", "mask": "spleen mask" }, { "image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz", "mask": "mask31.txt" }, ], "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], } json_str = json.dumps(test_data) file_path = os.path.join(tempdir, "test_data.json") with open(file_path, "w") as json_file: json_file.write(json_str) result = load_decathlon_datalist(file_path, True, "training", Path(tempdir)) self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) self.assertEqual(result[1]["mask"], os.path.join(tempdir, "mask31.txt")) self.assertEqual(result[0]["mask"], "spleen mask")
def _generate_data_list(self, dataset_dir: str) -> List[Dict]: section = "training" if self.section in ["training", "validation" ] else "test" datalist = load_decathlon_datalist( os.path.join(dataset_dir, "dataset.json"), True, section) return self._split_datalist(datalist)
def main(): print_config() # Define paths for running the script data_dir = os.path.normpath('/to/be/defined') json_path = os.path.normpath('/to/be/defined') logdir = os.path.normpath('/to/be/defined') # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used use_pretrained = 1 pretrained_path = os.path.normpath('/to/be/defined') # Training Hyper-parameters lr = 1e-4 max_iterations = 30000 eval_num = 100 if os.path.exists(logdir) == False: os.mkdir(logdir) # Training & Validation Transform chain train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged( keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandFlipd( keys=["image", "label"], spatial_axis=[0], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[1], prob=0.10, ), RandFlipd( keys=["image", "label"], spatial_axis=[2], prob=0.10, ), RandRotate90d( keys=["image", "label"], prob=0.10, max_k=3, ), RandShiftIntensityd( keys=["image"], offsets=0.10, prob=0.50, ), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest"), ), Orientationd(keys=["image", "label"], axcodes="RAS"), ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), CropForegroundd(keys=["image", "label"], source_key="image"), ToTensord(keys=["image", "label"]), ]) datalist = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="training") val_files = load_decathlon_datalist(base_dir=data_dir, data_list_file_path=json_path, is_segmentation=True, data_list_key="validation") train_ds = CacheDataset( data=datalist, transform=train_transforms, cache_num=24, cache_rate=1.0, num_workers=4, ) train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4) val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) case_num = 0 img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNETR( in_channels=1, out_channels=14, img_size=(96, 96, 96), feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed="conv", norm_name="instance", res_block=True, dropout_rate=0.0, ) # Load ViT backbone weights into UNETR if use_pretrained == 1: print('Loading Weights from the Path {}'.format(pretrained_path)) vit_dict = torch.load(pretrained_path) vit_weights = vit_dict['state_dict'] # Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias, # conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions # while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR) vit_weights.pop('conv3d_transpose_1.bias') vit_weights.pop('conv3d_transpose_1.weight') vit_weights.pop('conv3d_transpose.bias') vit_weights.pop('conv3d_transpose.weight') model.vit.load_state_dict(vit_weights) print('Pretrained Weights Succesfully Loaded !') elif use_pretrained == 0: print( 'No weights were loaded, all weights being used are randomly initialized!' ) model.to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True) torch.backends.cudnn.benchmark = True optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5) post_label = AsDiscrete(to_onehot=14) post_pred = AsDiscrete(argmax=True, to_onehot=14) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) global_step = 0 dice_val_best = 0.0 global_step_best = 0 epoch_loss_values = [] metric_values = [] def validation(epoch_iterator_val): model.eval() dice_vals = list() with torch.no_grad(): for step, batch in enumerate(epoch_iterator_val): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model) val_labels_list = decollate_batch(val_labels) val_labels_convert = [ post_label(val_label_tensor) for val_label_tensor in val_labels_list ] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [ post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list ] dice_metric(y_pred=val_output_convert, y=val_labels_convert) dice = dice_metric.aggregate().item() dice_vals.append(dice) epoch_iterator_val.set_description( "Validate (%d / %d Steps) (dice=%2.5f)" % (global_step, 10.0, dice)) dice_metric.reset() mean_dice_val = np.mean(dice_vals) return mean_dice_val def train(global_step, train_loader, dice_val_best, global_step_best): model.train() epoch_loss = 0 step = 0 epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True) for step, batch in enumerate(epoch_iterator): step += 1 x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_iterator.set_description( "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, max_iterations, loss)) if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations: epoch_iterator_val = tqdm( val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True) dice_val = validation(epoch_iterator_val) epoch_loss /= step epoch_loss_values.append(epoch_loss) metric_values.append(dice_val) if dice_val > dice_val_best: dice_val_best = dice_val global_step_best = global_step torch.save(model.state_dict(), os.path.join(logdir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}" .format(dice_val_best, dice_val)) plt.figure(1, (12, 6)) plt.subplot(1, 2, 1) plt.title("Iteration Average Loss") x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.subplot(1, 2, 2) plt.title("Val Mean Dice") x = [eval_num * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel("Iteration") plt.plot(x, y) plt.grid() plt.savefig( os.path.join(logdir, 'btcv_finetune_quick_update.png')) plt.clf() plt.close(1) global_step += 1 return global_step, dice_val_best, global_step_best while global_step < max_iterations: global_step, dice_val_best, global_step_best = train( global_step, train_loader, dice_val_best, global_step_best) model.load_state_dict( torch.load(os.path.join(logdir, "best_metric_model.pth"))) print(f"train completed, best_metric: {dice_val_best:.4f} " f"at iteration: {global_step_best}")
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def main_worker(gpu, args): args.gpu = gpu if args.distributed: args.rank = args.rank * torch.cuda.device_count() + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) print(args.rank, " gpu", args.gpu) torch.cuda.set_device( args.gpu ) # use this default device (same as args.device if not distributed) torch.backends.cudnn.benchmark = True if args.rank == 0: print("Batch size is:", args.batch_size, "epochs", args.epochs) ############# # Create MONAI dataset training_list = load_decathlon_datalist( data_list_file_path=args.dataset_json, data_list_key="training", base_dir=args.data_root, ) validation_list = load_decathlon_datalist( data_list_file_path=args.dataset_json, data_list_key="validation", base_dir=args.data_root, ) if args.quick: # for debugging on a small subset training_list = training_list[:16] validation_list = validation_list[:16] train_transform = Compose([ LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), TileOnGridd( keys=["image"], tile_count=args.tile_count, tile_size=args.tile_size, random_offset=True, background_val=255, return_list_of_dicts=True, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image"], spatial_axis=1, prob=0.5), RandRotate90d(keys=["image"], prob=0.5), ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), ]) valid_transform = Compose([ LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), TileOnGridd( keys=["image"], tile_count=None, tile_size=args.tile_size, random_offset=False, background_val=255, return_list_of_dicts=True, ), ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), ]) dataset_train = Dataset(data=training_list, transform=train_transform) dataset_valid = Dataset(data=validation_list, transform=valid_transform) train_sampler = DistributedSampler( dataset_train) if args.distributed else None val_sampler = DistributedSampler( dataset_valid, shuffle=False) if args.distributed else None train_loader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, multiprocessing_context="spawn", sampler=train_sampler, collate_fn=list_data_collate, ) valid_loader = torch.utils.data.DataLoader( dataset_valid, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, multiprocessing_context="spawn", sampler=val_sampler, collate_fn=list_data_collate, ) if args.rank == 0: print("Dataset training:", len(dataset_train), "validation:", len(dataset_valid)) model = milmodel.MILModel(num_classes=args.num_classes, pretrained=True, mil_mode=args.mil_mode) best_acc = 0 start_epoch = 0 if args.checkpoint is not None: checkpoint = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(checkpoint["state_dict"]) if "epoch" in checkpoint: start_epoch = checkpoint["epoch"] if "best_acc" in checkpoint: best_acc = checkpoint["best_acc"] print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format( args.checkpoint, start_epoch, best_acc)) model.cuda(args.gpu) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], output_device=args.gpu) if args.validate: # if we only want to validate existing checkpoint epoch_time = time.time() val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=0, args=args, max_tiles=args.tile_count) if args.rank == 0: print( "Final validation loss: {:.4f}".format(val_loss), "acc: {:.4f}".format(val_acc), "qwk: {:.4f}".format(qwk), "time {:.2f}s".format(time.time() - epoch_time), ) exit(0) params = model.parameters() if args.mil_mode in ["att_trans", "att_trans_pyramid"]: m = model if not args.distributed else model.module params = [ { "params": list(m.attention.parameters()) + list(m.myfc.parameters()) + list(m.net.parameters()) }, { "params": list(m.transformer.parameters()), "lr": 6e-6, "weight_decay": 0.1 }, ] optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0) if args.logdir is not None and args.rank == 0: writer = SummaryWriter(log_dir=args.logdir) if args.rank == 0: print("Writing Tensorboard logs to ", writer.log_dir) else: writer = None ###RUN TRAINING n_epochs = args.epochs val_acc_max = 0.0 scaler = None if args.amp: # new native amp scaler = GradScaler() for epoch in range(start_epoch, n_epochs): if args.distributed: train_sampler.set_epoch(epoch) torch.distributed.barrier() print(args.rank, time.ctime(), "Epoch:", epoch) epoch_time = time.time() train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args) if args.rank == 0: print( "Final training {}/{}".format(epoch, n_epochs - 1), "loss: {:.4f}".format(train_loss), "acc: {:.4f}".format(train_acc), "time {:.2f}s".format(time.time() - epoch_time), ) if args.rank == 0 and writer is not None: writer.add_scalar("train_loss", train_loss, epoch) writer.add_scalar("train_acc", train_acc, epoch) if args.distributed: torch.distributed.barrier() b_new_best = False val_acc = 0 if (epoch + 1) % args.val_every == 0: epoch_time = time.time() val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=epoch, args=args, max_tiles=args.tile_count) if args.rank == 0: print( "Final validation {}/{}".format(epoch, n_epochs - 1), "loss: {:.4f}".format(val_loss), "acc: {:.4f}".format(val_acc), "qwk: {:.4f}".format(qwk), "time {:.2f}s".format(time.time() - epoch_time), ) if writer is not None: writer.add_scalar("val_loss", val_loss, epoch) writer.add_scalar("val_acc", val_acc, epoch) writer.add_scalar("val_qwk", qwk, epoch) val_acc = qwk if val_acc > val_acc_max: print("qwk ({:.6f} --> {:.6f})".format( val_acc_max, val_acc)) val_acc_max = val_acc b_new_best = True if args.rank == 0 and args.logdir is not None: save_checkpoint(model, epoch, args, best_acc=val_acc, filename="model_final.pt") if b_new_best: print("Copying to model.pt new best model!!!!") shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt")) scheduler.step() print("ALL DONE")
def main(cfg): # ------------------------------------------------------------------------- # Configs # ------------------------------------------------------------------------- # Create log/model dir log_dir = create_log_dir(cfg) # Set the logger logging.basicConfig( format="%(asctime)s %(levelname)2s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) log_name = os.path.join(log_dir, "logs.txt") logger = logging.getLogger() fh = logging.FileHandler(log_name) fh.setLevel(logging.INFO) logger.addHandler(fh) # Set TensorBoard summary writer writer = SummaryWriter(log_dir) # Save configs logging.info(json.dumps(cfg)) with open(os.path.join(log_dir, "config.json"), "w") as fp: json.dump(cfg, fp, indent=4) # Set device cuda/cpu device = set_device(cfg) # Set cudnn benchmark/deterministic if cfg["benchmark"]: torch.backends.cudnn.benchmark = True else: set_determinism(seed=0) # ------------------------------------------------------------------------- # Transforms and Datasets # ------------------------------------------------------------------------- # Pre-processing preprocess_cpu_train = None preprocess_gpu_train = None preprocess_cpu_valid = None preprocess_gpu_valid = None if cfg["backend"] == "cucim": preprocess_cpu_train = Compose([ToTensorD(keys="label")]) preprocess_gpu_train = Compose([ Range()(ToCupy()), Range("ColorJitter")(RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)), Range("RandomFlip")(RandCuCIM(name="image_flip", apply_prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90", prob=cfg["prob"], max_k=3, spatial_axis=(-2, -1))), Range()(CastToType(dtype=np.float32)), Range("RandomZoom")(RandCuCIM(name="rand_zoom", min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensor(device=device)), ]) preprocess_cpu_valid = Compose([ToTensorD(keys="label")]) preprocess_gpu_valid = Compose([ Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)), Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensor(device=device)), ]) elif cfg["backend"] == "numpy": preprocess_cpu_train = Compose([ Range()(ToTensorD(keys=("image", "label"))), Range("ColorJitter")(TorchVisionD( keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04, )), Range()(ToNumpyD(keys="image")), Range("RandomFlip")(RandFlipD(keys="image", prob=cfg["prob"], spatial_axis=-1)), Range("RandomRotate90")(RandRotate90D(keys="image", prob=cfg["prob"])), Range()(CastToTypeD(keys="image", dtype=np.float32)), Range("RandomZoom")(RandZoomD(keys="image", prob=cfg["prob"], min_zoom=0.9, max_zoom=1.1)), Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range()(ToTensorD(keys="image")), ]) preprocess_cpu_valid = Compose([ Range("ValidCastType")(CastToTypeD(keys="image", dtype=np.float32)), Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)), Range("ValidToTensor")(ToTensorD(keys=("image", "label"))), ]) else: raise ValueError( f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]" ) # Post-processing postprocess = Compose([ Activations(sigmoid=True), AsDiscrete(threshold=0.5), ]) # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( data=train_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_train, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( data=valid_json_info_list, region_size=cfg["region_size"], grid_shape=cfg["grid_shape"], patch_size=cfg["patch_size"], transform=preprocess_cpu_valid, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=cfg["pin"]) # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("First sample is None!") for d in ["image", "label"]: logging.info(f"[{d}] \n" f" {d} shape: {first_sample[d].shape}\n" f" {d} type: {type(first_sample[d])}\n" f" {d} dtype: {first_sample[d].dtype}") logging.info(f"Batch size: {cfg['batch_size']}") logging.info(f"[Training] number of batches: {len(train_dataloader)}") logging.info(f"[Validation] number of batches: {len(valid_dataloader)}") # ------------------------------------------------------------------------- # Deep Learning Model and Configurations # ------------------------------------------------------------------------- # Initialize model model = TorchVisionFCModel("resnet18", n_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # Loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # Optimizer if cfg["novograd"] is True: optimizer = Novograd(model.parameters(), lr=cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6) if cfg["amp"] is True: scaler = GradScaler() else: scaler = None # Learning rate scheduler if cfg["cos"] is True: scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) else: scheduler = None # ------------------------------------------------------------------------- # Training/Evaluating # ------------------------------------------------------------------------- train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1} total_valid_time, total_train_time = 0.0, 0.0 t_start = time.perf_counter() metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1} # Training/Validation Loop for _ in range(cfg["n_epochs"]): t_epoch = time.perf_counter() logging.info( f"[Training] learning rate: {optimizer.param_groups[0]['lr']}") # Training with Range("Training Epoch"): train_counter = training( train_counter, model, loss_func, optimizer, scaler, cfg["amp"], train_dataloader, preprocess_gpu_train, postprocess, device, writer, cfg["print_step"], ) if scheduler is not None: scheduler.step() if cfg["save"]: torch.save( model.state_dict(), os.path.join(log_dir, f"model_epoch_{train_counter['epoch']}.pt")) t_train = time.perf_counter() train_time = t_train - t_epoch total_train_time += train_time # Validation if cfg["validate"]: with Range("Validation"): valid_loss, valid_acc = validation( model, loss_func, cfg["amp"], valid_dataloader, preprocess_gpu_valid, postprocess, device, cfg["print_step"], ) t_valid = time.perf_counter() valid_time = t_valid - t_train total_valid_time += valid_time if valid_loss < metric_summary["loss"]: metric_summary["loss"] = min(valid_loss, metric_summary["loss"]) metric_summary["accuracy"] = max(valid_acc, metric_summary["accuracy"]) metric_summary["best_epoch"] = train_counter["epoch"] writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"]) writer.add_scalar("valid/accuracy", valid_acc, train_counter["epoch"]) logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, " f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)" ) else: logging.info( f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s" ) writer.flush() t_end = time.perf_counter() # Save final metrics metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"] metric_summary["total_time"] = t_end - t_start writer.add_hparams(hparam_dict=cfg, metric_dict=metric_summary, run_name=log_dir) writer.close() logging.info(f"Metric Summary: {metric_summary}") # Save the best and final model if cfg["validate"] is True: copyfile( os.path.join(log_dir, f"model_epoch_{metric_summary['best_epoch']}.pth"), os.path.join(log_dir, "model_best.pth"), ) copyfile( os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"), os.path.join(log_dir, "model_final.pth"), ) # Final prints logging.info( f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s " f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)", ) logging.info(f"Logs and model was saved at: {log_dir}")
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def get_data(args, batch_size=1, mode="train"): # get necessary parameters: fold = args.fold task_id = args.task_id root_dir = args.root_dir datalist_path = args.datalist_path dataset_path = os.path.join(root_dir, task_name[task_id]) transform_params = (args.pos_sample_num, args.neg_sample_num, args.num_samples) multi_gpu_flag = args.multi_gpu transform = get_task_transforms(mode, task_id, *transform_params) if mode == "test": list_key = "test" else: list_key = "{}_fold{}".format(mode, fold) datalist_name = "dataset_task{}.json".format(task_id) property_keys = [ "name", "description", "reference", "licence", "tensorImageSize", "modality", "labels", "numTraining", "numTest", ] datalist = load_decathlon_datalist( os.path.join(datalist_path, datalist_name), True, list_key, dataset_path) properties = load_decathlon_properties( os.path.join(datalist_path, datalist_name), property_keys) if mode in ["validation", "test"]: if multi_gpu_flag: datalist = partition_dataset( data=datalist, shuffle=False, num_partitions=dist.get_world_size(), even_divisible=False, )[dist.get_rank()] val_ds = CacheDataset( data=datalist, transform=transform, num_workers=4, ) data_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=args.val_num_workers, ) elif mode == "train": if multi_gpu_flag: datalist = partition_dataset( data=datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=datalist, transform=transform, num_workers=8, cache_rate=args.cache_rate, ) data_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=args.train_num_workers, drop_last=True, ) else: raise ValueError(f"mode should be train, validation or test.") return properties, data_loader