def test_tb_image_shape(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriter(log_dir=tempdir) plot_2d_or_3d_image(torch.zeros(shape), 0, writer) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0)
def test_tbx_video(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriterX(log_dir=tempdir) plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0)
def test_tb_image_shape(self, shape): tempdir = tempfile.mkdtemp() shutil.rmtree(tempdir, ignore_errors=True) plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter(log_dir=tempdir)) self.assertTrue(os.path.exists(tempdir)) self.assertTrue(len(glob.glob(tempdir)) > 0) shutil.rmtree(tempdir, ignore_errors=True)
def test_tb_image_shape(self, shape): default_dir = os.path.join('.', 'runs') shutil.rmtree(default_dir, ignore_errors=True) plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter()) self.assertTrue(os.path.exists(default_dir)) self.assertTrue(len(glob.glob(default_dir)) > 0) shutil.rmtree(default_dir, ignore_errors=True)
def write_images(self, epoch): if not self.plot_data or not len(self.plot_data): return all_imgs = [] for region in sorted(self.plot_data.keys()): metric = self.metric_data.get(region) region_data = self.plot_data[region] if len(region_data[0].shape) == 3: ti = Image.new("RGB", region_data[0].shape[1:]) d = ImageDraw.Draw(ti) t = "region: {}".format(region) if self.compute_metric: t = t + "\ndice: {:.4f}".format(metric.mean()) t = t + "\nstdev: {:.4f}".format(metric.stdev()) d.multiline_text((10, 10), t, fill=(255, 255, 0)) ti = rescale_array( np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis]) all_imgs.append(ti) all_imgs.extend(region_data) if len(all_imgs[0].shape) == 3: img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)), nrow=4, normalize=True, pad_value=2) self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})", img_tensor=img_tensor, global_step=epoch) if len(all_imgs[0].shape) == 4: for region in sorted(self.plot_data.keys()): tags = [ f"region_{region}_image", f"region_{region}_label", f"region_{region}_output" ] if torch.distributed.is_initialized(): rank = "r{}-".format(torch.distributed.get_rank()) tags = [rank + tags[0], rank + tags[1], rank + tags[2]] for i in range(3): img = self.plot_data[region][i] img = np.moveaxis(img, -3, -1) plot_2d_or_3d_image(img[np.newaxis], epoch, self.writer, 0, self.max_channels, self.max_frames, tags[i]) self.logger.info( "Saved {} Regions {} into Tensorboard at epoch: {}".format( len(self.plot_data), sorted([*self.plot_data]), epoch)) self.writer.flush()
def __call__(self, engine): step = self.global_iter_transform(engine.state.iteration) show_images = self.batch_transform(engine.state.batch)[0] if torch.is_tensor(show_images): show_images = show_images.detach().cpu().numpy() if show_images is not None: if not isinstance(show_images, np.ndarray): raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.') plot_2d_or_3d_image(show_images, step, self._writer, self.index, self.max_channels, self.max_frames, 'input_0') show_labels = self.batch_transform(engine.state.batch)[1] if torch.is_tensor(show_labels): show_labels = show_labels.detach().cpu().numpy() if show_labels is not None: if not isinstance(show_labels, np.ndarray): raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.') plot_2d_or_3d_image(show_labels, step, self._writer, self.index, self.max_channels, self.max_frames, 'input_1') show_outputs = self.output_transform(engine.state.output) if torch.is_tensor(show_outputs): show_outputs = show_outputs.detach().cpu().numpy() if show_outputs is not None: if not isinstance(show_outputs, np.ndarray): raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.') plot_2d_or_3d_image(show_outputs, step, self._writer, self.index, self.max_channels, self.max_frames, 'output') self._writer.flush()
def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. Raises: TypeError: When ``output_transform(engine.state.output)[0]`` type is not in ``Optional[Union[numpy.ndarray, torch.Tensor]]``. TypeError: When ``batch_transform(engine.state.batch)[1]`` type is not in ``Optional[Union[numpy.ndarray, torch.Tensor]]``. TypeError: When ``output_transform(engine.state.output)`` type is not in ``Optional[Union[numpy.ndarray, torch.Tensor]]``. """ step = self.global_iter_transform( engine.state.epoch if self.epoch_level else engine.state.iteration) show_images = self.batch_transform(engine.state.batch)[0] if torch.is_tensor(show_images): show_images = show_images.detach().cpu().numpy() if show_images is not None: if not isinstance(show_images, np.ndarray): raise TypeError( "output_transform(engine.state.output)[0] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}." ) plot_2d_or_3d_image(show_images, step, self._writer, self.index, self.max_channels, self.max_frames, "input_0") show_labels = self.batch_transform(engine.state.batch)[1] if torch.is_tensor(show_labels): show_labels = show_labels.detach().cpu().numpy() if show_labels is not None: if not isinstance(show_labels, np.ndarray): raise TypeError( "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) plot_2d_or_3d_image(show_labels, step, self._writer, self.index, self.max_channels, self.max_frames, "input_1") show_outputs = self.output_transform(engine.state.output) if torch.is_tensor(show_outputs): show_outputs = show_outputs.detach().cpu().numpy() if show_outputs is not None: if not isinstance(show_outputs, np.ndarray): raise TypeError( "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) plot_2d_or_3d_image(show_outputs, step, self._writer, self.index, self.max_channels, self.max_frames, "output") self._writer.flush()
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list # Need to concatenate multiple channels here if you want multichannel segmentation # Check other examples on Monai webpage. if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), NormalizeIntensityd(keys=['image']), ScaleIntensityd(keys=['image']), ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available()) # # try to use all the available GPUs # devices = get_devices_spec(None) # build the network net = build_net() net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # loss_function = monai.losses.DiceLoss(sigmoid=True) # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7) loss_function = monai.losses.DiceCELoss(sigmoid=True) optim = torch.optim.Adam(net.parameters(), lr=opt.lr) net_scheduler = get_scheduler(optim, opt) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ] ) val_transforms = Compose( [ LoadNiftid(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys=["img", "seg"]), ToTensord(keys=["img", "seg"]), ] ) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader( check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader( val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") shutil.rmtree(tempdir) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
g_loss.backward() optimizerG.step() G_losses.append(g_loss.item()) lossC01s.append(lossC01.item()) lossC02s.append(lossC02.item()) lossC03s.append(lossC03.item()) # log to tensorboard every 10 steps if batch_index % 1 == 0: writer.add_scalar('Train/G_loss', g_loss.item(), epoch) writer.add_scalar('Train/D_loss', d_loss.item(), epoch) plot_2d_or_3d_image(targetC01, epoch * len(training_loader) + batch_index, writer, index=0, tag="Groundtruth_train/C01") plot_2d_or_3d_image(targetC02, epoch * len(training_loader) + batch_index, writer, index=0, tag="Groundtruth_train/C01") plot_2d_or_3d_image(targetC03, epoch * len(training_loader) + batch_index, writer, index=0, tag="Groundtruth_train/C01") plot_2d_or_3d_image(outputC01, epoch * len(training_loader) + batch_index, writer,
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray((im * 255).astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray((seg * 255).astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]), EnsureTyped(keys=["img", "seg"]), ]) val_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AddChanneld(keys=["img", "seg"]), ScaleIntensityd(keys=["img", "seg"]), EnsureTyped(keys=["img", "seg"]), ]) # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( spatial_dims=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to( device), val_data["seg"].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = [ post_trans(i) for i in decollate_batch(val_outputs) ] # compute metric for current iteration dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_dict.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), 'best_metric_model.pth') print('saved new best metric model') print( "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') shutil.rmtree(tempdir) print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch))
def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # --------------- Dataset --------------- datadir1 = "/home1/quanquan/datasets/lsw/benign_65/fpAML_55/slices/" image_files = np.array([x.path for x in os.scandir(datadir1+"image") if x.name.endswith(".npy")]) label_files = np.array([x.path for x in os.scandir(datadir1+"label") if x.name.endswith(".npy")]) image_files.sort() label_files.sort() # --- ??? what's up ??? train_files = [{"img":img, "seg":seg} for img, seg in zip(image_files[:-20], label_files[:-20])] val_files = [{"img":img, "seg":seg} for img, seg in zip(image_files[-20:], label_files[-20:])] # print("files", train_files[:20]) # print(val_files) val_imtrans = Compose([LoadNumpy(data_only=True), ScaleIntensity(), AddChannel(), ToTensor()]) val_segtrans = Compose([LoadNumpy(data_only=True), AddChannel(), ToTensor()]) # define array dataset, data loader check_ds = ArrayDataset(image_files, train_imtrans, label_files, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(image_files[:-20], train_imtrans, label_files[:-20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(image_files[-20:], val_imtrans, label_files[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # --------------- model --------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) # --------------- loss function --------------- loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() # writer = SummaryWriter(logdir=tdir(output_dir, "sumamry")) # ------------------- Training ---------------------- for epoch in range(max_epoch): # print("-" * 10) # print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"\r\t Training batch: {step}/{epoch_len}, \ttrain_loss: {loss.item():.4f}\t", end="") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"\n\tepoch {epoch + 1} \taverage loss: {epoch_loss:.4f}") # ------------------- Save Model ---------------------- if epoch % 5 == 0: def get_lr(optimizer): for param_group in optimizer.param_groups: return float(param_group['lr']) state = {'epoch': epoch + 1, 'lr': get_lr(optimizer), 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() } torch.save(state, tfilename(output_dir, "model", "{}_{}.pkl".format("lsw_monai_simple", epoch))) # ------------------- Evaluation ----------------------- if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to(device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) # value, _ = dice_metric(y_pred=val_outputs, y=val_labels) value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( epoch + 1, metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def main(): """ Read input and configuration parameters """ parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] sliding_window_validation = config_info['training']['sliding_window_validation'] # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving # model saving out_model_dir = os.path.join(config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) max_nr_models_saved = config_info['output']['max_nr_models_saved'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) """ Data Preparation """ # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_train_data = monai.utils.misc.first(train_loader) print("Training data tensor shapes") print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) check_valid_data = monai.utils.misc.first(val_loader) print("Validation data tensor shapes") print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() """ Training loop """ # start a typical PyTorch training best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) net.to(device) for epoch in range(nr_train_epochs): print('-' * 10) print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs)) net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device) opt.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() opt.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item())) writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % validation_every_n_epochs == 0: net.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None check_tot_validation = 0 for val_data in val_loader: check_tot_validation += 1 val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device) if sliding_window_validation: print('Running sliding window validation') roi_size = (96, 96, 1) val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() else: print('Running 2D validation') # compute validation val_outputs = net(val_images) value = 1.0 - loss_function(val_outputs, val_labels) metric_count += 1 metric_sum += value.item() print("Total number of data in validation: %d" % check_tot_validation) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth')) print('saved new best metric model') print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) epoch_len = len(train_ds) // train_loader.batch_size writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step) writer_valid.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output') print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) writer_train.close() writer_valid.close()
def main(): opt = Options().parse() # monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # check gpus if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else: num_gpus = 0 print('number of GPU:', num_gpus) # Data loader creation # train images train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) # validation images val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) # test images test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) # augment the data list for training for i in range(int(opt.increase_factor_data)): train_images.extend(train_images) train_segs.extend(train_segs) print('Number of training patches per epoch:', len(train_images)) print('Number of training images per epoch:', len(train_images_for_dice)) print('Number of validation images per epoch:', len(val_images)) print('Number of test images per epoch:', len(test_images)) # Creation of data directories for data_loader train_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_segs)] train_dice_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] val_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images, val_segs)] test_dicts = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(test_images, test_segs)] # Transforms list if opt.resolution is not None: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] else: train_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # augmentation ScaleIntensityd(keys=['image']), # intensity RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"), RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), ToTensord(keys=['image', 'label']) ] val_transforms = [ LoadImaged(keys=['image', 'label']), AddChanneld(keys=['image', 'label']), # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground NormalizeIntensityd(keys=['image']), # intensity ScaleIntensityd(keys=['image']), SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch ToTensord(keys=['image', 'label']) ] train_transforms = Compose(train_transforms) val_transforms = Compose(val_transforms) # create a training data loader check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) # create a training_dice data loader check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # create a validation data loader check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) # build the network if opt.network is 'nnunet': net = build_net() # nn build_net elif opt.network is 'unetr': net = build_UNETR() # UneTR net.cuda() if num_gpus > 1: net = torch.nn.DataParallel(net) if opt.preload is not None: net.load_state_dict(torch.load(opt.preload)) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) loss_function = monai.losses.DiceCELoss(sigmoid=True) torch.backends.cudnn.benchmark = opt.benchmark if opt.network is 'nnunet': optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) elif opt.network is 'unetr': optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) # start a typical PyTorch training val_interval = 1 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() writer = SummaryWriter() for epoch in range(opt.epochs): print("-" * 10) print(f"epoch {epoch + 1}/{opt.epochs}") net.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() optim.zero_grad() outputs = net(inputs) loss = loss_function(outputs, labels) loss.backward() optim.step() epoch_loss += loss.item() epoch_len = len(check_train) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if opt.network is 'nnunet': update_learning_rate(net_scheduler, optim) if (epoch + 1) % val_interval == 0: net.eval() with torch.no_grad(): def plot_dice(images_loader): val_images = None val_labels = None val_outputs = None for data in images_loader: val_images, val_labels = data["image"].cuda(), data["label"].cuda() roi_size = opt.patch_size sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] dice_metric(y_pred=val_outputs, y=val_labels) # aggregate the final mean dice result metric = dice_metric.aggregate().item() # reset the status for next validation round dice_metric.reset() return metric, val_images, val_labels, val_outputs metric, val_images, val_labels, val_outputs = plot_dice(val_loader) # Save best model if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(net.state_dict(), "best_metric_model.pth") print("saved new best metric model") metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) # Logger bar print( "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch ) ) writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) writer.add_scalar("Testing_dice", metric_test, epoch + 1) writer.add_scalar("Training_dice", metric_train, epoch + 1) writer.add_scalar("Validation_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label # val_outputs = (val_outputs.sigmoid() >= 0.5).float() plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close()
def __call__(self, engine: Engine): step = self.global_iter_transform( engine.state.epoch if self.epoch_level else engine.state.iteration) image_tensor = self.batch_transform(engine.state.batch)[0] label_tensor = self.batch_transform(engine.state.batch)[1] if image_tensor is not None: show_images = image_tensor[self.index] if torch.is_tensor(show_images): show_images = show_images.detach().cpu().numpy() if show_images is not None: if not isinstance(show_images, (np.ndarray, torch.Tensor, list, tuple)): raise TypeError( "output_transform(engine.state.output)[0] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}." ) plot_2d_or_3d_image( # add batch dim and plot the first item data=show_images[None], step=step, writer=self._writer, index=0, max_channels=self.max_channels, frame_dim=self.frame_dim, max_frames=self.max_frames, tag=self.prefix_name + "/input_0", ) if label_tensor is not None: show_labels = label_tensor[self.index] if isinstance(show_labels, torch.Tensor): show_labels = show_labels.detach().cpu().numpy() if show_labels is not None: if not isinstance(show_labels, np.ndarray): raise TypeError( "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) plot_2d_or_3d_image( data=show_labels[None], step=step, writer=self._writer, index=0, max_channels=self.max_channels, frame_dim=self.frame_dim, max_frames=self.max_frames, tag=self.prefix_name + "/input_1", ) if self.output_transform(engine.state.output) is not None: show_outputs = self.output_transform( engine.state.output)[self.index] # ! tmp solution to handle multi-inputs if isinstance(show_outputs, (list, tuple)): show_outputs = show_outputs[0] if isinstance(show_outputs, torch.Tensor): show_outputs = show_outputs.detach().cpu().numpy() if show_outputs is not None: if not isinstance(show_outputs, np.ndarray): raise TypeError( "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) plot_2d_or_3d_image( data=show_outputs[None], step=step, writer=self._writer, index=0, max_channels=self.max_channels, frame_dim=self.frame_dim, max_frames=self.max_frames, tag=self.prefix_name + "/output", ) self._writer.flush()
def run_training_test(root_dir, device="cuda:0", cachedataset=0): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ] ) train_transforms.set_random_state(1234) val_transforms = Compose( [ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd(keys="img"), ToTensord(keys=["img", "seg"]), ] ) # create a training data loader if cachedataset == 2: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) elif cachedataset == 3: train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) dice_metric = DiceMetric(include_background=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs")) model_filename = os.path.join(root_dir, "best_metric_model.pth") for epoch in range(6): print("-" * 10) print(f"Epoch {epoch + 1}/{6}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model)) value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += not_nans.item() metric_sum += value.item() * not_nans.item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print("saved new best metric model") print( f"current epoch {epoch +1} current mean dice: {metric:0.4f} " f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print(f"train completed, best_metric: {best_metric:0.4f} at epoch: {best_metric_epoch}") writer.close() return epoch_loss_values, best_metric, best_metric_epoch
def train(): """ :return: """ print('Model training started') set_determinism(seed=0) epoch_num = params['nb_epoch'] val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") writer.add_scalar("epoch_loss", epoch_loss, epoch + 1) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 for val_data in val_loader: val_inputs, val_labels = ( val_data["image"].to(device), val_data["label"].to(device), ) val_outputs = model(val_inputs) value = compute_meandice( y_pred=val_outputs, y=val_labels, include_background=False, to_onehot_y=True, mutually_exclusive=True, ) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save( model, os.path.join('saved_models', "best_metric_model.pth")) print("saved new best metric model") print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last validation subject as GIF in TensorBoard with the corresponding image, label and pred val_pred = torch.argmax(val_outputs, dim=1, keepdim=True) summary_img = torch.cat((val_inputs, val_labels, val_pred), dim=3) plot_2d_or_3d_image(summary_img, epoch + 1, writer, tag='last_val_subject') # Model checkpointing if (epoch + 1) % 20 == 0: torch.save( model, os.path.join('saved_models', params['f_name'] + '_' + str(epoch + 1) + '.pth')) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_2d(128, 128, num_seg_classes=1) Image.fromarray(im.astype("uint8")).save( os.path.join(tempdir, f"img{i:d}.png")) Image.fromarray(seg.astype("uint8")).save( os.path.join(tempdir, f"seg{i:d}.png")) images = sorted(glob(os.path.join(tempdir, "img*.png"))) segs = sorted(glob(os.path.join(tempdir, "seg*.png"))) train_files = [{ "img": img, "seg": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "img": img, "seg": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) train_segtrans = Compose([ LoadImage(image_only=True), AddChannel(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(prob=0.5, spatial_axes=(0, 1)), ToTensor(), ]) val_imtrans = Compose([ LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor() ]) val_segtrans = Compose( [LoadImage(image_only=True), AddChannel(), ToTensor()]) # define array dataset, data loader check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans) check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) # create a training data loader train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans) train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) dice_metric = DiceMetric(include_background=True, reduction="mean") post_trans = Compose( [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # start a typical PyTorch training val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() for epoch in range(10): print("-" * 10) print(f"epoch {epoch + 1}/{10}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data[0].to(device), batch_data[1].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0.0 metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data[0].to( device), val_data[1].to(device) roi_size = (96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) val_outputs = post_trans(val_outputs) value, _ = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth") print("saved new best metric model") print( "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}" .format(epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar("val_mean_dice", metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) writer.close()
def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=False): monai.config.print_config() images = sorted(glob(os.path.join(root_dir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(root_dir, 'seg*.nii.gz'))) train_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ 'img': img, 'seg': seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]), ToTensord(keys=['img', 'seg']) ]) train_transforms.set_random_state(1234) val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), ScaleIntensityd(keys=['img', 'seg']), ToTensord(keys=['img', 'seg']) ]) # create a training data loader if cachedataset: train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8) else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(do_sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 5e-4) # start a typical PyTorch training val_interval = 2 best_metric, best_metric_epoch = -1, -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter(log_dir=os.path.join(root_dir, 'runs')) model_filename = os.path.join(root_dir, 'best_metric_model.pth') for epoch in range(6): print('-' * 10) print('Epoch {}/{}'.format(epoch + 1, 6)) model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['img'].to( device), batch_data['seg'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = len(train_ds) // train_loader.batch_size print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item())) writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) epoch_loss /= step epoch_loss_values.append(epoch_loss) print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 val_images = None val_labels = None val_outputs = None for val_data in val_loader: val_images, val_labels = val_data['img'].to( device), val_data['seg'].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference( val_images, roi_size, sw_batch_size, model) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), model_filename) print('saved new best metric model') print( "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d" % (epoch + 1, metric, best_metric, best_metric_epoch)) writer.add_scalar('val_mean_dice', metric, epoch + 1) # plot the last model output as GIF image in TensorBoard with the corresponding image and label plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image') plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label') plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output') print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) writer.close() return epoch_loss_values, best_metric, best_metric_epoch