def run_infer(weights_folder_path, cfg): cfg.pretrained = False # for local test, please modify the following path into actual path. cfg.data_folder = cfg.data_dir + "test/" to_device_transform = ToDeviced(keys=("input", "target", "mask", "is_annotated"), device=cfg.device) all_path = [] for path in glob.iglob(os.path.join(weights_folder_path, "*.pth")): all_path.append(path) nets = [] for path in all_path: state_dict = torch.load(path)["model"] new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k.replace("module.", "")] = v net = RanzcrNet(cfg).eval().to(cfg.device) net.load_state_dict(new_state_dict) del net.decoder del net.segmentation_head nets.append(net) test_df = pd.read_csv(cfg.test_df) test_dataset = get_test_dataset(test_df, cfg) test_dataloader = get_test_dataloader(test_dataset, cfg) with torch.no_grad(): fold_preds = [[] for i in range(len(nets))] for batch in tqdm(test_dataloader): batch = to_device_transform(batch) for i, net in enumerate(nets): if cfg.mixed_precision: with autocast(): logits = net(batch)["logits"].cpu().numpy() else: logits = net(batch)["logits"].cpu().numpy() fold_preds[i] += [logits] fold_preds = [np.concatenate(p) for p in fold_preds] preds = np.stack(fold_preds) preds = expit(preds) preds = np.mean(preds, axis=0) sub_df = test_df.copy() sub_df[cfg.label_cols] = preds submission = pd.read_csv(cfg.test_df) submission.loc[sub_df.index, cfg.label_cols] = sub_df[cfg.label_cols] submission.to_csv("submission.csv", index=False)
def test_value(self): device = "cuda:0" data = [{"img": torch.tensor(i)} for i in range(4)] dataset = CacheDataset(data=data, transform=ToDeviced(keys="img", device=device, non_blocking=True), cache_rate=1.0) dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1) for i, d in enumerate(dataloader): torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device))
def test_train_timing(self): images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:32], segs[:32])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-9:], segs[-9:])] device = torch.device("cuda:0") # define transforms for train and validation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), # change to execute transforms with Tensor data EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio # the image centers of negative samples # must be in valid image area RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(64, 64, 64), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=["image", "label"], prob=0.5), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True), RandRotated( keys=["image", "label"], prob=0.5, range_x=np.pi / 4, mode=("bilinear", "nearest"), align_corners=True, dtype=np.float64, ), RandAffined(keys=["image", "label"], prob=0.5, rotate_range=np.pi / 2, mode=("bilinear", "nearest")), RandGaussianNoised(keys="image", prob=0.5), RandStdShiftIntensityd(keys="image", prob=0.5, factors=0.05, nonzero=True), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ]) max_epochs = 5 learning_rate = 2e-4 val_interval = 1 # do validation for every epoch # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=5) # disable multi-workers because `ThreadDataLoader` works with multi-threads train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True) model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) # Novograd paper suggests to use a bigger LR than Adam, # because Adam does normalization by element-wise second moments optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) best_metric = -1 total_start = time.time() for epoch in range(max_epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 optimizer.zero_grad() # set AMP for training with torch.cuda.amp.autocast(): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" f" step time: {(time.time() - step_start):.4f}") epoch_loss /= step print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): for val_data in val_loader: roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation with torch.cuda.amp.autocast(): val_outputs = sliding_window_inference( val_data["image"], roi_size, sw_batch_size, model) val_outputs = [ post_pred(i) for i in decollate_batch(val_outputs) ] val_labels = [ post_label(i) for i in decollate_batch(val_data["label"]) ] dice_metric(y_pred=val_outputs, y=val_labels) metric = dice_metric.aggregate().item() dice_metric.reset() if metric > best_metric: best_metric = metric print( f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) total_time = time.time() - total_start print( f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}" ) # test expected metrics self.assertGreater(best_metric, 0.95)
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) # use amp to accelerate training scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.batch_size, shuffle=True) # validation transforms and dataset val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest"), ), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device=device), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) # ThreadDataLoader can be faster if no IO operations when caching all the data in memory val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=args.batch_size, shuffle=False) # create network, loss function and optimizer if args.network == "SegResNet": model = SegResNet( blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], init_filters=16, in_channels=4, out_channels=3, dropout_prob=0.0, ).to(device) else: model = UNet( spatial_dims=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = DiceFocalLoss( smooth_nr=1e-5, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True, batch=True, ) optimizer = Novograd(model.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) dice_metric = DiceMetric(include_background=True, reduction="mean") dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 print(f"time elapsed before training: {time.time() - total_start}") train_start = time.time() for epoch in range(args.epochs): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{args.epochs}") epoch_loss = train(train_loader, model, loss_function, optimizer, lr_scheduler, scaler) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, dice_metric, dice_metric_batch, post_trans) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 if dist.get_rank() == 0: torch.save(model.state_dict(), "best_metric_model.pth") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}," f" total train time: {(time.time() - train_start):.4f}") dist.destroy_process_group()
clip=True, ) ), Range()(CropForegroundd(keys=["image", "label"], source_key="image")), # pre-compute foreground and background indexes # and cache them to accelerate training Range("Indexing")( FgBgToIndicesd( keys="label", fg_postfix="_fg", bg_postfix="_bg", image_key="image", ) ), EnsureTyped(keys=["image", "label"]), ToDeviced(keys=["image", "label"], device="cuda:0"), Range("RandCrop")( RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, fg_indices_key="label_fg", bg_indices_key="label_bg", ) ) ] )
def main(cfg): os.makedirs(str(cfg.output_dir + f"/fold{cfg.fold}/"), exist_ok=True) # set random seed, works when use all data to train if cfg.seed < 0: cfg.seed = np.random.randint(1_000_000) set_seed(cfg.seed) # set dataset, dataloader train = pd.read_csv(cfg.train_df) if cfg.fold == -1: val_df = train[train["fold"] == 0] else: val_df = train[train["fold"] == cfg.fold] train_df = train[train["fold"] != cfg.fold] train_dataset = get_train_dataset(train_df, cfg) val_dataset = get_val_dataset(val_df, cfg) train_dataloader = get_train_dataloader(train_dataset, cfg) val_dataloader = get_val_dataloader(val_dataset, cfg) if cfg.train_val is True: train_val_dataset = get_val_dataset(train_df, cfg) train_val_dataloader = get_val_dataloader(train_val_dataset, cfg) to_device_transform = ToDeviced(keys=("input", "target", "mask", "is_annotated"), device=cfg.device) cfg.to_device_transform = to_device_transform # set model model = RanzcrNet(cfg) model.to(cfg.device) # set optimizer, lr scheduler total_steps = len(train_dataset) optimizer = get_optimizer(model, cfg) scheduler = get_scheduler(cfg, optimizer, total_steps) # set other tools if cfg.mixed_precision: scaler = GradScaler() else: scaler = None writer = SummaryWriter(str(cfg.output_dir + f"/fold{cfg.fold}/")) # train and val loop step = 0 i = 0 best_val_loss = np.inf optimizer.zero_grad() for epoch in range(cfg.epochs): print("EPOCH:", epoch) gc.collect() if cfg.train is True: run_train( model=model, train_dataloader=train_dataloader, optimizer=optimizer, scheduler=scheduler, cfg=cfg, scaler=scaler, writer=writer, epoch=epoch, iteration=i, step=step, ) if (epoch + 1) % cfg.eval_epochs == 0 or (epoch + 1) == cfg.epochs: val_loss = run_eval( model=model, val_dataloader=val_dataloader, cfg=cfg, writer=writer, epoch=epoch, ) if cfg.train_val is True: if (epoch + 1) % cfg.eval_train_epochs == 0 or (epoch + 1) == cfg.epochs: train_val_loss = run_eval(model, train_val_dataloader, cfg, writer, epoch) print(f"train_val_loss {train_val_loss:.5}") if val_loss < best_val_loss: print( f"SAVING CHECKPOINT: val_loss {best_val_loss:.5} -> {val_loss:.5}" ) best_val_loss = val_loss checkpoint = create_checkpoint( model, optimizer, epoch, scheduler=scheduler, scaler=scaler, ) torch.save( checkpoint, f"{cfg.output_dir}/fold{cfg.fold}/checkpoint_best_seed{cfg.seed}.pth", )