def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net) return predict_segmentation(seg_probs)
def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels
def run_inference_test(root_dir, device=torch.device("cuda:0")): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=int) for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch( val_outputs, { "filename_or_obj": val_data["img.filename_or_obj"], "affine": val_data["img.affine"] }) metric = metric_sum / metric_count return metric
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) device = torch.device('cuda:0') model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load('best_metric_model.pth')) model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 saver = NiftiSaver(output_dir='./output') for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count print('evaluation metric:', metric) shutil.rmtree(tempdir)
def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size): inputs = torch.ones(*image_shape) device = torch.device("cpu:0") def compute(data): return data + 1 result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute) expected_val = np.ones(image_shape, dtype=np.float32) + 1 self.assertTrue(np.allclose(result.numpy(), expected_val))
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation train_imtrans = Compose( [ ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) train_segtrans = Compose( [ AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 2)), ToTensor(), ] ) val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) val_segtrans = Compose([AddChannel(), ToTensor()]) # define nifti dataset, data loader check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") shutil.rmtree(tempdir) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) device = torch.device("cuda:0") model = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) model.load_state_dict(torch.load("best_metric_model.pth")) model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 saver = NiftiSaver(output_dir="./output") for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch( val_outputs, { "filename_or_obj": val_data["img.filename_or_obj"], "affine": val_data["img.affine"] }) metric = metric_sum / metric_count print("evaluation metric:", metric) shutil.rmtree(tempdir)
print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print(
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print('generating synthetic data to {} (this may take a while)'.format( tempdir)) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) train_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=['img', 'seg']) ]) val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), ToTensord(keys=['img', 'seg']) ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer device = torch.device('cuda:0') model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print('-' * 10) print('epoch {}/{}'.format(epoch + 1, 5)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to( device), batch_data['seg'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print( 'current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}' .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') shutil.rmtree(tempdir) print('train completed, best_metric: {:.4f} at epoch: {}'.format( best_metric, best_metric_epoch)) writer.close()
def main(): """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run inference with basic UNet with MONAI.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # inference params batch_size_inference = config_info['inference']['batch_size_inference'] # temporary check as sliding window inference does not accept higher batch size assert batch_size_inference == 1 prob_thr = config_info['inference']['probability_threshold'] model_to_load = config_info['inference']['model_to_load'] if not os.path.exists(model_to_load): raise IOError('Trained model not found') # data params data_root = config_info['data']['data_root'] inference_list = config_info['data']['inference_list'] # output saving out_dir = config_info['output']['out_dir'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) """ Data Preparation """ val_files = create_data_list(data_folder_list=data_root, subject_list=inference_list, img_postfix='_Image', is_inference=True) print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for inference: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - NOTE: resizing needs to be applied afterwards, otherwise it cannot be remapped back to original size val_transforms = Compose([ LoadNiftid(keys=['img']), AddChanneld(keys=['img']), NormalizeIntensityd(keys=['img']), ToTensord(keys=['img']) ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size_inference, num_workers=num_workers) """ Network preparation """ device = torch.cuda.current_device() # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) net.load_state_dict(torch.load(model_to_load)) net.eval() """ Run inference """ with torch.no_grad(): saver = NiftiSaver(output_dir=out_dir) for val_data in val_loader: val_images = val_data['img'].to(device) orig_size = list(val_images.shape) resized_size = copy.deepcopy(orig_size) resized_size[2] = 96 resized_size[3] = 96 val_images_resize = torch.nn.functional.interpolate( val_images, size=resized_size[2:], mode='trilinear') # define sliding window size and batch size for windows inference roi_size = (96, 96, 1) val_outputs = sliding_window_inference(val_images_resize, roi_size, batch_size_inference, net) val_outputs = (val_outputs.sigmoid() >= prob_thr).float() val_outputs_resized = torch.nn.functional.interpolate( val_outputs, size=orig_size[2:], mode='nearest') # add post-processing val_outputs_resized = val_outputs_resized.detach().cpu().numpy() strt = ndimage.generate_binary_structure(3, 2) post = padded_binary_closing(np.squeeze(val_outputs_resized), strt) post = get_largest_component(post) val_outputs_resized = val_outputs_resized * post # out = np.zeros(img.shape[:-1], np.uint8) # out = set_ND_volume_roi_with_bounding_box_range(out, bb_min, bb_max, out_roi) saver.save_batch( val_outputs_resized, { 'filename_or_obj': val_data['img.filename_or_obj'], 'affine': val_data['img.affine'] })
def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=False): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ]) # create a training data loader if cachedataset: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}" ) writer.close() return epoch_loss_values, best_metric, best_metric_epoch