Exemplo n.º 1
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), dtype=np.uint8),
         FilterImaged(keys="image", min_size=5),
         AsChannelFirstd(keys="image"),
         AddChanneld(keys="label"),
         ToTensord(keys="image"),
         TorchVisiond(keys="image",
                      name="ColorJitter",
                      brightness=64.0 / 255.0,
                      contrast=0.75,
                      saturation=0.25,
                      hue=0.04),
         ToNumpyd(keys="image"),
         RandRotate90d(keys=("image", "label"),
                       prob=0.5,
                       spatial_axes=(0, 1)),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddInitialSeedPointExd(label="label", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys=("image", "label")),
     ]
Exemplo n.º 2
0
    def _get_loader(self, folders):
        images = []
        segs = []
        for folder in folders:
            images += glob(os.path.join(folder, "*_im.nii.gz"))
            segs += glob(os.path.join(folder, "*_seg.nii.gz"))
        images = sorted(images, key=os.path.basename)
        segs = sorted(segs, key=os.path.basename)

        files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

        transforms = Compose([
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ])

        ds = CacheDataset(data=files, transform=transforms)
        loader = DataLoader(ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

        return loader
Exemplo n.º 3
0
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     t = [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image",
                  pixdim=[1.0] * self.dimension,
                  mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image",
                                guidance="guidance",
                                dimensions=self.dimension),
     ]
     if self.dimension == 2:
         t.append(Fetch2DSliced(keys="image", guidance="guidance"))
     t.extend([
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image",
                              guidance="guidance",
                              spatial_size=self.spatial_size),
         Resized(keys="image", spatial_size=self.model_size, mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208,
                             divisor=388),  # type: ignore
         AddGuidanceSignald(image="image", guidance="guidance"),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ])
     return t
 def pre_transforms(self, data=None) -> Sequence[Callable]:
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
     ]
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 mode=["bilinear", "nearest"],
                 dtype=np.float32),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        # resampling with align_corners=True or dtype=float64 will generate
        # slight different results between PyTorch 1.5 an 1.6
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=np.float32)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            not_nans = dice_metric.not_nans.item()
            metric_count += not_nans
            metric_sum += value.item() * not_nans
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = metric_sum / metric_count
    return metric
Exemplo n.º 6
0
    def __init__(self, tranforms):
        self.tranform_list = []
        for tranform in tranforms:
            if 'LoadImaged' == tranform:
                self.tranform_list.append(LoadImaged(keys=["image", "label"]))
            elif 'AsChannelFirstd' == tranform:
                self.tranform_list.append(AsChannelFirstd(keys="image"))
            elif 'ConvertToMultiChannelBasedOnBratsClassesd' == tranform:
                self.tranform_list.append(ConvertToMultiChannelBasedOnBratsClassesd(keys="label"))
            elif 'Spacingd' == tranform:
                self.tranform_list.append(Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")))
            elif 'Orientationd' == tranform:
                self.tranform_list.append(Orientationd(keys=["image", "label"], axcodes="RAS"))
            elif 'CenterSpatialCropd' == tranform:
                self.tranform_list.append(CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]))
            elif 'NormalizeIntensityd' == tranform:
                self.tranform_list.append(NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True))
            elif 'ToTensord' == tranform:
                self.tranform_list.append(ToTensord(keys=["image", "label"]))
            elif 'Activations' == tranform:
                self.tranform_list.append(Activations(sigmoid=True))
            elif 'AsDiscrete' == tranform:
                self.tranform_list.append(AsDiscrete(threshold_values=True))
            else:
                raise ValueError(
                    f"Unsupported tranform: {tranform}. Please add it to support it."
                )

        super().__init__(self.tranform_list)
Exemplo n.º 7
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=int)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
    return metric
Exemplo n.º 8
0
 def test_shape(self, in_type, input_param, expected_shape):
     test_data = {
         "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
         "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
         "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
     }
     result = AsChannelFirstd(**input_param)(test_data)
     self.assertTupleEqual(result["image"].shape, expected_shape)
     self.assertTupleEqual(result["label"].shape, expected_shape)
     self.assertTupleEqual(result["extra"].shape, expected_shape)
Exemplo n.º 9
0
def _default_transforms(image_key, label_key, pixdim):
    keys = [image_key] if label_key is None else [image_key, label_key]
    mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST
            ] if len(keys) == 2 else [GridSampleMode.BILINEAR]
    return Compose([
        LoadImaged(keys=keys),
        AsChannelFirstd(keys=keys),
        Spacingd(keys=keys, pixdim=pixdim, mode=mode),
        Orientationd(keys=keys, axcodes="RAS"),
    ])
Exemplo n.º 10
0
 def test_shape(self, input_param, expected_shape):
     test_data = {
         'image': np.random.randint(0, 2, size=[1, 2, 3, 4]),
         'label': np.random.randint(0, 2, size=[1, 2, 3, 4]),
         'extra': np.random.randint(0, 2, size=[1, 2, 3, 4])
     }
     result = AsChannelFirstd(**input_param)(test_data)
     self.assertTupleEqual(result['image'].shape, expected_shape)
     self.assertTupleEqual(result['label'].shape, expected_shape)
     self.assertTupleEqual(result['extra'].shape, expected_shape)
Exemplo n.º 11
0
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image",
                         conversion="RGB",
                         dtype=np.uint8,
                         padding=False),
         AsChannelFirstd(keys="image"),
         AddClickSignalsd(image="image"),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
Exemplo n.º 12
0
def _get_transforms(keys, pixdim):
    mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST
            ] if len(keys) == 2 else [GridSampleMode.BILINEAR]
    transforms = [
        LoadImaged(keys=keys),
        AsChannelFirstd(keys=keys),
        Spacingd(keys=keys, pixdim=pixdim, mode=mode),
        Orientationd(keys=keys, axcodes="RAS"),
    ]

    return Compose(transforms)
Exemplo n.º 13
0
 def pre_transforms(self, data):
     return [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0], mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image", guidance="guidance", dimensions=3),
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image", guidance="guidance", spatial_size=self.spatial_size),
         Resized(keys="image", spatial_size=self.model_size, mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208, divisor=388),
         AddGuidanceSignald(image="image", guidance="guidance"),
     ]
Exemplo n.º 14
0
 def _get_predictions_iterator(self, segs):
     files = [{"seg": seg} for seg in segs]
     transforms = Compose([
         LoadImaged(keys=["seg"]),
         AsChannelFirstd(keys=["seg"], channel_dim=-1),
         ToTensord(keys=["seg"]),
     ])
     ds = CacheDataset(data=files, transform=transforms)
     loader = DataLoader(ds,
                         batch_size=1,
                         num_workers=4,
                         collate_fn=list_data_collate)
     for data in loader:
         yield (data["seg"], data["seg_meta_dict"])
Exemplo n.º 15
0
 def pre_transforms(self, data=None):
     return [
         LoadImagePatchd(keys="image", conversion="RGB", dtype=np.uint8),
         FilterImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         ScaleIntensityRangeD(keys="image",
                              a_min=0.0,
                              a_max=255.0,
                              b_min=-1.0,
                              b_max=1.0),
         AddClickGuidanced(image="image", guidance="guidance"),
         AddGuidanceSignald(image="image",
                            guidance="guidance",
                            number_intensity_ch=3),
         EnsureTyped(keys="image",
                     device=data.get("device") if data else None),
     ]
Exemplo n.º 16
0
 def pre_transforms(self, data=None):
     return [
         LoadImaged(keys="image"),
         AsChannelFirstd(keys="image"),
         Spacingd(keys="image", pixdim=[1.0, 1.0], mode="bilinear"),
         AddGuidanceFromPointsd(ref_image="image",
                                guidance="guidance",
                                dimensions=2),
         Fetch2DSliced(keys="image", guidance="guidance"),
         AddChanneld(keys="image"),
         SpatialCropGuidanced(keys="image",
                              guidance="guidance",
                              spatial_size=[256, 256]),
         Resized(keys="image", spatial_size=[256, 256], mode="area"),
         ResizeGuidanced(guidance="guidance", ref_image="image"),
         NormalizeIntensityd(keys="image", subtrahend=208,
                             divisor=388),  # type: ignore
         AddGuidanceSignald(image="image", guidance="guidance"),
     ]
Exemplo n.º 17
0
    def _get_loader(self, folders):
        images = []
        segs = []
        for folder in folders:
            images += glob(os.path.join(folder, "*_im.nii.gz"))
            segs += glob(os.path.join(folder, "*_seg.nii.gz"))
        images = sorted(images, key=os.path.basename)
        segs = sorted(segs, key=os.path.basename)

        files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

        transforms = Compose([
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(keys=["img", "seg"],
                                   label_key="seg",
                                   spatial_size=[96, 96, 96],
                                   pos=1,
                                   neg=1,
                                   num_samples=4),
            RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ])

        ds = CacheDataset(data=files, transform=transforms)
        loader = DataLoader(
            ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
            collate_fn=list_data_collate,
            pin_memory=torch.cuda.is_available(),
        )

        return loader
Exemplo n.º 18
0
def run_training_test(root_dir, device="cuda:0"):
    real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    train_files = [{"reals": img} for img in zip(real_images)]

    # prepare real data
    train_transforms = Compose([
        LoadNiftid(keys=["reals"]),
        AsChannelFirstd(keys=["reals"]),
        ScaleIntensityd(keys=["reals"]),
        RandFlipd(keys=["reals"], prob=0.5),
        ToTensord(keys=["reals"]),
    ])
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)

    learning_rate = 2e-4
    betas = (0.5, 0.999)
    real_label = 1
    fake_label = 0

    # create discriminator
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)
    disc_net.apply(normal_init)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    disc_loss_criterion = torch.nn.BCELoss()

    def discriminator_loss(gen_images, real_images):
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
        return torch.div(torch.add(realloss, genloss), 2)

    # create generator
    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])
    gen_net.apply(normal_init)
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)
    gen_loss_criterion = torch.nn.BCELoss()

    def generator_loss(gen_images):
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    key_train_metric = None

    train_handlers = [
        StatsHandler(
            name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        TensorBoardStatsHandler(
            log_dir=root_dir,
            tag_name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "g_net": gen_net,
                            "d_net": disc_net
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    disc_train_steps = 2
    num_epochs = 5

    trainer = GanTrainer(
        device,
        num_epochs,
        train_loader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=train_handlers,
    )
    trainer.run()

    return trainer.state
Exemplo n.º 19
0
def run_training_test(root_dir, device="cuda:0", cachedataset=0):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model))
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += not_nans.item()
                    metric_sum += value.item() * not_nans.item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds, shuffle=False)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    net = DistributedDataParallel(net, device_ids=[device])

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        CheckpointLoader(
            load_path="./runs/checkpoint_epoch=4.pt",
            load_dict={"net": net},
            # config mapping to expected GPU device
            map_location={"cuda:0": f"cuda:{args.local_rank}"},
        ),
    ]
    if dist.get_rank() == 0:
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        val_handlers.extend(
            [
                StatsHandler(output_transform=lambda x: None),
                SegmentationSaver(
                    output_dir="./runs/",
                    batch_transform=lambda batch: batch["image_meta_dict"],
                    output_transform=lambda output: output["pred"],
                ),
            ]
        )

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(
                include_background=True,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=device,
            )
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]), device=device)},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
    dist.destroy_process_group()
Exemplo n.º 21
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        Keys.IMAGE: img,
        Keys.LABEL: seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        RandCropByPosNegLabeld(keys=[Keys.IMAGE, Keys.LABEL],
                               label_key=Keys.LABEL,
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=[Keys.IMAGE, Keys.LABEL],
                      prob=0.5,
                      spatial_axes=[0, 2]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=[Keys.IMAGE, Keys.LABEL]),
        AsChannelFirstd(keys=[Keys.IMAGE, Keys.LABEL], channel_dim=-1),
        ScaleIntensityd(keys=[Keys.IMAGE, Keys.LABEL]),
        ToTensord(keys=[Keys.IMAGE, Keys.LABEL]),
    ])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    val_handlers = [StatsHandler(output_transform=lambda x: None)]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        val_handlers=val_handlers,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     add_sigmoid=True,
                     output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL]))
        },
        additional_metrics=None,
    )

    train_handlers = [
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x[Keys.INFO][Keys.LOSS]),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        train_handlers=train_handlers,
        amp=False,
        key_train_metric=None,
    )
    trainer.run()

    shutil.rmtree(tempdir)
Exemplo n.º 22
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)

            if dist.get_rank() == 0:
                writer.add_scalar("Mean Dice/val", metric, epoch)
                writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
                writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
                writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()
Exemplo n.º 23
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    ################################ DATASET ################################
    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["image", "label"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    ################################ DATASET ################################
    
    ################################ NETWORK ################################
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    ################################ NETWORK ################################
    
    ################################ LOSS ################################    
    loss = monai.losses.DiceLoss(sigmoid=True)
    ################################ LOSS ################################
    
    ################################ OPT ################################
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    ################################ OPT ################################
    
    ################################ LR ################################
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
    ################################ LR ################################
    
    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
        TensorBoardImageHandler(
            log_dir="./runs/",
            batch_transform=lambda x: (x["image"], x["label"]),
            output_transform=lambda x: x["pred"],
        ),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False,
    )
    trainer.run()
Exemplo n.º 24
0
def evaluate(args):
    if args.local_rank == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed evaluation process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True,
                            sampler=val_sampler)
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])
    # config mapping to expected GPU device
    map_location = {"cuda:0": f"cuda:{args.local_rank}"}
    # load model parameters to GPU device
    model.load_state_dict(
        torch.load("final_model.pth", map_location=map_location))

    model.eval()
    with torch.no_grad():
        # define PyTorch Tensor to record metrics result at each GPU
        # the first value is `sum` of all dice metric, the second value is `count` of not_nan items
        metric = torch.zeros(2, dtype=torch.float, device=device)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze()
            metric[0] += value * dice_metric.not_nans
            metric[1] += dice_metric.not_nans
        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
        metric = metric.tolist()
        if dist.get_rank() == 0:
            print("evaluation metric:", metric[0] / metric[1])
        dist.destroy_process_group()
Exemplo n.º 25
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(
        check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(
        val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(
                        y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    shutil.rmtree(tempdir)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
Exemplo n.º 26
0
def run_training_test(root_dir, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        RandCropByPosNegLabeld(keys=["image", "label"],
                               label_key="label",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=root_dir,
                                output_transform=lambda x: None),
        TensorBoardImageHandler(log_dir=root_dir,
                                batch_transform=lambda x:
                                (x["image"], x["label"]),
                                output_transform=lambda x: x["pred"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=root_dir,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        amp=True if amp else False,
    )
    trainer.run()

    return evaluator.state.best_metric
Exemplo n.º 27
0
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
        SegmentationSaver(
            output_dir=root_dir,
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96),
                                     sw_batch_size=4,
                                     overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        amp=True if amp else False,
    )
    evaluator.run()

    return evaluator.state.best_metric
Exemplo n.º 28
0
                                   channel_dim=-1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)]

# define transforms for image and segmentation
val_transforms = Compose([
    LoadNiftid(keys=['img', 'seg']),
    AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
    ScaleIntensityd(keys=['img', 'seg']),
    ToTensord(keys=['img', 'seg'])
])
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

device = torch.device("cuda:0")
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
net.to(device)
Exemplo n.º 29
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))
        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]

    # model file path
    model_file = glob("./runs/net_key_metric*")[0]

    # define transforms for image and segmentation
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    val_post_transforms = Compose(
        [
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ]
    )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointLoader(load_path=model_file, load_dict={"net": net}),
        SegmentationSaver(
            output_dir="./runs/",
            batch_transform=lambda batch: batch["image_meta_dict"],
            output_transform=lambda output: output["pred"],
        ),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )
    evaluator.run()
Exemplo n.º 30
0
def train(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    elif not os.path.exists(args.dir):
        # create 40 random image, mask paris for training
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # partition dataset based on current rank number, every rank trains with its own data
    # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch
    data_part = partition_dataset(
        data=train_files,
        num_partitions=dist.get_world_size(),
        shuffle=True,
        even_divisible=True,
    )[dist.get_rank()]

    train_ds = SmartCacheDataset(
        data=data_part,
        transform=train_transforms,
        replace_rate=0.2,
        cache_num=
        15,  # we suppose to use 2 ranks in this example, every rank has 20 training images
        num_init_workers=2,
        num_replace_workers=2,
    )
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[device])

    # start a typical PyTorch training
    epoch_loss_values = list()
    # start the replacement thread of SmartCache
    train_ds.start()

    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        # replace 20% of cache content for next epoch
        train_ds.update_cache()
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    # stop replacement thread of SmartCache
    train_ds.shutdown()
    print(f"train completed, epoch losses: {epoch_loss_values}")
    if dist.get_rank() == 0:
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
    dist.destroy_process_group()