def test_tb_image_shape(self, shape):
     with tempfile.TemporaryDirectory() as tempdir:
         writer = SummaryWriter(log_dir=tempdir)
         plot_2d_or_3d_image(torch.zeros(shape), 0, writer)
         writer.flush()
         writer.close()
         self.assertTrue(len(glob.glob(tempdir)) > 0)
示例#2
0
 def test_tbx_video(self, shape):
     with tempfile.TemporaryDirectory() as tempdir:
         writer = SummaryWriterX(log_dir=tempdir)
         plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3)
         writer.flush()
         writer.close()
         self.assertTrue(len(glob.glob(tempdir)) > 0)
示例#3
0
    def test_tb_image_shape(self, shape):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter(log_dir=tempdir))

        self.assertTrue(os.path.exists(tempdir))
        self.assertTrue(len(glob.glob(tempdir)) > 0)
        shutil.rmtree(tempdir, ignore_errors=True)
    def test_tb_image_shape(self, shape):
        default_dir = os.path.join('.', 'runs')
        shutil.rmtree(default_dir, ignore_errors=True)

        plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter())

        self.assertTrue(os.path.exists(default_dir))
        self.assertTrue(len(glob.glob(default_dir)) > 0)
        shutil.rmtree(default_dir, ignore_errors=True)
示例#5
0
    def write_images(self, epoch):
        if not self.plot_data or not len(self.plot_data):
            return

        all_imgs = []
        for region in sorted(self.plot_data.keys()):
            metric = self.metric_data.get(region)
            region_data = self.plot_data[region]
            if len(region_data[0].shape) == 3:
                ti = Image.new("RGB", region_data[0].shape[1:])
                d = ImageDraw.Draw(ti)
                t = "region: {}".format(region)
                if self.compute_metric:
                    t = t + "\ndice: {:.4f}".format(metric.mean())
                    t = t + "\nstdev: {:.4f}".format(metric.stdev())
                d.multiline_text((10, 10), t, fill=(255, 255, 0))
                ti = rescale_array(
                    np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis])
                all_imgs.append(ti)
            all_imgs.extend(region_data)

        if len(all_imgs[0].shape) == 3:
            img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)),
                                   nrow=4,
                                   normalize=True,
                                   pad_value=2)
            self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})",
                                  img_tensor=img_tensor,
                                  global_step=epoch)

        if len(all_imgs[0].shape) == 4:
            for region in sorted(self.plot_data.keys()):
                tags = [
                    f"region_{region}_image", f"region_{region}_label",
                    f"region_{region}_output"
                ]
                if torch.distributed.is_initialized():
                    rank = "r{}-".format(torch.distributed.get_rank())
                    tags = [rank + tags[0], rank + tags[1], rank + tags[2]]
                for i in range(3):
                    img = self.plot_data[region][i]
                    img = np.moveaxis(img, -3, -1)
                    plot_2d_or_3d_image(img[np.newaxis], epoch, self.writer, 0,
                                        self.max_channels, self.max_frames,
                                        tags[i])

        self.logger.info(
            "Saved {} Regions {} into Tensorboard at epoch: {}".format(
                len(self.plot_data), sorted([*self.plot_data]), epoch))
        self.writer.flush()
示例#6
0
    def __call__(self, engine):
        step = self.global_iter_transform(engine.state.iteration)

        show_images = self.batch_transform(engine.state.batch)[0]
        if torch.is_tensor(show_images):
            show_images = show_images.detach().cpu().numpy()
        if show_images is not None:
            if not isinstance(show_images, np.ndarray):
                raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.')
            plot_2d_or_3d_image(show_images, step, self._writer, self.index,
                                self.max_channels, self.max_frames, 'input_0')

        show_labels = self.batch_transform(engine.state.batch)[1]
        if torch.is_tensor(show_labels):
            show_labels = show_labels.detach().cpu().numpy()
        if show_labels is not None:
            if not isinstance(show_labels, np.ndarray):
                raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.')
            plot_2d_or_3d_image(show_labels, step, self._writer, self.index,
                                self.max_channels, self.max_frames, 'input_1')

        show_outputs = self.output_transform(engine.state.output)
        if torch.is_tensor(show_outputs):
            show_outputs = show_outputs.detach().cpu().numpy()
        if show_outputs is not None:
            if not isinstance(show_outputs, np.ndarray):
                raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.')
            plot_2d_or_3d_image(show_outputs, step, self._writer, self.index,
                                self.max_channels, self.max_frames, 'output')

        self._writer.flush()
示例#7
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        Raises:
            TypeError: When ``output_transform(engine.state.output)[0]`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.
            TypeError: When ``batch_transform(engine.state.batch)[1]`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.
            TypeError: When ``output_transform(engine.state.output)`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.

        """
        step = self.global_iter_transform(
            engine.state.epoch if self.epoch_level else engine.state.iteration)
        show_images = self.batch_transform(engine.state.batch)[0]
        if torch.is_tensor(show_images):
            show_images = show_images.detach().cpu().numpy()
        if show_images is not None:
            if not isinstance(show_images, np.ndarray):
                raise TypeError(
                    "output_transform(engine.state.output)[0] must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}."
                )
            plot_2d_or_3d_image(show_images, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "input_0")

        show_labels = self.batch_transform(engine.state.batch)[1]
        if torch.is_tensor(show_labels):
            show_labels = show_labels.detach().cpu().numpy()
        if show_labels is not None:
            if not isinstance(show_labels, np.ndarray):
                raise TypeError(
                    "batch_transform(engine.state.batch)[1] must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}."
                )
            plot_2d_or_3d_image(show_labels, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "input_1")

        show_outputs = self.output_transform(engine.state.output)
        if torch.is_tensor(show_outputs):
            show_outputs = show_outputs.detach().cpu().numpy()
        if show_outputs is not None:
            if not isinstance(show_outputs, np.ndarray):
                raise TypeError(
                    "output_transform(engine.state.output) must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}."
                )
            plot_2d_or_3d_image(show_outputs, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "output")

        self._writer.flush()
示例#8
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list
    # Need to concatenate multiple channels here if you want multichannel segmentation
    # Check other examples on Monai webpage.

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'],  gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # # try to use all the available GPUs
    # devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)
    loss_function = monai.losses.DiceCELoss(sigmoid=True)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = post_trans(val_outputs)
                        value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
示例#9
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            ScaleIntensityd(keys=["img", "seg"]),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(
        check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(
        val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
    )

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(
                        y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    shutil.rmtree(tempdir)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
示例#10
0
            g_loss.backward()
            optimizerG.step()

            G_losses.append(g_loss.item())
            lossC01s.append(lossC01.item())
            lossC02s.append(lossC02.item())
            lossC03s.append(lossC03.item())

            # log to tensorboard every 10 steps
            if batch_index % 1 == 0:
                writer.add_scalar('Train/G_loss', g_loss.item(), epoch)
                writer.add_scalar('Train/D_loss', d_loss.item(), epoch)
                plot_2d_or_3d_image(targetC01,
                                    epoch * len(training_loader) + batch_index,
                                    writer,
                                    index=0,
                                    tag="Groundtruth_train/C01")
                plot_2d_or_3d_image(targetC02,
                                    epoch * len(training_loader) + batch_index,
                                    writer,
                                    index=0,
                                    tag="Groundtruth_train/C01")
                plot_2d_or_3d_image(targetC03,
                                    epoch * len(training_loader) + batch_index,
                                    writer,
                                    index=0,
                                    tag="Groundtruth_train/C01")
                plot_2d_or_3d_image(outputC01,
                                    epoch * len(training_loader) + batch_index,
                                    writer,
示例#11
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate)
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = [
                        post_trans(i) for i in decollate_batch(val_outputs)
                    ]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
示例#12
0
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), 'best_metric_model.pth')
                print('saved new best metric model')
            print(
                "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                % (epoch + 1, metric, best_metric, best_metric_epoch))
            writer.add_scalar('val_mean_dice', metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images,
                                epoch + 1,
                                writer,
                                index=0,
                                tag='image')
            plot_2d_or_3d_image(val_labels,
                                epoch + 1,
                                writer,
                                index=0,
                                tag='label')
            plot_2d_or_3d_image(val_outputs,
                                epoch + 1,
                                writer,
                                index=0,
                                tag='output')
shutil.rmtree(tempdir)
print('train completed, best_metric: %0.4f  at epoch: %d' %
      (best_metric, best_metric_epoch))
示例#13
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    # --------------- Dataset  ---------------
    
    datadir1 = "/home1/quanquan/datasets/lsw/benign_65/fpAML_55/slices/"
    image_files = np.array([x.path for x in os.scandir(datadir1+"image") if x.name.endswith(".npy")])
    label_files = np.array([x.path for x in os.scandir(datadir1+"label") if x.name.endswith(".npy")])
    image_files.sort()
    label_files.sort()
    # --- ??? what's up ???
    train_files = [{"img":img, "seg":seg} for img, seg in zip(image_files[:-20], label_files[:-20])]
    val_files   = [{"img":img, "seg":seg} for img, seg in zip(image_files[-20:], label_files[-20:])]
    # print("files", train_files[:20])
    # print(val_files)
   
    val_imtrans = Compose([LoadNumpy(data_only=True), ScaleIntensity(), AddChannel(), ToTensor()])
    val_segtrans = Compose([LoadNumpy(data_only=True), AddChannel(), ToTensor()])
    
    # define array dataset, data loader
    check_ds = ArrayDataset(image_files, train_imtrans, label_files, train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)
    
    # create a training data loader
    train_ds = ArrayDataset(image_files[:-20], train_imtrans, label_files[:-20], train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(image_files[-20:], val_imtrans, label_files[-20:], val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    
    # ---------------  model  ---------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    
    # ---------------  loss function  ---------------
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    # writer = SummaryWriter(logdir=tdir(output_dir, "sumamry"))
    
    # -------------------  Training ----------------------
    for epoch in range(max_epoch):
        # print("-" * 10)
        # print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"\r\t Training batch: {step}/{epoch_len}, \ttrain_loss: {loss.item():.4f}\t", end="")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"\n\tepoch {epoch + 1} \taverage loss: {epoch_loss:.4f}")
        
        # -------------------  Save Model  ----------------------
        if epoch % 5 == 0:
            def get_lr(optimizer):
                for param_group in optimizer.param_groups:
                    return float(param_group['lr'])
                  
            state = {'epoch': epoch + 1,
                     'lr': get_lr(optimizer),
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()
                     }
            torch.save(state, tfilename(output_dir, "model", "{}_{}.pkl".format("lsw_monai_simple", epoch)))
    
        # -------------------  Evaluation  -----------------------
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    # value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
def main():

    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.')
    parser.add_argument('--config', dest='config', metavar='config', type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    sliding_window_validation = config_info['training']['sliding_window_validation']
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    # model saving
    out_model_dir = os.path.join(config_info['output']['out_model_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    max_nr_models_saved = config_info['output']['max_nr_models_saved']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)

    """
    Data Preparation
    """
    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
        RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True, num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_train_data = monai.utils.misc.first(train_loader)
    print("Training data tensor shapes")
    print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    check_valid_data = monai.utils.misc.first(val_loader)
    print("Validation data tensor shapes")
    print(check_valid_data['img'].shape, check_valid_data['seg'].shape)

    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()

    """
    Training loop
    """
    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    net.to(device)
    for epoch in range(nr_train_epochs):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs))
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device)
            opt.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            opt.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item()))
            writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

        if (epoch + 1) % validation_every_n_epochs == 0:
            net.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                check_tot_validation = 0
                for val_data in val_loader:
                    check_tot_validation += 1
                    val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device)
                    if sliding_window_validation:
                        print('Running sliding window validation')
                        roi_size = (96, 96, 1)
                        val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net)
                        value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                                 to_onehot_y=False, add_sigmoid=True)
                        metric_count += len(value)
                        metric_sum += value.sum().item()
                    else:
                        print('Running 2D validation')
                        # compute validation
                        val_outputs = net(val_images)
                        value = 1.0 - loss_function(val_outputs, val_labels)
                        metric_count += 1
                        metric_sum += value.item()
                print("Total number of data in validation: %d" % check_tot_validation)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth'))
                    print('saved new best metric model')
                print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                      % (epoch + 1, metric, best_metric, best_metric_epoch))
                epoch_len = len(train_ds) // train_loader.batch_size
                writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step)
                writer_valid.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image')
                plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label')
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output')

    print('train completed, best_metric: %0.4f  at epoch: %d' % (best_metric, best_metric_epoch))
    writer_train.close()
    writer_valid.close()
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # check gpus
    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation
    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):
    
        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False)

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # build the network
    if opt.network is 'nnunet':
        net = build_net()  # nn build_net
    elif opt.network is 'unetr':
        net = build_UNETR() # UneTR
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    loss_function = monai.losses.DiceCELoss(sigmoid=True)
    torch.backends.cudnn.benchmark = opt.benchmark


    if opt.network is 'nnunet':

        optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
        net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)

    elif opt.network is 'unetr':

        optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        if opt.network is 'nnunet':
            update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    # aggregate the final mean dice result
                    metric = dice_metric.aggregate().item()
                    # reset the status for next validation round
                    dice_metric.reset()

                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
    def __call__(self, engine: Engine):
        step = self.global_iter_transform(
            engine.state.epoch if self.epoch_level else engine.state.iteration)
        image_tensor = self.batch_transform(engine.state.batch)[0]
        label_tensor = self.batch_transform(engine.state.batch)[1]

        if image_tensor is not None:
            show_images = image_tensor[self.index]
            if torch.is_tensor(show_images):
                show_images = show_images.detach().cpu().numpy()
            if show_images is not None:
                if not isinstance(show_images,
                                  (np.ndarray, torch.Tensor, list, tuple)):
                    raise TypeError(
                        "output_transform(engine.state.output)[0] must be None or one of "
                        f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}."
                    )
                plot_2d_or_3d_image(
                    # add batch dim and plot the first item
                    data=show_images[None],
                    step=step,
                    writer=self._writer,
                    index=0,
                    max_channels=self.max_channels,
                    frame_dim=self.frame_dim,
                    max_frames=self.max_frames,
                    tag=self.prefix_name + "/input_0",
                )

        if label_tensor is not None:
            show_labels = label_tensor[self.index]
            if isinstance(show_labels, torch.Tensor):
                show_labels = show_labels.detach().cpu().numpy()
            if show_labels is not None:
                if not isinstance(show_labels, np.ndarray):
                    raise TypeError(
                        "batch_transform(engine.state.batch)[1] must be None or one of "
                        f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}."
                    )
                plot_2d_or_3d_image(
                    data=show_labels[None],
                    step=step,
                    writer=self._writer,
                    index=0,
                    max_channels=self.max_channels,
                    frame_dim=self.frame_dim,
                    max_frames=self.max_frames,
                    tag=self.prefix_name + "/input_1",
                )

        if self.output_transform(engine.state.output) is not None:
            show_outputs = self.output_transform(
                engine.state.output)[self.index]
            # ! tmp solution to handle multi-inputs
            if isinstance(show_outputs, (list, tuple)):
                show_outputs = show_outputs[0]

            if isinstance(show_outputs, torch.Tensor):
                show_outputs = show_outputs.detach().cpu().numpy()
            if show_outputs is not None:
                if not isinstance(show_outputs, np.ndarray):
                    raise TypeError(
                        "output_transform(engine.state.output) must be None or one of "
                        f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}."
                    )
                plot_2d_or_3d_image(
                    data=show_outputs[None],
                    step=step,
                    writer=self._writer,
                    index=0,
                    max_channels=self.max_channels,
                    frame_dim=self.frame_dim,
                    max_frames=self.max_frames,
                    tag=self.prefix_name + "/output",
                )

        self._writer.flush()
示例#17
0
def run_training_test(root_dir, device="cuda:0", cachedataset=0):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model))
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += not_nans.item()
                    metric_sum += value.item() * not_nans.item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
def train():
    """

    :return:
    """
    print('Model training started')
    set_determinism(seed=0)

    epoch_num = params['nb_epoch']
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        writer.add_scalar("epoch_loss", epoch_loss, epoch + 1)

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = model(val_inputs)
                    value = compute_meandice(
                        y_pred=val_outputs,
                        y=val_labels,
                        include_background=False,
                        to_onehot_y=True,
                        mutually_exclusive=True,
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model,
                        os.path.join('saved_models', "best_metric_model.pth"))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last validation subject as GIF in TensorBoard with the corresponding image, label and pred
                val_pred = torch.argmax(val_outputs, dim=1, keepdim=True)
                summary_img = torch.cat((val_inputs, val_labels, val_pred),
                                        dim=3)
                plot_2d_or_3d_image(summary_img,
                                    epoch + 1,
                                    writer,
                                    tag='last_val_subject')

        # Model checkpointing
        if (epoch + 1) % 20 == 0:
            torch.save(
                model,
                os.path.join('saved_models',
                             params['f_name'] + '_' + str(epoch + 1) + '.pth'))

    print(
        f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
    )
    writer.close()
示例#19
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    val_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])
    val_segtrans = Compose(
        [LoadImage(image_only=True),
         AddChannel(), ToTensor()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20],
                            train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
示例#20
0
def run_training_test(root_dir,
                      device=torch.device("cuda:0"),
                      cachedataset=False):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, 'img*.nii.gz')))
    segs = sorted(glob(os.path.join(root_dir, 'seg*.nii.gz')))
    train_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        RandCropByPosNegLabeld(keys=['img', 'seg'],
                               label_key='seg',
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=['img', 'seg'])
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        ToTensord(keys=['img', 'seg'])
    ])

    # create a training data loader
    if cachedataset:
        train_ds = monai.data.CacheDataset(data=train_files,
                                           transform=train_transforms,
                                           cache_rate=0.8)
    else:
        train_ds = monai.data.Dataset(data=train_files,
                                      transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, 'runs'))
    model_filename = os.path.join(root_dir, 'best_metric_model.pth')
    for epoch in range(6):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch + 1, 6))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(
                device), batch_data['seg'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data['img'].to(
                        device), val_data['seg'].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=True,
                                             to_onehot_y=False,
                                             add_sigmoid=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print('saved new best metric model')
                print(
                    "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                    % (epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='image')
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='label')
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='output')
    print('train completed, best_metric: %0.4f  at epoch: %d' %
          (best_metric, best_metric_epoch))
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch