def train_handlers(self, context: Context): handlers: List[Any] = [] # LR Scheduler lr_scheduler = self.lr_scheduler_handler(context) if lr_scheduler: handlers.append(lr_scheduler) if context.local_rank == 0: handlers.extend([ StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), TensorBoardStatsHandler( log_dir=context.events_dir, tag_name="train_loss", output_transform=from_engine(["loss"], first=True), ), ]) if context.evaluator: logger.info( f"{context.local_rank} - Adding Validation to run every '{self._val_interval}' interval" ) handlers.append( ValidationHandler(self._val_interval, validator=context.evaluator, epoch_level=True)) return handlers
def val_key_metric(self, context: Context): all_metrics = dict() all_metrics["val_mean_dice"] = MeanDice(output_transform=from_engine( ["pred", "label"]), include_background=False) for key_label in self._labels: if key_label != "background": all_metrics[key_label + "_dice"] = MeanDice( output_transform=from_engine( ["pred_" + key_label, "label_" + key_label]), include_background=False) return all_metrics
def train_handlers(self, context: Context): handlers = super().train_handlers(context) if context.local_rank == 0: handlers.append( TensorBoardImageHandler( log_dir=context.events_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine(["pred"]), interval=20, epoch_level=True, )) return handlers
def test_compute(self, data, expected): # Set up handlers handlers = [ # Mark with Ignite Event MarkHandler(Events.STARTED), # Mark with literal MarkHandler("EPOCH_STARTED"), # Mark with literal and providing the message MarkHandler("EPOCH_STARTED", "Start of the epoch"), # Define a range using one prefix (between BATCH_STARTED and BATCH_COMPLETED) RangeHandler("Batch"), # Define a range using a pair of events RangeHandler((Events.STARTED, Events.COMPLETED)), # Define a range using a pair of literals RangeHandler(("GET_BATCH_STARTED", "GET_BATCH_COMPLETED"), msg="Batching!"), # Define a range using a pair of literal and events RangeHandler(("GET_BATCH_STARTED", Events.COMPLETED)), # Define the start of range using literal RangePushHandler("ITERATION_STARTED"), # Define the start of range using event RangePushHandler(Events.ITERATION_STARTED, "Iteration 2"), # Define the start of range using literals and providing message RangePushHandler("EPOCH_STARTED", "Epoch 2"), # Define the end of range using Ignite Event RangePopHandler(Events.ITERATION_COMPLETED), RangePopHandler(Events.EPOCH_COMPLETED), # Define the end of range using literal RangePopHandler("ITERATION_COMPLETED"), # Other handlers StatsHandler(tag_name="train", output_transform=from_engine(["label"], first=True)), ] # Set up an engine engine = SupervisedEvaluator( device=torch.device("cpu:0"), val_data_loader=data, epoch_length=1, network=torch.nn.PReLU(), postprocessing=lambda x: dict(pred=x["pred"] + 1.0), decollate=True, val_handlers=handlers, ) # Run the engine engine.run() # Get the output from the engine output = engine.state.output[0] torch.testing.assert_allclose(output["pred"], expected)
def region_wise_metrics(regions, metric, prefix, keys=("pred", "label")): all_metrics = dict() all_metrics[metric] = MeanDice(output_transform=from_engine(keys), include_background=False) if regions: labels = regions if isinstance(regions, dict) else { name: idx for idx, name in enumerate(regions, start=1) } for name, idx in labels.items(): all_metrics[f"{prefix}_{name}_dice"] = MeanDice( output_transform=from_engine_idx(keys, idx), include_background=False, ) return all_metrics
def main(tempdir): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # model file path model_file = glob("./runs/net_key_metric*")[0] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), EnsureTyped(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir="./runs/") ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=model_file, load_dict={"net": net}), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) evaluator.run()
def main(tempdir): monai.config.print_config() # set root log level to INFO and init a train logger, will be used in `StatsHandler` logging.basicConfig(stream=sys.stdout, level=logging.INFO) get_logger("train_log") # create a temporary directory and 40 random image, mask pairs print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), EnsureTyped(keys=["image", "label"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) val_post_transforms = Compose([ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) val_handlers = [ # use the logger "train_log" defined at the beginning of this program StatsHandler(name="train_log", output_transform=lambda x: None), TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), TensorBoardImageHandler( log_dir="./runs/", batch_transform=from_engine(["image", "label"]), output_transform=from_engine(["pred"]), ), CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_post_transforms, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) train_post_transforms = Compose([ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), # use the logger "train_log" defined at the beginning of this program StatsHandler(name="train_log", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), CheckpointSaver(save_dir="./runs/", save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), postprocessing=train_post_transforms, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, ) trainer.run()
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import torch from ignite.engine import Engine, Events from parameterized import parameterized from monai.handlers import MeanDice, from_engine TEST_CASE_1 = [{ "include_background": True, "output_transform": from_engine(["pred", "label"]) }, 0.75, (4, 2)] TEST_CASE_2 = [{ "include_background": False, "output_transform": from_engine(["pred", "label"]) }, 0.66666, (4, 1)] TEST_CASE_3 = [ { "reduction": "mean_channel", "output_transform": from_engine(["pred", "label"]) }, torch.Tensor([1.0, 0.0, 1.0, 1.0]), (4, 2), ]
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) val_files = [{ "image": img, "label": seg } for img, seg in zip(images, segs)] # define transforms for image and segmentation val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) val_postprocessing = Compose([ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` SaveImaged(keys="pred", meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform"), ]) val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, output_postfix="seg_handler", batch_transform=from_engine("image_meta_dict"), output_transform=from_engine("pred"), ), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=True if amp else False, ) evaluator.run() return evaluator.state.best_metric
def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz"))) train_files = [{ "image": img, "label": seg } for img, seg in zip(images[:20], segs[:20])] val_files = [{ "image": img, "label": seg } for img, seg in zip(images[-20:], segs[-20:])] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers) # create a validation data loader val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers) # create UNet, DiceLoss and Adam optimizer net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) summary_writer = SummaryWriter(log_dir=root_dir) val_postprocessing = Compose([ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) class _TestEvalIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) def _forward_completed(self, engine): pass val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None), TensorBoardImageHandler(log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred")), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), _TestEvalIterEvents(), ] evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), postprocessing=val_postprocessing, key_val_metric={ "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=True if amp else False, ) train_postprocessing = Compose([ ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ]) class _TestTrainIterEvents: def attach(self, engine): engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed) engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed) engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed) engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed) def _forward_completed(self, engine): pass def _loss_completed(self, engine): pass def _backward_completed(self, engine): pass def _model_completed(self, engine): pass train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)), TensorBoardStatsHandler(summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True)), CheckpointSaver(save_dir=root_dir, save_dict={ "net": net, "opt": opt }, save_interval=2, epoch_level=True), _TestTrainIterEvents(), ] trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=True if amp else False, optim_set_to_none=True, ) trainer.run() return evaluator.state.best_metric
def save_func(engine): for o in from_engine("pred")(engine.state.output): saver(o)
def val_key_metric(self, context: Context): return { "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }
def create_trainer(args): set_determinism(seed=args.seed) multi_gpu = args.multi_gpu local_rank = args.local_rank if multi_gpu: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device("cuda:{}".format(local_rank)) torch.cuda.set_device(device) else: device = torch.device("cuda" if args.use_gpu else "cpu") pre_transforms = get_pre_transforms(args.roi_size, args.model_size, args.dimensions) click_transforms = get_click_transforms() post_transform = get_post_transforms() train_loader, val_loader = get_loaders(args, pre_transforms) # define training components network = get_network(args.network, args.channels, args.dimensions).to(device) if multi_gpu: network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[local_rank], output_device=local_rank) if args.resume: logging.info("{}:: Loading Network...".format(local_rank)) map_location = {"cuda:0": "cuda:{}".format(local_rank)} network.load_state_dict( torch.load(args.model_filepath, map_location=map_location)) # define event-handlers for engine val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=args.output, output_transform=lambda x: None), CheckpointSaver( save_dir=args.output, save_dict={"net": network}, save_key_metric=True, save_final=True, save_interval=args.save_interval, final_filename="model.pt", ), ] val_handlers = val_handlers if local_rank == 0 else None evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_val_interactions, key_probability="probability", train=False, ), inferer=SimpleInferer(), postprocessing=post_transform, key_val_metric={ "val_dice": MeanDice( include_background=False, output_transform=from_engine(["pred", "label"]), ) }, val_handlers=val_handlers, ) loss_function = DiceLoss(sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(network.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=args.val_freq, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), TensorBoardStatsHandler( log_dir=args.output, tag_name="train_loss", output_transform=from_engine(["loss"], first=True), ), CheckpointSaver( save_dir=args.output, save_dict={ "net": network, "opt": optimizer, "lr": lr_scheduler }, save_interval=args.save_interval * 2, save_final=True, final_filename="checkpoint.pt", ), ] train_handlers = train_handlers if local_rank == 0 else train_handlers[:2] trainer = SupervisedTrainer( device=device, max_epochs=args.epochs, train_data_loader=train_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_train_interactions, key_probability="probability", train=True, ), optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), postprocessing=post_transform, amp=args.amp, key_train_metric={ "train_dice": MeanDice( include_background=False, output_transform=from_engine(["pred", "label"]), ) }, train_handlers=train_handlers, ) return trainer
def train(data_folder=".", model_folder="runs"): """run a training pipeline.""" images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz"))) labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz"))) logging.info( f"training: image/label ({len(images)}) folder: {data_folder}") amp = True # auto. mixed precision keys = ("image", "label") train_frac, val_frac = 0.8, 0.2 n_train = int(train_frac * len(images)) + 1 n_val = min(len(images) - n_train, int(val_frac * len(images))) logging.info( f"training: train {n_train} val {n_val}, folder: {data_folder}") train_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[:n_train], labels[:n_train])] val_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[-n_val:], labels[-n_val:])] # create a training data loader batch_size = 2 logging.info(f"batch size {batch_size}") train_transforms = get_xforms("train", keys) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms) train_loader = monai.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_transforms = get_xforms("val", keys) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader( val_ds, batch_size= 1, # image-level batch to the sliding window method, not the window-level batch num_workers=2, pin_memory=torch.cuda.is_available(), ) # create BasicUNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = get_net().to(device) max_epochs, lr, momentum = 500, 1e-4, 0.95 logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}") opt = torch.optim.Adam(net.parameters(), lr=lr) # create evaluator (to be used to measure model quality during training val_post_transform = monai.transforms.Compose([ EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=2) ]) val_handlers = [ ProgressBar(), CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=3), ] evaluator = monai.engines.SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=get_inferer(), postprocessing=val_post_transform, key_val_metric={ "val_mean_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=amp, ) # evaluator as an event handler of the trainer train_handlers = [ ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] trainer = monai.engines.SupervisedTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=DiceCELoss(), inferer=get_inferer(), key_train_metric=None, train_handlers=train_handlers, amp=amp, ) trainer.run()
def train(args): # load hyper parameters task_id = args.task_id fold = args.fold val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold, args.expr_name) log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold) log_filename = os.path.join(val_output_dir, log_filename) interval = args.interval learning_rate = args.learning_rate max_epochs = args.max_epochs multi_gpu_flag = args.multi_gpu amp_flag = args.amp lr_decay_flag = args.lr_decay sw_batch_size = args.sw_batch_size tta_val = args.tta_val batch_dice = args.batch_dice window_mode = args.window_mode eval_overlap = args.eval_overlap local_rank = args.local_rank determinism_flag = args.determinism_flag determinism_seed = args.determinism_seed if determinism_flag: set_determinism(seed=determinism_seed) if local_rank == 0: print("Using deterministic training.") # transforms train_batch_size = data_loader_params[task_id]["batch_size"] if multi_gpu_flag: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cuda") properties, val_loader = get_data(args, mode="validation") _, train_loader = get_data(args, batch_size=train_batch_size, mode="train") # produce the network checkpoint = args.checkpoint net = get_network(properties, task_id, val_output_dir, checkpoint) net = net.to(device) if multi_gpu_flag: net = DistributedDataParallel(module=net, device_ids=[device]) optimizer = torch.optim.SGD( net.parameters(), lr=learning_rate, momentum=0.99, weight_decay=3e-5, nesterov=True, ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9) # produce evaluator val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver(save_dir=val_output_dir, save_dict={"net": net}, save_key_metric=True), ] evaluator = DynUNetEvaluator( device=device, val_data_loader=val_loader, network=net, num_classes=len(properties["labels"]), inferer=SlidingWindowInferer( roi_size=patch_size[task_id], sw_batch_size=sw_batch_size, overlap=eval_overlap, mode=window_mode, ), postprocessing=None, key_val_metric={ "val_mean_dice": MeanDice( include_background=False, output_transform=from_engine(["pred", "label"]), ) }, val_handlers=val_handlers, amp=amp_flag, tta_val=tta_val, ) # produce trainer loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice) train_handlers = [] if lr_decay_flag: train_handlers += [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True) ] train_handlers += [ ValidationHandler(validator=evaluator, interval=interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] trainer = DynUNetTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=optimizer, loss_function=loss, inferer=SimpleInferer(), postprocessing=None, key_train_metric=None, train_handlers=train_handlers, amp=amp_flag, ) if local_rank > 0: evaluator.logger.setLevel(logging.WARNING) trainer.logger.setLevel(logging.WARNING) logger = logging.getLogger() formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") # Setup file handler fhandler = logging.FileHandler(log_filename) fhandler.setLevel(logging.INFO) fhandler.setFormatter(formatter) logger.addHandler(fhandler) chandler = logging.StreamHandler() chandler.setLevel(logging.INFO) chandler.setFormatter(formatter) logger.addHandler(chandler) logger.setLevel(logging.INFO) trainer.run()
def validation(args): # load hyper parameters task_id = args.task_id sw_batch_size = args.sw_batch_size tta_val = args.tta_val window_mode = args.window_mode eval_overlap = args.eval_overlap multi_gpu_flag = args.multi_gpu local_rank = args.local_rank amp = args.amp # produce the network checkpoint = args.checkpoint val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, args.fold, args.expr_name) if multi_gpu_flag: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cuda") properties, val_loader = get_data(args, mode="validation") net = get_network(properties, task_id, val_output_dir, checkpoint) net = net.to(device) if multi_gpu_flag: net = DistributedDataParallel(module=net, device_ids=[device]) num_classes = len(properties["labels"]) net.eval() evaluator = DynUNetEvaluator( device=device, val_data_loader=val_loader, network=net, num_classes=num_classes, inferer=SlidingWindowInferer( roi_size=patch_size[task_id], sw_batch_size=sw_batch_size, overlap=eval_overlap, mode=window_mode, ), postprocessing=None, key_val_metric={ "val_mean_dice": MeanDice( include_background=False, output_transform=from_engine(["pred", "label"]), ) }, additional_metrics=None, amp=amp, tta_val=tta_val, ) evaluator.run() if local_rank == 0: print(evaluator.state.metrics) results = evaluator.state.metric_details["val_mean_dice"] if num_classes > 2: for i in range(num_classes - 1): print("mean dice for label {} is {}".format( i + 1, results[:, i].mean()))
def train(cfg): log_dir = create_log_dir(cfg) device = set_device(cfg) # -------------------------------------------------------------------------- # Data Loading and Preprocessing # -------------------------------------------------------------------------- # __________________________________________________________________________ # Build MONAI preprocessing train_preprocess = Compose([ ToTensorD(keys="image"), TorchVisionD(keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04), ToNumpyD(keys="image"), RandFlipD(keys="image", prob=0.5), RandRotate90D(keys="image", prob=0.5), CastToTypeD(keys="image", dtype=np.float32), RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) valid_preprocess = Compose([ CastToTypeD(keys="image", dtype=np.float32), ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0), ToTensorD(keys=("image", "label")), ]) # __________________________________________________________________________ # Create MONAI dataset train_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="training", base_dir=cfg["data_root"], ) valid_json_info_list = load_decathlon_datalist( data_list_file_path=cfg["dataset_json"], data_list_key="validation", base_dir=cfg["data_root"], ) train_dataset = PatchWSIDataset( train_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], train_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) valid_dataset = PatchWSIDataset( valid_json_info_list, cfg["region_size"], cfg["grid_shape"], cfg["patch_size"], valid_preprocess, image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM", ) # __________________________________________________________________________ # DataLoaders train_dataloader = DataLoader(train_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) valid_dataloader = DataLoader(valid_dataset, num_workers=cfg["num_workers"], batch_size=cfg["batch_size"], pin_memory=True) # __________________________________________________________________________ # Get sample batch and some info first_sample = first(train_dataloader) if first_sample is None: raise ValueError("Fist sample is None!") print("image: ") print(" shape", first_sample["image"].shape) print(" type: ", type(first_sample["image"])) print(" dtype: ", first_sample["image"].dtype) print("labels: ") print(" shape", first_sample["label"].shape) print(" type: ", type(first_sample["label"])) print(" dtype: ", first_sample["label"].dtype) print(f"batch size: {cfg['batch_size']}") print(f"train number of batches: {len(train_dataloader)}") print(f"valid number of batches: {len(valid_dataloader)}") # -------------------------------------------------------------------------- # Deep Learning Classification Model # -------------------------------------------------------------------------- # __________________________________________________________________________ # initialize model model = TorchVisionFCModel("resnet18", num_classes=1, use_conv=True, pretrained=cfg["pretrain"]) model = model.to(device) # loss function loss_func = torch.nn.BCEWithLogitsLoss() loss_func = loss_func.to(device) # optimizer if cfg["novograd"]: optimizer = Novograd(model.parameters(), cfg["lr"]) else: optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9) # AMP scaler if cfg["amp"]: cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= ( 1, 6) else False else: cfg["amp"] = False scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["n_epochs"]) # -------------------------------------------- # Ignite Trainer/Evaluator # -------------------------------------------- # Evaluator val_handlers = [ CheckpointSaver(save_dir=log_dir, save_dict={"net": model}, save_key_metric=True), StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: None), ] val_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) evaluator = SupervisedEvaluator( device=device, val_data_loader=valid_dataloader, network=model, postprocessing=val_postprocessing, key_val_metric={ "val_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=cfg["amp"], ) # Trainer train_handlers = [ LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), CheckpointSaver(save_dir=cfg["logdir"], save_dict={ "net": model, "opt": optimizer }, save_interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ValidationHandler(validator=evaluator, interval=1, epoch_level=True), TensorBoardStatsHandler(log_dir=cfg["logdir"], tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] train_postprocessing = Compose([ ActivationsD(keys="pred", sigmoid=True), AsDiscreteD(keys="pred", threshold=0.5) ]) trainer = SupervisedTrainer( device=device, max_epochs=cfg["n_epochs"], train_data_loader=train_dataloader, network=model, optimizer=optimizer, loss_function=loss_func, postprocessing=train_postprocessing, key_train_metric={ "train_acc": Accuracy(output_transform=from_engine(["pred", "label"])) }, train_handlers=train_handlers, amp=cfg["amp"], ) trainer.run()
def train_key_metric(self, context: Context): return { self.TRAIN_KEY_METRIC: MeanDice(output_transform=from_engine(["pred", "label"])) }
def train(data_folder=".", model_folder="runs", continue_training=False): """run a training pipeline.""" #/== files for synthesis path_parent = Path( '/content/drive/My Drive/Datasets/covid19/COVID-19-20_augs_cea/') path_synthesis = Path( path_parent / 'CeA_BASE_grow=1_bg=-1.00_step=-1.0_scale=-1.0_seed=1.0_ch0_1=-1_ch1_16=-1_ali_thr=0.1' ) scans_syns = os.listdir(path_synthesis) decreasing_sequence = get_decreasing_sequence(255, splits=20) keys2 = ("image", "label", "synthetic_lesion") # READ THE SYTHETIC HEALTHY TEXTURE path_synthesis_old = '/content/drive/My Drive/Datasets/covid19/results/cea_synthesis/patient0/' texture_orig = np.load(f'{path_synthesis_old}texture.npy.npz') texture_orig = texture_orig.f.arr_0 texture = texture_orig + np.abs(np.min(texture_orig)) + .07 texture = np.pad(texture, ((100, 100), (100, 100)), mode='reflect') print(f'type(texture) = {type(texture)}, {np.shape(texture)}') #==/ images = sorted(glob.glob(os.path.join(data_folder, "*_ct.nii.gz"))) #[:20] # XX labels = sorted(glob.glob(os.path.join(data_folder, "*_seg.nii.gz"))) #[:20] # XX logging.info( f"training: image/label ({len(images)}) folder: {data_folder}") amp = True # auto. mixed precision keys = ("image", "label") train_frac, val_frac = 0.8, 0.2 n_train = int(train_frac * len(images)) + 1 n_val = min(len(images) - n_train, int(val_frac * len(images))) logging.info( f"training: train {n_train} val {n_val}, folder: {data_folder}") train_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[:n_train], labels[:n_train])] val_files = [{ keys[0]: img, keys[1]: seg } for img, seg in zip(images[-n_val:], labels[-n_val:])] # create a training data loader batch_size = 2 # XX should be 2 logging.info(f"batch size {batch_size}") GEN = np.random.randint(5, 45) train_transforms = get_xforms("synthesis", keys, keys2, path_synthesis, decreasing_sequence, scans_syns, texture, GEN) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms) train_loader = monai.data.DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available(), ) # create a validation data loader val_transforms = get_xforms("val", keys) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader( val_ds, batch_size= 1, # image-level batch to the sliding window method, not the window-level batch num_workers=2, pin_memory=torch.cuda.is_available(), ) # create BasicUNet, DiceLoss and Adam optimizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = get_net().to(device) # if continue training if continue_training: ckpts = sorted(glob.glob(os.path.join(model_folder, "*.pt"))) # ckpts = glob.glob(os.path.join(model_folder, "*.pt")) # XX should use sorted() to take the best model ckpt = ckpts[-1] logging.info(f"continue training using {ckpt}.") net.load_state_dict(torch.load(ckpt, map_location=device)) max_epochs, lr, momentum = 20, 1e-4, 0.95 # max_epochs, lr, momentum = 500, 1e-4, 0.95 logging.info(f"epochs {max_epochs}, lr {lr}, momentum {momentum}") opt = torch.optim.Adam(net.parameters(), lr=lr) # create evaluator (to be used to measure model quality during training val_post_transform = monai.transforms.Compose([ EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2) ]) val_handlers = [ ProgressBar(), CheckpointSaver(save_dir=model_folder, save_dict={"net": net}, save_key_metric=True, key_metric_n_saved=10), #key_metric_n_saved=3 ] evaluator = monai.engines.SupervisedEvaluator( device=device, val_data_loader=val_loader, network=net, inferer=get_inferer(), postprocessing=val_post_transform, key_val_metric={ "val_mean_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"])) }, val_handlers=val_handlers, amp=amp, ) # evaluator as an event handler of the trainer train_handlers = [ ValidationHandler(validator=evaluator, interval=1, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), ] trainer = monai.engines.SupervisedTrainer( device=device, max_epochs=max_epochs, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=DiceCELoss(), inferer=get_inferer(), key_train_metric=None, train_handlers=train_handlers, amp=amp, ) trainer.run()
def val_key_metric(self, context): return { self.VAL_KEY_METRIC: MeanDice(output_transform=from_engine(["pred", "label"])) }
def train(args): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") if idist.get_local_rank() == 0 and not os.path.exists(args.dir): # create 40 random image, mask paris for training print(f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) idist.barrier() images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys="image"), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), EnsureTyped(keys=["image", "label"]), ] ) # create a training data loader train_ds = Dataset(data=train_files, transform=train_transforms) # create a training data sampler train_sampler = DistributedSampler(train_ds) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader( train_ds, batch_size=2, shuffle=False, num_workers=2, pin_memory=True, sampler=train_sampler, ) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{idist.get_local_rank()}") torch.cuda.set_device(device) net = monai.networks.nets.UNet( spatial_dims=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) # wrap the model with DistributedDataParallel module net = DistributedDataParallel(net, device_ids=[device]) train_post_transforms = Compose( [ EnsureTyped(keys="pred"), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), ] ) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ] if idist.get_rank() == 0: train_handlers.extend( [ StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2), ] ) trainer = SupervisedTrainer( device=device, max_epochs=5, train_data_loader=train_loader, network=net, optimizer=opt, loss_function=loss, inferer=SimpleInferer(), # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False, postprocessing=train_post_transforms, key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]), device=device)}, train_handlers=train_handlers, ) trainer.run() dist.destroy_process_group()