def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # check gpus
    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation
    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):
    
        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False)

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # build the network
    if opt.network is 'nnunet':
        net = build_net()  # nn build_net
    elif opt.network is 'unetr':
        net = build_UNETR() # UneTR
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    loss_function = monai.losses.DiceCELoss(sigmoid=True)
    torch.backends.cudnn.benchmark = opt.benchmark


    if opt.network is 'nnunet':

        optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
        net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)

    elif opt.network is 'unetr':

        optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        if opt.network is 'nnunet':
            update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    # aggregate the final mean dice result
                    metric = dice_metric.aggregate().item()
                    # reset the status for next validation round
                    dice_metric.reset()

                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
    def test_test_time_augmentation(self):
        input_size = (20, 20)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        keys = ["image", "label"]
        num_training_ims = 10
        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)

        transforms = Compose([
            AddChanneld(keys),
            RandAffined(
                keys,
                prob=1.0,
                spatial_size=(30, 30),
                rotate_range=(np.pi / 3, np.pi / 3),
                translate_range=(3, 3),
                scale_range=((0.8, 1), (0.8, 1)),
                padding_mode="zeros",
                mode=("bilinear", "nearest"),
                as_tensor_output=False,
            ),
            CropForegroundd(keys, source_key="image"),
            DivisiblePadd(keys, 4),
        ])

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds,
                                  batch_size=2,
                                  collate_fn=pad_list_data_collate)

        model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(
                    device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        post_trans = Compose([
            Activations(sigmoid=True),
            AsDiscrete(threshold_values=True),
        ])

        def inferrer_fn(x):
            return post_trans(model(x))

        tt_aug = TestTimeAugmentation(transforms,
                                      batch_size=5,
                                      num_workers=0,
                                      inferrer_fn=inferrer_fn,
                                      device=device)
        mode, mean, std, vvc = tt_aug(test_data)
        self.assertEqual(mode.shape, (1, ) + input_size)
        self.assertEqual(mean.shape, (1, ) + input_size)
        self.assertTrue(all(np.unique(mode) == (0, 1)))
        self.assertEqual((mean.min(), mean.max()), (0.0, 1.0))
        self.assertEqual(std.shape, (1, ) + input_size)
        self.assertIsInstance(vvc, float)
cfg.thickness = [32, 96]

cfg.train_aug = Compose([
    Resized(
        keys=("input", "mask"),
        spatial_size=1120,
        size_mode="longest",
        mode="bilinear",
        align_corners=False,
    ),
    SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)),
    RandFlipd(keys=("input", "mask"), prob=0.5, spatial_axis=1),
    RandAffined(
        keys=("input", "mask"),
        prob=0.5,
        rotate_range=np.pi / 14.4,
        translate_range=(70, 70),
        scale_range=(0.1, 0.1),
        as_tensor_output=False,
    ),
    RandSpatialCropd(
        keys=("input", "mask"),
        roi_size=(cfg.img_size[0], cfg.img_size[1]),
        random_size=False,
    ),
    RandScaleIntensityd(keys="input", factors=(-0.2, 0.2), prob=0.5),
    RandShiftIntensityd(keys="input", offsets=(-51, 51), prob=0.5),
    RandLambdad(keys="input", func=lambda x: 255 - x, prob=0.5),
    RandCoarseDropoutd(
        keys=("input", "mask"),
        holes=8,
        spatial_size=(1, 1),
TESTS_3D = [(
    t.__class__.__name__ +
    (" pad_list_data_collate" if collate_fn else " default_collate"), t,
    collate_fn, 3
) for collate_fn in [None, pad_list_data_collate] for t in [
    RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]),
    RandAxisFlipd(keys=KEYS, prob=0.5),
    Compose(
        [RandRotate90d(keys=KEYS, spatial_axes=(1, 2)),
         ToTensord(keys=KEYS)]),
    RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
    RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64),
    RandAffined(keys=KEYS,
                prob=0.5,
                rotate_range=np.pi,
                device=torch.device(
                    "cuda" if torch.cuda.is_available() else "cpu")),
]]

TESTS_2D = [
    (t.__class__.__name__ +
     (" pad_list_data_collate" if collate_fn else " default_collate"), t,
     collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [
         RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]),
         RandAxisFlipd(keys=KEYS, prob=0.5),
         Compose([
             RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)),
             ToTensord(keys=KEYS)
         ]),
         RandZoomd(
Exemple #5
0
    has_nib = True
else:
    _, has_nib = optional_import("nibabel")

KEYS = ["image", "label"]

TESTS = [
    (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn)
    for collate_fn in [None, pad_list_data_collate]
    for t in [
        RandFlipd(keys=KEYS, spatial_axis=[1, 2]),
        RandAxisFlipd(keys=KEYS),
        RandRotate90d(keys=KEYS, spatial_axes=(1, 2)),
        RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
        RandRotated(keys=KEYS, range_x=np.pi),
        RandAffined(keys=KEYS, rotate_range=np.pi),
    ]
]


class TestInverseCollation(unittest.TestCase):
    """Test collation for of random transformations with prob == 0 and 1."""

    def setUp(self):
        if not has_nib:
            self.skipTest("nibabel required for test_inverse")

        set_determinism(seed=0)

        im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)]
        load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)])
Exemple #6
0
TESTS_3D = [(
    t.__class__.__name__ +
    (" pad_list_data_collate" if collate_fn else " default_collate"), t,
    collate_fn, 3
) for collate_fn in [None, pad_list_data_collate] for t in [
    RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]),
    RandAxisFlipd(keys=KEYS, prob=0.5),
    RandRotate90d(keys=KEYS, spatial_axes=(1,
                                           2)),
    RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
    RandRotated(keys=KEYS, prob=0.5, range_x=np.pi),
    RandAffined(
        keys=KEYS,
        prob=0.5,
        rotate_range=np.pi,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        as_tensor_output=False,
    ),
]]

TESTS_2D = [(
    t.__class__.__name__ +
    (" pad_list_data_collate" if collate_fn else " default_collate"), t,
    collate_fn, 2
) for collate_fn in [None, pad_list_data_collate] for t in [
    RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]),
    RandAxisFlipd(keys=KEYS, prob=0.5),
    RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0,
                                                     1)),
    RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            postfix="inverted1",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        # test different nearest interpolation values
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="image",
            nearest_interp=[True, False],
            post_func=[lambda x: x + 10, lambda x: x],
            postfix="inverted2",
            collate_fn=pad_list_data_collate,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        # check the nearest inerpolation mode
        for i in engine.state.output["image_inverted1"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        for i in engine.state.output["label_inverted1"]:
            np.testing.assert_allclose(
                i.astype(np.uint8).astype(np.float32), i.astype(np.float32))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        # check labels match
        reverted = engine.state.output["label_inverted1"][-1].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1824: torch 1.5.1
        self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824),
                        "diff. in 3 possible values")

        # check the case that different items use different interpolation mode to invert transforms
        d = engine.state.output["image_inverted2"]
        # if the interpolation mode is nearest, accumulated diff should be smaller than 1
        self.assertLess(
            torch.sum(d.to(torch.float) -
                      d.to(torch.uint8).to(torch.float)).item(), 1.0)
        self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))

        d = engine.state.output["label_inverted2"]
        # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
        self.assertGreater(
            torch.sum(d.to(torch.float) -
                      d.to(torch.uint8).to(torch.float)).item(), 10000.0)
        self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))
Exemple #8
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(
        glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(
        glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{
        'image': image_name,
        'label': label_name
    }
                        for image_name, label_name in zip(
                            train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts,
                                     transform=train_transforms)
    train_loader = DataLoader(check_train,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.workers,
                              pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts,
                                   transform=val_transforms)
    train_dice_loader = DataLoader(check_val,
                                   batch_size=1,
                                   num_workers=opt.workers,
                                   pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val,
                            batch_size=1,
                            num_workers=opt.workers,
                            pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val,
                             batch_size=1,
                             num_workers=opt.workers,
                             pin_memory=torch.cuda.is_available())

    # try to use all the available GPUs
    devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(
            ), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(
                        ), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(
                            val_images, roi_size, sw_batch_size, net)
                        value = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(
                    val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(
                    train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(
                    test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric_train, metric, metric_test,
                            best_metric, best_metric_epoch))

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation inference")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
    def test_train_timing(self):
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))
        train_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[:32], segs[:32])]
        val_files = [{
            "image": img,
            "label": seg
        } for img, seg in zip(images[-9:], segs[-9:])]

        device = torch.device("cuda:0")
        # define transforms for train and validation
        train_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"),
            # change to execute transforms with Tensor data
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(64, 64, 64),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(keys=["image", "label"], prob=0.5),
            RandRotate90d(keys=["image", "label"],
                          prob=0.5,
                          spatial_axes=(1, 2)),
            RandZoomd(keys=["image", "label"],
                      prob=0.5,
                      min_zoom=0.8,
                      max_zoom=1.2,
                      keep_size=True),
            RandRotated(
                keys=["image", "label"],
                prob=0.5,
                range_x=np.pi / 4,
                mode=("bilinear", "nearest"),
                align_corners=True,
            ),
            RandAffined(keys=["image", "label"],
                        prob=0.5,
                        rotate_range=np.pi / 2,
                        mode=("bilinear", "nearest")),
            RandGaussianNoised(keys="image", prob=0.5),
            RandStdShiftIntensityd(keys="image",
                                   prob=0.5,
                                   factors=0.05,
                                   nonzero=True),
        ])

        val_transforms = Compose([
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=(1.0, 1.0, 1.0),
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys="image"),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
            ToDeviced(keys=["image", "label"], device=device),
        ])

        max_epochs = 5
        learning_rate = 2e-4
        val_interval = 1  # do validation for every epoch

        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
        train_ds = CacheDataset(data=train_files,
                                transform=train_transforms,
                                cache_rate=1.0,
                                num_workers=8)
        val_ds = CacheDataset(data=val_files,
                              transform=val_transforms,
                              cache_rate=1.0,
                              num_workers=5)
        # disable multi-workers because `ThreadDataLoader` works with multi-threads
        train_loader = ThreadDataLoader(train_ds,
                                        num_workers=0,
                                        batch_size=4,
                                        shuffle=True)
        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)

        loss_function = DiceCELoss(to_onehot_y=True,
                                   softmax=True,
                                   squared_pred=True,
                                   batch=True)
        model = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(device)

        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()

        post_pred = Compose([
            EnsureType(),
            AsDiscrete(argmax=True, to_onehot=True, num_classes=2)
        ])
        post_label = Compose(
            [EnsureType(),
             AsDiscrete(to_onehot=True, num_classes=2)])

        dice_metric = DiceMetric(include_background=True,
                                 reduction="mean",
                                 get_not_nans=False)

        best_metric = -1
        total_start = time.time()
        for epoch in range(max_epochs):
            epoch_start = time.time()
            print("-" * 10)
            print(f"epoch {epoch + 1}/{max_epochs}")
            model.train()
            epoch_loss = 0
            step = 0
            for batch_data in train_loader:
                step_start = time.time()
                step += 1
                optimizer.zero_grad()
                # set AMP for training
                with torch.cuda.amp.autocast():
                    outputs = model(batch_data["image"])
                    loss = loss_function(outputs, batch_data["label"])
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                epoch_loss += loss.item()
                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                      f" step time: {(time.time() - step_start):.4f}")
            epoch_loss /= step
            print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

            if (epoch + 1) % val_interval == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_loader:
                        roi_size = (96, 96, 96)
                        sw_batch_size = 4
                        # set AMP for validation
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_data["image"], roi_size, sw_batch_size,
                                model)

                        val_outputs = [
                            post_pred(i) for i in decollate_batch(val_outputs)
                        ]
                        val_labels = [
                            post_label(i)
                            for i in decollate_batch(val_data["label"])
                        ]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    metric = dice_metric.aggregate().item()
                    dice_metric.reset()
                    if metric > best_metric:
                        best_metric = metric
                    print(
                        f"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}"
                    )
            print(
                f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}"
            )

        total_time = time.time() - total_start
        print(
            f"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}"
        )
        # test expected metrics
        self.assertGreater(best_metric, 0.95)
Exemple #10
0
            translate_params=[10, 5, -4],
            scale_params=[0.8, 1.3],
        ),
    )
)

TESTS.append(
    (
        "RandAffine 3d",
        "3D",
        1e-1,
        RandAffined(
            KEYS,
            [155, 179, 192],
            prob=1,
            padding_mode="zeros",
            rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7],
            shear_range=[(0.5, 0.5)],
            translate_range=[10, 5, -4],
            scale_range=[(0.8, 1.2), (0.9, 1.3)],
        ),
    )
)

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2  # type: ignore


def no_collation(x):
    return x
def main(config):
    now = datetime.now().strftime("%Y%m%d-%H:%M:%S")

    # path
    csv_path = config['path']['csv_path']

    trained_model_path = config['path'][
        'trained_model_path']  # if None, trained from scratch
    training_model_folder = os.path.join(
        config['path']['training_model_folder'], now)  # '/path/to/folder'
    if not os.path.exists(training_model_folder):
        os.makedirs(training_model_folder)
    logdir = os.path.join(training_model_folder, 'logs')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # PET CT scan params
    image_shape = tuple(config['preprocessing']['image_shape'])  # (x, y, z)
    in_channels = config['preprocessing']['in_channels']
    voxel_spacing = tuple(
        config['preprocessing']
        ['voxel_spacing'])  # (4.8, 4.8, 4.8)  # in millimeter, (x, y, z)
    data_augment = config['preprocessing'][
        'data_augment']  # True  # for training dataset only
    resize = config['preprocessing']['resize']  # True  # not use yet
    origin = config['preprocessing']['origin']  # how to set the new origin
    normalize = config['preprocessing'][
        'normalize']  # True  # whether or not to normalize the inputs
    number_class = config['preprocessing']['number_class']  # 2

    # CNN params
    architecture = config['model']['architecture']  # 'unet' or 'vnet'

    cnn_params = config['model'][architecture]['cnn_params']
    # transform list to tuple
    for key, value in cnn_params.items():
        if isinstance(value, list):
            cnn_params[key] = tuple(value)

    # Training params
    epochs = config['training']['epochs']
    batch_size = config['training']['batch_size']
    shuffle = config['training']['shuffle']
    opt_params = config['training']["optimizer"]["opt_params"]

    # Get Data
    DM = DataManager(csv_path=csv_path)
    train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test(
        wrap_with_dict=True)

    # Input preprocessing
    # use data augmentation for training
    train_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # user can also add other random transforms
        RandAffined(keys=("pet_img", "ct_img", "mask_img"),
                    spatial_size=None,
                    prob=0.4,
                    rotate_range=(0, np.pi / 30, np.pi / 15),
                    shear_range=None,
                    translate_range=(10, 10, 10),
                    scale_range=(0.1, 0.1, 0.1),
                    mode=("bilinear", "bilinear", "nearest"),
                    padding_mode="border"),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])
    # without data augmentation for validation
    val_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_images_paths,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=2)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_images_paths,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=batch_size,
                                       num_workers=2)

    # Model
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,  # 3D
        in_channels=in_channels,
        out_channels=1,
        kernel_size=5,
        channels=(8, 16, 32, 64, 128),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    # training
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/",
                                output_transform=lambda x: None),
        # TensorBoardImageHandler(
        #     log_dir="./runs/",
        #     batch_transform=lambda x: (x["image"], x["label"]),
        #     output_transform=lambda x: x["pred"],
        # ),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SimpleInferer(),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "val_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "val_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    train_handlers = [
        # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        prepare_batch=lambda x: (x['image'], x['mask_img']),
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "train_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "train_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.config.get_torch_version_tuple() >=
        (1, 6) else False,
    )
    trainer.run()
Exemple #12
0
keys = ("img", "seg")

# use these when interpolating binary segmentations to ensure values are 0 or 1 only
zoom_mode = monai.utils.enums.InterpolateMode.NEAREST
elast_mode = monai.utils.enums.GridSampleMode.BILINEAR, monai.utils.enums.GridSampleMode.NEAREST


trans = Compose(
    [
        ScaleIntensityd(keys=("img",)),  # rescale image data to range [0,1]
        AddChanneld(keys=keys),  # add 1-size channel dimension
        RandRotate90d(keys=keys, prob=aug_prob),
        RandFlipd(keys=keys, prob=aug_prob),
        RandZoomd(keys=keys, prob=aug_prob, mode=zoom_mode),
        Rand2DElasticd(keys=keys, prob=aug_prob, spacing=10, magnitude_range=(-2, 2), mode=elast_mode),
        RandAffined(keys=keys, prob=aug_prob, rotate_range=1, translate_range=16, mode=elast_mode),
        ToTensord(keys=keys),  # convert to tensor
    ]
)


data = [
    {"img": train_images[i], "seg": train_segs[i]} for i in range(len(train_images))
]

ds = CacheDataset(data, trans)
loader = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
Exemple #13
0
def training(train_files, val_files, log_dir):
    # Define transforms for image
    print(log_dir)
    train_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]),
            RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]),
            RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5),
            # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1),
            ToTensord(keys=["inputs"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            ToTensord(keys=["inputs"]),
        ]
    )
    # data_size = len(full_files)
    # split = data_size // 2
    # indices = list(range(data_size))
    # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])

    # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True)
    # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn)
    # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn)
    # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn)

    # randomBatch_sizeList = [8, 16, 32, 64, 128]
    # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3]
    # batch_size = random.choice(randomBatch_sizeList)
    # lr = random.choice(randomLRList)
    lr = 0.01
    batch_size = 256
    # print(batch_size)
    # print(lr)
    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)
    check_data = monai.utils.misc.first(check_loader)
    # print(check_data)
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device)
    # train_data = monai.utils.misc.first(train_loader)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)

    # Create Net, CrossEntropyLoss and Adam optimizer
    # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device)
    # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device)
    # im_size = (2,) + tuple(train_ds[0]["inputs"].shape)
    model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device)
    classes = np.array([0, 1, 2, 3, 4])
    # print(check_data["label"].numpy())
    class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy())
    class_weights_tensor = torch.Tensor(class_weights).to(device)
    # print(class_weights_tensor)
    # loss_function = nn.BCEWithLogitsLoss()
    loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)
    # loss_function = torch.nn.MSELoss()
    # m = torch.nn.LogSoftmax(dim=1)
    optimizer = torch.optim.Adam(model.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1)
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')
    # start a typical PyTorch training
    epoch_num = 300
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    # checkpoint_interval = 100
    for epoch in range(start_epoch + 1, epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        # print(scheduler.get_last_lr())
        model.train()
        epoch_loss = 0
        step = 0
        # for i, (inputs, labels, imgName) in enumerate(train_loader):
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device)
            # batch_arr = []
            # for j in range(len(inputs)):
            #     batch_arr.append(inputs[i])
            # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device))
            # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
            # batch_img = batch_img.type(torch.FloatTensor).to(device)
            outputs = model(inputs)
            # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5)
            # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device))
            loss = loss_function(outputs, labels.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}")
            epoch_len = len(train_loader) // train_loader.batch_size
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        # if (epoch + 1) % checkpoint_interval == 0:  # 每隔checkpoint_interval保存一次
        #     checkpoint = {'model': model.state_dict(),
        #                   'optimizer': optimizer.state_dict(),
        #                   'epoch': epoch
        #                   }
        #     path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch)
        #     torch.save(checkpoint, path_checkpoint)
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                # for i, (inputs, labels, imgName) in enumerate(val_loader):
                for val_data in val_loader:
                    val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device)
                    # val_batch_arr = []
                    # for j in range(len(inputs)):
                    #     val_batch_arr.append(inputs[i])
                    # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device))
                    # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
                    # val_img = val_img.type(torch.FloatTensor).to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                    # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5)
                    # y_pred = torch.sigmoid(y_pred)
                    # y = (y / 0.25).long()
                    # print(y)
                # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True)
                # zero = torch.zeros_like(y_pred)
                # one = torch.ones_like(y_pred)
                # y_pred_label = torch.where(y_pred > 0.5, one, zero)
                # print((y_pred_label.sum(1)).to(torch.long))
                # y_pred_acc = (y_pred_label.sum(1)).to(torch.long)
                # print(y_pred.argmax(dim=1))
                # kappa_value = kappa(cm)
                kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic')
                # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic')
                metric_values.append(kappa_value)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                # print(acc_value)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if kappa_value > best_metric:
                    best_metric = kappa_value
                    best_metric_epoch = epoch + 1
                    checkpoint = {'model': model.state_dict(),
                                  'optimizer': optimizer.state_dict(),
                                  'epoch': epoch
                                  }
                    torch.save(checkpoint, log_dir)
                    print("saved new best metric model")
                print(
                    "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format(
                        epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
    plt.figure('train', (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.subplot(1, 2, 2)
    plt.title("Validation: Area under the ROC curve")
    x = [val_interval * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.show()
    evaluta_model(val_files, log_dir)
Exemple #14
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(KEYS),
            CastToTyped(KEYS, dtype=torch.uint8),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        for i in engine.state.output["image_inverted"] + engine.state.output[
                "label_inverted"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        # check labels match
        reverted = engine.state.output["label_inverted"][-1].detach().cpu(
        ).numpy()[0].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) in (25300, 1812),
                        "diff. in two possible values")
Exemple #15
0
def main():
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    _train_data_hackathon = large_image_splitter(_train_data_hackathon,
                                                 dirs["cache"])
    copy_list = transform_and_copy(_train_data_hackathon, dirs['cache'])
    balance_training_data2(_train_data_hackathon, copy_list, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    train_data_hackathon, valid_data_hackathon = train_test_split(
        train_split, test_size=0.2, shuffle=True, random_state=43)

    #balance_training_data(train_data_hackathon, seed=72)
    #balance_training_data(valid_data_hackathon, seed=73)
    #balance_training_data(test_data_hackathon, seed=74)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(5, 5, 0),
                                      select_fn=lambda x: x != 0)
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    # Random axis flip
    rand_x_flip = RandFlipd(keys=["image"], spatial_axis=0, prob=0.50)
    rand_y_flip = RandFlipd(keys=["image"], spatial_axis=1, prob=0.50)
    rand_z_flip = RandFlipd(keys=["image"], spatial_axis=2, prob=0.50)
    # Rand affine transform
    rand_affine = RandAffined(keys=["image"],
                              prob=0.5,
                              rotate_range=(0, 0, np.pi / 12),
                              shear_range=(0.07, 0.07, 0.0),
                              translate_range=(0, 0, 0),
                              scale_range=(0.07, 0.07, 0.0),
                              padding_mode="zeros")
    # Pad image to have hight at least 30
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")
    # Apply Gaussian noise
    rand_gaussian_noise = RandGaussianNoised(keys=["image"],
                                             prob=0.25,
                                             mean=0.0,
                                             std=0.1)

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transform = Compose([
        common_transform,
        rand_x_flip,
        rand_y_flip,
        rand_z_flip,
        rand_affine,
        rand_gaussian_noise,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_valid_transfrom = Compose([
        common_transform,
        #rand_x_flip,
        #rand_y_flip,
        #rand_z_flip,
        #rand_affine,
        ToTensord(keys=["image"]),
    ]).flatten()
    hackathon_test_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    train_dataset = PersistentDataset(data=train_data_hackathon[:],
                                      transform=hackathon_train_transform,
                                      cache_dir=dirs["persistent"])
    valid_dataset = PersistentDataset(data=valid_data_hackathon[:],
                                      transform=hackathon_valid_transfrom,
                                      cache_dir=dirs["persistent"])
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_test_transfrom,
                                     cache_dir=dirs["persistent"])
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        #shuffle=True,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            train_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=4,
        shuffle=False,
        pin_memory=using_gpu,
        num_workers=2,
        sampler=ImbalancedDatasetSampler(
            valid_data_hackathon,
            callback_get_label=lambda x, i: x[i]['_label']),
        collate_fn=PadListDataCollate(Method.SYMMETRIC, NumpyPadMode.CONSTANT))
    test_loader = DataLoader(test_dataset,
                             batch_size=4,
                             shuffle=False,
                             pin_memory=using_gpu,
                             num_workers=2,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)
    # pos_weight for class imbalance
    _, n, p = calculate_class_imbalance(train_data_hackathon)
    pos_weight = torch.Tensor([n, p]).to(device)
    loss_function = torch.nn.BCEWithLogitsLoss(pos_weight)
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=0.95,
                                                       last_epoch=-1)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        #Activationsd(keys="pred", softmax=True),
    ])
    validator = Validator(device=device,
                          val_data_loader=valid_loader,
                          network=network,
                          post_transform=valid_post_transforms,
                          amp=using_gpu,
                          non_blocking=using_gpu)

    trainer = Trainer(device=device,
                      out_dir=dirs["out"],
                      out_name="DenseNet121",
                      max_epochs=120,
                      validation_epoch=1,
                      validation_interval=1,
                      train_data_loader=train_loader,
                      network=network,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      lr_scheduler=None,
                      validator=validator,
                      amp=using_gpu,
                      non_blocking=using_gpu)
    """x_max, y_max, z_max, size_max = 0, 0, 0, 0
    for data in valid_loader:
        image = data["image"]
        label = data["label"]
        print()
        print(len(data['image_transforms']))
        #print(data['image_transforms'])
        print(label)
        shape = image.shape
        x_max = max(x_max, shape[-3])
        y_max = max(y_max, shape[-2])
        z_max = max(z_max, shape[-1])
        size = int(image.nelement()*image.element_size()/1024/1024)
        size_max = max(size_max, size)
        print("shape:", shape, "size:", str(size)+"MB")
        #multi_slice_viewer(image[0, 0, :, :, :], str(label))
    print(x_max, y_max, z_max, str(size_max)+"MB")
    exit()"""

    # Run trainer
    train_output = trainer.run()

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()