def update(self, y_pred, y, batched=True): if not batched: y_pred = y_pred[None] y = y[None] score = compute_meandice(y_pred=y_pred, y=y, include_background=False).mean() self.data.append(score.item())
def run_inference_test(root_dir, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=int) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch( val_outputs, { "filename_or_obj": val_data["img.filename_or_obj"], "affine": val_data["img.affine"] }) metric = metric_sum / metric_count return metric
def update(self, y_pred, y): y_pred = y_pred if torch.is_tensor(y_pred) else torch.from_numpy( y_pred) y = y if torch.is_tensor(y) else torch.from_numpy(y) score = compute_meandice(y_pred=y_pred, y=y, include_background=True).mean().item() if not math.isnan(score): self.data.append(score)
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) device = torch.device('cuda:0') model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 saver = NiftiSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir)
def validation_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] roi_size = PATCH_SIZE sw_batch_size = 1 outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward) loss = self.loss_function(outputs, labels) outputs = self.post_pred(outputs) labels = self.post_label(labels) value = compute_meandice(y_pred=outputs, y=labels, include_background=False) return {"val_loss": loss, "val_dice": value}
def MeanDice(model, data_loader, device): metric_sum = 0.0 metric_count = 0 for data in data_loader: inputs, labels = ( data["image"].to(device), data["label"].to(device), ) outputs = model(inputs) value = compute_meandice(outputs, inputs, sigmoid=True, logit_thresh=0.5) metric_count += len(value) metric_sum += value.sum().item() return metric_sum / metric_count
def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len( output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output scores = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive, self.add_sigmoid, self.logit_thresh) # add all items in current batch for batch in scores: not_nan = ~torch.isnan(batch) if not_nan.sum() == 0: continue class_avg = batch[not_nan].mean().item() self._sum += class_avg self._num_examples += 1
def test_nans(self, input_data, expected_value): result = compute_meandice(**input_data) self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
def test_value(self, input_data, expected_value): result = compute_meandice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print( "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1)
metric_sum = 0.0 metric_count = 0 for val_data in val_loader: val_inputs, val_labels = ( val_data['image'].to(device), val_data['label'].to(device), ) roi_size = PATCH_SIZE sw_batch_size = 1 val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=False, to_onehot_y=True, mutually_exclusive=True, ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), output_path / 'best_metric_model.pth') print('saved new best metric model') print(
def inference(data_loader, model, criterion, model_name): "Network ready for the inference step. Plots the predicted LA mask results together with the ground truth mask. " total_batch = len(data_loader) test_loss = average_metrics() dice, hausdorff = average_metrics(), average_metrics() # Softmax post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) # set model eval mode! model.eval() results = [] dice_collection, hasudorff_collection = [], [] start2 = time.time() for batch_idx, (data, y) in enumerate(data_loader): result = None start = time.time() data = data.to(device) data = data.type(torch.cuda.FloatTensor) y = y.to(device) # Get prediction out = model(data.to(device)) # Get loss loss = criterion(out.to(device), y.to(device)) # Backpropagate error loss.backward() # Update loss test_loss.update(loss.item()) # Evaluation outputs = post_pred(out.to(device)) labels = post_label(y.to(device)) # Post-processing post_processing = remove_small_objects( torch.argmax(out, dim=1).detach().cpu()[0, :, :, :]) # Metrice dice.update( compute_meandice( y_pred=post_processing.to(device), y=labels, include_background=False, ).item()) hausdorff.update( compute_hausdorff_distance(y_pred=post_processing.to(device), y=labels, distance_metric='euclidean').item()) dice_collection.append(dice.val) hasudorff_collection.append(hausdorff.val) print( f'Iteration {(batch_idx + 1)}/{total_batch} - Loss: {test_loss.val} -Dice: {dice.val} - Hausdorff: {hausdorff.val} ' ) end = time.time() result = [ batch_idx + 1, test_loss.val, dice.val, hausdorff.val, end - start ] results.append(result) # PLOTS plt.figure("check", (10, 5)) slice = out.shape[4] // 2 ground_truth = rotate(y.detach().cpu()[0, 0, :, :, slice], 90) mri_image = rotate(data.detach().cpu()[0, 0, :, :, slice], 90) predicted = rotate( torch.argmax(post_processing, dim=1).detach().cpu()[0, :, :, slice], 90) custom = matplotlib.colors.ListedColormap(['gray', 'red']) # LGE - MRI plt.subplot(1, 3, 1) plt.title(f'MRI test {batch_idx+1}') plt.axis('off') plt.imshow(mri_image, cmap="gray") # Ground truth mask + LGE-MRI plt.subplot(1, 3, 2) plt.axis('off') plt.title('Ground Truth') plt.imshow(mri_image, cmap="gray") plt.imshow(ground_truth, alpha=0.4, cmap=custom) # Predicted mask + LGE-MRI plt.subplot(1, 3, 3) plt.title(f'DSC:{round(dice.val,3)} - HD:{round(hausdorff.val,2)}') plt.axis('off') plt.imshow(mri_image, cmap="gray") plt.imshow(predicted, alpha=0.4, cmap=custom) plt.show() # Save mri, gt masks and mask predictions into VTK format y_predicted_array = torch.argmax( post_processing, dim=1).detach().cpu().numpy().astype('float')[0, :, :, :] y_true_array = y.detach().cpu().numpy()[0, 0, :, :, :] mri_array = data.detach().cpu().numpy()[0, 0, :, :, :] file = f'patient_{batch_idx+1}.vtk' generate_vtk_from_numpy( y_predicted_array, 'test_results/' + model_name + '/predicted_' + file) generate_vtk_from_numpy(y_true_array, 'test_results/' + model_name + '/true_' + file) generate_vtk_from_numpy(mri_array, 'test_results/' + model_name + '/mri_' + file) print( f'Test loss:{test_loss.avg} - Dice: {dice.avg} +/- {np.std(dice_collection)} - Hausdorff:{hausdorff.avg} +/- {np.std(hasudorff_collection)}' ) end2 = time.time() # Save metric results into csv file results.append( ['average', test_loss.avg, dice.avg, hausdorff.val, end2 - start2]) results_df = pd.DataFrame( results, columns=["test num", "Loss", "Dice", "Hausdorff", 'Time']) results_df.to_csv('test_results/' + model_name + '/metrics_vals.csv', index=False)
def __call__(self, engine: Engine): batch_data = engine.state.batch output_data = engine.state.output device = engine.state.device tag = "" if torch.distributed.is_initialized(): tag = "r{}-".format(torch.distributed.get_rank()) for bidx in range(len(batch_data.get("image"))): step = engine.state.iteration region = batch_data.get("region")[bidx] region = region.item() if torch.is_tensor(region) else region image = batch_data["image"][bidx][0].detach().cpu().numpy()[ np.newaxis] label = batch_data["label"][bidx].detach().cpu().numpy() pred = output_data["pred"][bidx].detach().cpu().numpy() dice = compute_meandice( y_pred=output_data["pred"][bidx][None].to(device), y=batch_data["label"][bidx][None].to(device), include_background=False, ).mean() if self.save_np: np.savez( os.path.join( self.output_dir, "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format( tag, region, step, bidx, dice), ), image, label, pred, ) if self.images and len(image.shape) == 3: img = make_grid(torch.from_numpy( rescale_array(image, 0, 1)[0])) lab = make_grid(torch.from_numpy( rescale_array(label, 0, 1)[0])) pos = rescale_array( output_data["image"][bidx][1].detach().cpu().numpy()[ np.newaxis], 0, 1)[0] neg = rescale_array( output_data["image"][bidx][2].detach().cpu().numpy()[ np.newaxis], 0, 1)[0] pre = make_grid( torch.from_numpy( np.array([rescale_array(pred, 0, 1)[0], pos, neg]))) torchvision.utils.save_image( tensor=[img, lab, pre], nrow=3, pad_value=2, fp=os.path.join( self.output_dir, "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png". format(tag, region, step, bidx, dice), ), ) if self.images and len(image.shape) == 4: samples = { "image": image[0], "label": label[0], "pred": pred[0] } for sample in samples: img = np.moveaxis(samples[sample], -3, -1) img = nib.Nifti1Image(img, np.eye(4)) nib.save( img, os.path.join( self.output_dir, "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format( tag, sample, step, bidx, dice)), )
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)
def main(): """ Read input and configuration parameters """ parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] sliding_window_validation = config_info['training']['sliding_window_validation'] # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving # model saving out_model_dir = os.path.join(config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) max_nr_models_saved = config_info['output']['max_nr_models_saved'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) """ Data Preparation """ # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_train_data = monai.utils.misc.first(train_loader) print("Training data tensor shapes") print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) check_valid_data = monai.utils.misc.first(val_loader) print("Validation data tensor shapes") print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() """ Training loop """ # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) net.to(device) for epoch in range(nr_train_epochs): print('-' * 10) print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs)) net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device) opt.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() opt.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item())) writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % validation_every_n_epochs == 0: net.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None check_tot_validation = 0 for val_data in val_loader: check_tot_validation += 1 val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) if sliding_window_validation: print('Running sliding window validation') roi_size = (96, 96, 1) val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() else: print('Running 2D validation') # compute validation val_outputs = net(val_images) value = 1.0 - loss_function(val_outputs, val_labels) metric_count += 1 metric_sum += value.item() print("Total number of data in validation: %d" % check_tot_validation) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth')) print('saved new best metric model') print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) epoch_len = len(train_ds) // train_loader.batch_size writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step) writer_valid.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output') print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) writer_train.close() writer_valid.close()
def train_process(fast=False): epoch_num = 10 val_interval = 1 train_trans, val_trans = transformations() train_ds = Dataset(data=train_files, transform=train_trans) val_ds = Dataset(data=val_files, transform=val_trans) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True) val_loader = DataLoader(val_ds, batch_size=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n1 = 16 model = UNet(dimensions=3, in_channels=1, out_channels=2, channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16), strides=(2, 2, 2, 2)).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = list() metric_values = list() for epoch in range(epoch_num): print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['image'].to( device), batch_data['label'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 for val_data in val_loader: val_inputs, val_labels = val_data['image'].to( device), val_data['label'].to(device) val_outputs = model(val_inputs) val_outputs = post_pred(val_outputs) val_labels = post_label(val_labels) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric epochs_no_improve = 0 best_metric_epoch = epoch + 1 best_metrics_epochs_and_time[0].append(best_metric) best_metrics_epochs_and_time[1].append(best_metric_epoch) torch.save(model.state_dict(), 'sLUMRTL644.pth') else: epochs_no_improve += 1 print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
def main(): parser = argparse.ArgumentParser(description="training") parser.add_argument( "--checkpoint", type=str, default=None, help="checkpoint full path", ) parser.add_argument( "--factor_ram_cost", default=0.0, type=float, help="factor to determine RAM cost in the searched architecture", ) parser.add_argument( "--fold", action="store", required=True, help="fold index in N-fold cross-validation", ) parser.add_argument( "--json", action="store", required=True, help="full path of .json file", ) parser.add_argument( "--json_key", action="store", required=True, help="selected key in .json data list", ) parser.add_argument( "--local_rank", required=int, help="local process rank", ) parser.add_argument( "--num_folds", action="store", required=True, help="number of folds in cross-validation", ) parser.add_argument( "--output_root", action="store", required=True, help="output root", ) parser.add_argument( "--root", action="store", required=True, help="data root", ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): os.makedirs(args.output_root, exist_ok=True) amp = True determ = True factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 learning_rate = 0.025 learning_rate_arch = 0.001 learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 num_epochs = 1430 # around 20k iteration num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 output_classes = 3 overlap_ratio = 0.625 patch_size = (96, 96, 96) patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] print("factor_ram_cost", factor_ram_cost) # deterministic training if determ: set_determinism(seed=0) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: json_data = json.load(f) split = len(json_data[args.json_key]) // num_folds list_train = json_data[args.json_key][:( split * fold)] + json_data[args.json_key][(split * (fold + 1)):] list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] # training data files = [] for _i in range(len(list_train)): str_img = os.path.join(args.root, list_train[_i]["image"]) str_seg = os.path.join(args.root, list_train[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) train_files = files random.shuffle(train_files) train_files_w = train_files[:len(train_files) // 2] train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_w:", len(train_files_w)) train_files_a = train_files[len(train_files) // 2:] train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files_a:", len(train_files_a)) # validation data files = [] for _i in range(len(list_valid)): str_img = os.path.join(args.root, list_valid[_i]["image"]) str_seg = os.path.join(args.root, list_valid[_i]["label"]) if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): continue files.append({"image": str_img, "label": str_seg}) val_files = files val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), CopyItemsd(keys=["label"], times=1, names=["label4crop"]), Lambdad( keys=["label4crop"], func=lambda x: np.concatenate(tuple([ ndimage.binary_dilation( (x == _k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes) ]), axis=0), overwrite=True, ), EnsureTyped(keys=["image", "label"]), CastToTyped(keys=["image"], dtype=(torch.float32)), SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), RandCropByLabelClassesd(keys=["image", "label"], label_key="label4crop", num_classes=output_classes, ratios=[ 1, ] * output_classes, spatial_size=patch_size, num_samples=num_patches_per_image), Lambdad(keys=["label4crop"], func=lambda x: 0), RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.15), RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), CastToTyped(keys=["image"], dtype=(torch.float32)), ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), EnsureTyped(keys=["image", "label"]), ToTensord(keys=["image", "label"]) ]) train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, num_depths=4, use_downsample=True, device=device, ) model = monai.networks.nets.DiNTS( dints_space=dints_space, in_channels=input_channels, num_classes=output_classes, use_downsample=True, ) model = model.to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) post_pred = Compose( [EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) # loss function loss_func = monai.losses.DiceCELoss( include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch=True, smooth_nr=0.00001, smooth_dr=0.00001, ) # optimizer optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() if torch.cuda.device_count() > 1: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) if args.checkpoint != None and os.path.isfile(args.checkpoint): print("[info] fine-tuning pre-trained checkpoint {0:s}".format( args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) torch.cuda.empty_cache() else: print("[info] training from scratch") # amp if amp: from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() if dist.get_rank() == 0: print("[info] amp enabled") # start a typical PyTorch training val_interval = num_epochs_per_validation best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() idx_iter = 0 metric_values = list() if dist.get_rank() == 0: writer = SummaryWriter( log_dir=os.path.join(args.output_root, "Events")) with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") dataloader_a_iterator = iter(train_loader_a) start_time = time.time() for epoch in range(num_epochs): decay = 0.5**np.sum([ (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones ]) lr = learning_rate * decay for param_group in optimizer.param_groups: param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) print(f"epoch {epoch + 1}/{num_epochs}") print("learning rate is set to {}".format(lr)) model.train() epoch_loss = 0 loss_torch = torch.zeros(2, dtype=torch.float, device=device) epoch_loss_arch = 0 loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) step = 0 for batch_data in train_loader_w: step += 1 inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = True else: for _ in model.module.weight_parameters(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False optimizer.zero_grad() if amp: with autocast(): outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: outputs = model(inputs) if output_classes == 2: loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) else: loss = loss_func(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() loss_torch[0] += loss.item() loss_torch[1] += 1.0 epoch_len = len(train_loader_w) idx_iter += 1 if dist.get_rank() == 0: print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) if epoch < num_epochs_warmup: continue try: sample_a = next(dataloader_a_iterator) except StopIteration: dataloader_a_iterator = iter(train_loader_a) sample_a = next(dataloader_a_iterator) inputs_search, labels_search = sample_a["image"].to( device), sample_a["label"].to(device) if world_size == 1: for _ in model.weight_parameters(): _.requires_grad = False else: for _ in model.module.weight_parameters(): _.requires_grad = False dints_space.log_alpha_a.requires_grad = True dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss entropy_alpha_c = torch.tensor(0.).to(device) entropy_alpha_a = torch.tensor(0.).to(device) ram_cost_full = torch.tensor(0.).to(device) ram_cost_usage = torch.tensor(0.).to(device) ram_cost_loss = torch.tensor(0.).to(device) topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() topology_loss = dints_space.get_topology_entropy(probs_a) ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() combination_weights = (epoch - num_epochs_warmup) / ( num_epochs - num_epochs_warmup) if amp: with autocast(): outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() scaler.step(arch_optimizer_a) scaler.step(arch_optimizer_c) scaler.update() else: outputs_search = model(inputs_search) if output_classes == 2: loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) else: loss = loss_func(outputs_search, labels_search) loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() arch_optimizer_a.step() arch_optimizer_c.step() epoch_loss_arch += loss.item() loss_torch_arch[0] += loss.item() loss_torch_arch[1] += 1.0 if dist.get_rank() == 0: print( "[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) loss_torch = loss_torch.tolist() loss_torch_arch = loss_torch_arch.tolist() if dist.get_rank() == 0: loss_torch_epoch = loss_torch[0] / loss_torch[1] print( f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if epoch >= num_epochs_warmup: loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] print( f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" ) if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() model.eval() with torch.no_grad(): metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) metric_sum = 0.0 metric_count = 0 metric_mat = [] val_images = None val_labels = None val_outputs = None _index = 0 for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) roi_size = patch_size_valid sw_batch_size = num_sw_batch_size if amp: with torch.cuda.amp.autocast(): pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) else: pred = sliding_window_inference( val_images, roi_size, sw_batch_size, lambda x: model(x), mode="gaussian", overlap=overlap_ratio, ) val_outputs = pred val_outputs = post_pred(val_outputs[0, ...]) val_outputs = val_outputs[None, ...] val_labels = post_label(val_labels[0, ...]) val_labels = val_labels[None, ...] value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) print(_index + 1, "/", len(val_loader), value) metric_count += len(value) metric_sum += value.sum().item() metric_vals = value.cpu().numpy() if len(metric_mat) == 0: metric_mat = metric_vals else: metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) for _c in range(output_classes - 1): val0 = torch.nan_to_num(value[0, _c], nan=0.0) val1 = 1.0 - torch.isnan(value[0, 0]).float() metric[2 * _c] += val0 * val1 metric[2 * _c + 1] += val1 _index += 1 # synchronizes all processes and reduce results dist.barrier() dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) metric = metric.tolist() if dist.get_rank() == 0: for _c in range(output_classes - 1): print( "evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) avg_metric = 0 for _c in range(output_classes - 1): avg_metric += metric[2 * _c] / metric[2 * _c + 1] avg_metric = avg_metric / float(output_classes - 1) print("avg_metric", avg_metric) if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 best_metric_iterations = idx_iter node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode( ) torch.save( { "node_a": node_a_d, "arch_code_a": arch_code_a_d, "arch_code_a_max": arch_code_a_max_d, "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, "best_path": best_metric_iterations, }, os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), ) print("saved new best metric model") dict_file = {} dict_file["best_avg_dice_score"] = float(best_metric) dict_file["best_avg_dice_score_epoch"] = int( best_metric_epoch) dict_file["best_avg_dice_score_iteration"] = int(idx_iter) with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: documents = yaml.dump(dict_file, stream=out_file) print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, avg_metric, best_metric, best_metric_epoch)) current_time = time.time() elapsed_time = (current_time - start_time) / 60.0 with open( os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: f.write( "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n" .format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) dist.barrier() torch.cuda.empty_cache() print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) if dist.get_rank() == 0: writer.close() dist.destroy_process_group() return
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) device = torch.device("cuda:0") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch( val_outputs, { "filename_or_obj": val_data["img.filename_or_obj"], "affine": val_data["img.affine"] }) metric = metric_sum / metric_count print("evaluation metric:", metric) shutil.rmtree(tempdir)
def train(): """ :return: """ print('Model training started') set_determinism(seed=0) epoch_num = params['nb_epoch'] val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") writer.add_scalar("epoch_loss", epoch_loss, epoch + 1) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) val_outputs = model(val_inputs) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=False, to_onehot_y=True, mutually_exclusive=True, ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( model, os.path.join('saved_models', "best_metric_model.pth")) print("saved new best metric model") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last validation subject as GIF in TensorBoard with the corresponding image, label and pred val_pred = torch.argmax(val_outputs, dim=1, keepdim=True) summary_img = torch.cat((val_inputs, val_labels, val_pred), dim=3) plot_2d_or_3d_image(summary_img, epoch + 1, writer, tag='last_val_subject') # Model checkpointing if (epoch + 1) % 20 == 0: torch.save( model, os.path.join('saved_models', params['f_name'] + '_' + str(epoch + 1) + '.pth')) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) roi_size = (160, 160, 160) sw_batch_size = 4 val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model) val_outputs = post_pred(val_outputs) largest = KeepLargestConnectedComponent(applied_labels=[1]) val_labels = post_label(val_labels) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=False, ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), os.path.join(out_dir, "best_metric_model.pth")) print("saved new best metric model") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format( tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) train_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=['img', 'seg']) ]) val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), ToTensord(keys=['img', 'seg']) ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device('cuda:0') model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print('-' * 10) print('epoch {}/{}'.format(epoch + 1, 5)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to( device), batch_data['seg'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print( 'current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}' .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') shutil.rmtree(tempdir) print('train completed, best_metric: {:.4f} at epoch: {}'.format( best_metric, best_metric_epoch)) writer.close()
def fit(model, train_ds, val_ds, batch_size, epoch_num, loss_function, optimizer, device, root_dir, callbacks=None, verbose=1): # train_loader = torch.utils.data.DataLoader( # train_ds, batch_size=batch_size, shuffle=True, num_workers=2 # ) # val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2) # tensorboard writer = SummaryWriter() val_interval = 1 # do validation for every epoch, best_metric = float("-inf") best_metric_epoch = float("-inf") epoch_loss_values = list() metric_values = list() epoch_times = list() total_start = time.time() for epoch in range(epoch_num): epoch_start = time.time() print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step_start = time.time() step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), epoch * len(train_ds) + step - 1) print( f"Epoch: [{epoch + 1}], [{step}/{len(train_ds) // train_loader.batch_size}], train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f} " ) # ETA: 0:01:18 epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) val_outputs = model(val_inputs) value = compute_meandice(val_outputs, val_inputs, sigmoid=True, logit_thresh=0.5) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) writer.add_scalar('DiceMetric/val', metric, (epoch + 1) * len(train_ds)) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"), ) # torch.save({ # 'epoch': EPOCH, # 'model_state_dict': net.state_dict(), # 'optimizer_state_dict': optimizer.state_dict(), # 'loss': LOSS, # }, PATH) print("saved new best metric model") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}", f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}", ), print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}") epoch_times.append(time.time() - epoch_start) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}", f" total time: {(time.time() - total_start):.4f}" ), return ( epoch_num, time.time() - total_start, epoch_loss_values, metric_values, epoch_times, )
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation train_imtrans = Compose( [ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) train_segtrans = Compose( [ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) val_segtrans = Compose([AddChannel(), ToTensor()]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") shutil.rmtree(tempdir) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def train(model, train_loader, val_loader, loss_function, optimizer, output_dir, device, epoch_num=600, val_interval=2): best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() for epoch in range(epoch_num): print('-' * 10) print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['image'].to( device), batch_data['label'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() print( f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}" ) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 for val_data in val_loader: val_inputs, val_labels = val_data['image'].to( device), val_data['label'].to(device) roi_size = (160, 160, 160) sw_batch_size = 4 val_outputs = sliding_window_inference( val_inputs, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False, to_onehot_y=True, mutually_exclusive=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( model.state_dict(), os.path.join(output_dir, 'best_metric_model.pth')) print('saved new best metric model') print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) metric_output = os.path.join(output_dir, "training_metrics.png") plot_metrics(epoch_loss_values, metric_values, val_interval, output_path=metric_output)
def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=False): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], interp_order=["bilinear", "nearest"]), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], interp_order=["bilinear", "nearest"]), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # create a training data loader if cachedataset: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) writer.close() return epoch_loss_values, best_metric, best_metric_epoch
def train_epoch(data_loader, model, alpha, criterion1, criterion2, optimizer=None, mode_train=True): """ Inputs: : data_loader: training or validation sets in dataset pytorch format : model (class): unet architecure : alpha (int): alpgha value for the boundary loss : criterion1: loss function 1 (in our case Generalized Dice Loss) : criterion2: loss function 2 ( in our case Boundary loss) : optimizer (class): define optimizer (ie. adam) : mode_train (bool): True is train, False is validation """ total_batch = len(data_loader) batch_loss = average_metrics() batch_dice = average_metrics() batch_hausdorff = average_metrics() post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) if mode_train: model.train() else: model.eval() for batch_idx, (data, y) in enumerate(data_loader): data = data.to(device) data = data.type(torch.cuda.FloatTensor) y = y.to(device) # Reset gradients if mode_train: optimizer.zero_grad() # Get prediction out = model(data.to(device)) # Evaluation outputs = post_pred(out.to(device)) labels = post_label(y.to(device)) dice_val = compute_meandice(y_pred=outputs, y=labels, include_background=False) hausdorff_val = compute_hausdorff_distance(y_pred=outputs, y=labels, distance_metric='euclidean') if math.isnan(dice_val.item()) or math.isnan(hausdorff_val.item()): pass else: batch_dice.update(dice_val.item()) batch_hausdorff.update(hausdorff_val.item()) # Losses if criterion1 and criterion2 == None: # Get REGION loss region_loss = criterion1(out.to(device), y.to(device)) # Backpropagate error region_loss.backward() # Update loss batch_loss.update(region_loss.item()) elif criterion1 == None and criterion2: # Get CONTOUR loss out_probs = softmax(out.to(device), dim=1) contour_loss = criterion2(out_probs.to(device), dy.to(device)) # Backpropagate error contour_loss.backward() # Update loss batch_loss.update(contour_loss.item()) else: # Get REGION loss region_loss = criterion1(out.to(device), y.to(device)) # Get CONTOUR loss out_probs = softmax(out.to(device), dim=1) contour_loss = criterion2(out_probs.to(device), y.to(device)) # Combination both losses loss = region_loss + alpha * contour_loss # Backpropagate error loss.backward() # Update loss # Update loss batch_loss.update(loss.item()) # Optimize if mode_train: optimizer.step() # Log if (batch_idx + 1) % opt.verbose == 0 and mode_train: if criterion1 and criterion2 == None: print( f'Iteration {(batch_idx + 1)}/{total_batch} - GD Loss: {batch_loss.val} - Dice: {batch_dice.val} - Hausdorff:{batch_hausdorff.val}' ) elif criterion1 == None and criterion2: print( f'Iteration {(batch_idx + 1)}/{total_batch} - B Loss: {batch_loss.val} - Dice: {batch_dice.val} - Hausdorff:{batch_hausdorff.val}' ) else: print( f'Iteration {(batch_idx + 1)}/{total_batch} - GD & B Loss: {batch_loss.val} - Dice: {batch_dice.val} - Hausdorff:{batch_hausdorff.val}' ) return batch_loss.avg, batch_dice.avg, batch_hausdorff.avg