Exemplo n.º 1
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(KEYS),
            CastToTyped(KEYS, dtype=torch.uint8),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        for i in engine.state.output["image_inverted"] + engine.state.output[
                "label_inverted"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        # check labels match
        reverted = engine.state.output["label_inverted"][-1].detach().cpu(
        ).numpy()[0].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) in (25300, 1812),
                        "diff. in two possible values")
Exemplo n.º 2
0
 def test_fail_random_but_not_invertible(self):
     transforms = Compose(
         [AddChanneld("im"),
          Rand2DElasticd("im", None, None)])
     with self.assertRaises(RuntimeError):
         TestTimeAugmentation(transforms, None, None, None)
Exemplo n.º 3
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        transform = Compose([
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue("image_meta_dict" in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44))

        try:  # will start downloading if testing_dir doesn't have the Decathlon files
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=True,
            )
        except (ContentTooShortError, HTTPError, RuntimeError) as e:
            print(str(e))
            if isinstance(e, RuntimeError):
                # FIXME: skip MD5 check as current downloading method may fail
                self.assertTrue(str(e).startswith("md5 check"))
            return  # skipping this test due the network connection errors

        _test_dataset(data)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                transform=transform,
                                section="validation",
                                download=False)
        _test_dataset(data)
        # test validation without transforms
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44))
        self.assertEqual(len(data), 52)
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                section="training",
                                download=False)
        self.assertTupleEqual(data[0]["image"].shape, (34, 56, 31))
        self.assertEqual(len(data), 208)

        # test dataset properties
        data = DecathlonDataset(root_dir=testing_dir,
                                task="Task04_Hippocampus",
                                section="validation",
                                download=False)
        properties = data.get_properties(keys="labels")
        self.assertDictEqual(properties["labels"], {
            "0": "background",
            "1": "Anterior",
            "2": "Posterior"
        })

        shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus"))
        try:
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=False,
            )
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
Exemplo n.º 4
0
def main():
    print_config()

    # Define paths for running the script
    data_dir = os.path.normpath('/to/be/defined')
    json_path = os.path.normpath('/to/be/defined')
    logdir = os.path.normpath('/to/be/defined')

    # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used
    use_pretrained = 1
    pretrained_path = os.path.normpath('/to/be/defined')

    # Training Hyper-parameters
    lr = 1e-4
    max_iterations = 30000
    eval_num = 100

    if os.path.exists(logdir) == False:
        os.mkdir(logdir)

    # Training & Validation Transform chain
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-175,
                             a_max=250,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ])

    datalist = load_decathlon_datalist(base_dir=data_dir,
                                       data_list_file_path=json_path,
                                       is_segmentation=True,
                                       data_list_key="training")

    val_files = load_decathlon_datalist(base_dir=data_dir,
                                        data_list_file_path=json_path,
                                        is_segmentation=True,
                                        data_list_key="validation")
    train_ds = CacheDataset(
        data=datalist,
        transform=train_transforms,
        cache_num=24,
        cache_rate=1.0,
        num_workers=4,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    val_ds = CacheDataset(data=val_files,
                          transform=val_transforms,
                          cache_num=6,
                          cache_rate=1.0,
                          num_workers=4)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    case_num = 0
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNETR(
        in_channels=1,
        out_channels=14,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="conv",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )

    # Load ViT backbone weights into UNETR
    if use_pretrained == 1:
        print('Loading Weights from the Path {}'.format(pretrained_path))
        vit_dict = torch.load(pretrained_path)
        vit_weights = vit_dict['state_dict']

        # Remove items of vit_weights if they are not in the ViT backbone (this is used in UNETR).
        # For example, some variables names like conv3d_transpose.weight, conv3d_transpose.bias,
        # conv3d_transpose_1.weight and conv3d_transpose_1.bias are used to match dimensions
        # while pretraining with ViTAutoEnc and are not a part of ViT backbone.
        model_dict = model.vit.state_dict()
        vit_weights = {k: v for k, v in vit_weights.items() if k in model_dict}
        model_dict.update(vit_weights)
        model.vit.load_state_dict(model_dict)
        del model_dict, vit_weights, vit_dict
        print('Pretrained Weights Succesfully Loaded !')

    elif use_pretrained == 0:
        print(
            'No weights were loaded, all weights being used are randomly initialized!'
        )

    model.to(device)

    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    post_label = AsDiscrete(to_onehot=14)
    post_pred = AsDiscrete(argmax=True, to_onehot=14)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []

    def validation(epoch_iterator_val):
        model.eval()
        dice_vals = list()

        with torch.no_grad():
            for step, batch in enumerate(epoch_iterator_val):
                val_inputs, val_labels = (batch["image"].cuda(),
                                          batch["label"].cuda())
                val_outputs = sliding_window_inference(val_inputs,
                                                       (96, 96, 96), 4, model)
                val_labels_list = decollate_batch(val_labels)
                val_labels_convert = [
                    post_label(val_label_tensor)
                    for val_label_tensor in val_labels_list
                ]
                val_outputs_list = decollate_batch(val_outputs)
                val_output_convert = [
                    post_pred(val_pred_tensor)
                    for val_pred_tensor in val_outputs_list
                ]
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice = dice_metric.aggregate().item()
                dice_vals.append(dice)
                epoch_iterator_val.set_description(
                    "Validate (%d / %d Steps) (dice=%2.5f)" %
                    (global_step, 10.0, dice))

            dice_metric.reset()

        mean_dice_val = np.mean(dice_vals)
        return mean_dice_val

    def train(global_step, train_loader, dice_val_best, global_step_best):
        model.train()
        epoch_loss = 0
        step = 0
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              dynamic_ncols=True)
        for step, batch in enumerate(epoch_iterator):
            step += 1
            x, y = (batch["image"].cuda(), batch["label"].cuda())
            logit_map = model(x)
            loss = loss_function(logit_map, y)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" %
                (global_step, max_iterations, loss))

            if (global_step % eval_num == 0
                    and global_step != 0) or global_step == max_iterations:
                epoch_iterator_val = tqdm(
                    val_loader,
                    desc="Validate (X / X Steps) (dice=X.X)",
                    dynamic_ncols=True)
                dice_val = validation(epoch_iterator_val)

                epoch_loss /= step
                epoch_loss_values.append(epoch_loss)
                metric_values.append(dice_val)
                if dice_val > dice_val_best:
                    dice_val_best = dice_val
                    global_step_best = global_step
                    torch.save(model.state_dict(),
                               os.path.join(logdir, "best_metric_model.pth"))
                    print(
                        "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))
                else:
                    print(
                        "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))

                plt.figure(1, (12, 6))
                plt.subplot(1, 2, 1)
                plt.title("Iteration Average Loss")
                x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
                y = epoch_loss_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.subplot(1, 2, 2)
                plt.title("Val Mean Dice")
                x = [eval_num * (i + 1) for i in range(len(metric_values))]
                y = metric_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.savefig(
                    os.path.join(logdir, 'btcv_finetune_quick_update.png'))
                plt.clf()
                plt.close(1)

            global_step += 1
        return global_step, dice_val_best, global_step_best

    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(
            global_step, train_loader, dice_val_best, global_step_best)
    model.load_state_dict(
        torch.load(os.path.join(logdir, "best_metric_model.pth")))

    print(f"train completed, best_metric: {dice_val_best:.4f} "
          f"at iteration: {global_step_best}")
Exemplo n.º 5
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                   "testing_data")
        download_len = 1
        val_frac = 1.0
        collection = "QIN-PROSTATE-Repeatability"

        transform = Compose([
            LoadImaged(keys=["image", "seg"],
                       reader="PydicomReader",
                       label_dict=TCIA_LABEL_DICT[collection]),
            AddChanneld(keys="image"),
            ScaleIntensityd(keys="image"),
        ])

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), int(download_len * val_frac))
            self.assertTrue("image" in dataset[0])
            self.assertTrue("seg" in dataset[0])
            self.assertTrue(isinstance(dataset[0]["image"], MetaTensor))
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 256, 256, 24))
            self.assertTupleEqual(dataset[0]["seg"].shape, (256, 256, 24, 4))

        with skip_if_downloading_fails():
            data = TciaDataset(
                root_dir=testing_dir,
                collection=collection,
                transform=transform,
                section="validation",
                download=True,
                download_len=download_len,
                copy_cache=False,
                val_frac=val_frac,
            )

        _test_dataset(data)
        data = TciaDataset(
            root_dir=testing_dir,
            collection=collection,
            transform=transform,
            section="validation",
            download=False,
            val_frac=val_frac,
        )
        _test_dataset(data)
        self.assertTrue(data[0]["image"].meta["filename_or_obj"].endswith(
            "QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/image"))
        self.assertTrue(data[0]["seg"].meta["filename_or_obj"].endswith(
            "QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/seg"))
        # test validation without transforms
        data = TciaDataset(root_dir=testing_dir,
                           collection=collection,
                           section="validation",
                           download=False,
                           val_frac=val_frac)
        self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
        self.assertEqual(len(data), int(download_len * val_frac))
        data = TciaDataset(root_dir=testing_dir,
                           collection=collection,
                           section="validation",
                           download=False,
                           val_frac=val_frac)
        self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
        self.assertEqual(len(data), download_len)

        shutil.rmtree(os.path.join(testing_dir, collection))
        try:
            TciaDataset(
                root_dir=testing_dir,
                collection=collection,
                transform=transform,
                section="validation",
                download=False,
                val_frac=val_frac,
            )
        except RuntimeError as e:
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
    opt = Options().parse()

    train_images = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    data_dicts = [{
        'image': image_name,
        'label': label_name,
        'mask': mask_name
    } for image_name, label_name, mask_name in zip(train_images, train_segs)]

    monai_transforms = [
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        NormalizeIntensityd(keys=['image']),
        ScaleIntensityd(keys=['image']),
        # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
        # RandFlipd(keys=['image', 'label'], prob=1, spatial_axis=2),
        # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1,
        #             rotate_range=(np.pi / 36, np.pi / 4, np.pi / 36)),
        # Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1,
        #                sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.20, 0.20, 0.20)),
        # RandAdjustContrastd(keys=['image'], gamma=(0.5, 3), prob=1),
        # RandGaussianNoised(keys=['image'], prob=1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
        # RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=1),
        # BorderPadd(keys=['image', 'label'],spatial_border=(16,16,0)),
        # RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
        # Orientationd(keys=["image", "label"], axcodes="PLI"),
        ToTensord(keys=['image', 'label'])
Exemplo n.º 7
0
 def test_shape(self, input_param, input_data, expected_shape):
     result = AddChanneld(**input_param)(input_data)
     self.assertEqual(result['img'].shape, expected_shape)
     self.assertEqual(result['seg'].shape, expected_shape)
Exemplo n.º 8
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images, labels)]

    # Define transforms for image
    val_transforms = Compose([
        LoadImaged(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.DenseNet121(spatial_dims=3,
                                            in_channels=1,
                                            out_channels=2).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_classification3d_dict.pth"))
    model.eval()
    with torch.no_grad():
        num_correct = 0.0
        metric_count = 0
        saver = CSVSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["label"].to(device)
            val_outputs = model(val_images).argmax(dim=1)
            value = torch.eq(val_outputs, val_labels)
            metric_count += len(value)
            num_correct += value.sum().item()
            saver.save_batch(val_outputs, val_data["img_meta_dict"])
        metric = num_correct / metric_count
        print("evaluation metric:", metric)
        saver.finalize()
Exemplo n.º 9
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI314-IOP-0889-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI249-Guys-1072-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI609-HH-2600-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI173-HH-1590-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI020-Guys-0700-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI342-Guys-0909-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI134-Guys-0780-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI577-HH-2661-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI066-Guys-0731-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI130-HH-1528-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                   in_channels=1,
                                                   out_channels=2)
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    print(state)
Exemplo n.º 10
0
    Rand2DElasticd,
    RandAffined,
)

aug_prob = 0.5
keys = ("img", "seg")

# use these when interpolating binary segmentations to ensure values are 0 or 1 only
zoom_mode = monai.utils.enums.InterpolateMode.NEAREST
elast_mode = monai.utils.enums.GridSampleMode.BILINEAR, monai.utils.enums.GridSampleMode.NEAREST


trans = Compose(
    [
        ScaleIntensityd(keys=("img",)),  # rescale image data to range [0,1]
        AddChanneld(keys=keys),  # add 1-size channel dimension
        RandRotate90d(keys=keys, prob=aug_prob),
        RandFlipd(keys=keys, prob=aug_prob),
        RandZoomd(keys=keys, prob=aug_prob, mode=zoom_mode),
        Rand2DElasticd(keys=keys, prob=aug_prob, spacing=10, magnitude_range=(-2, 2), mode=elast_mode),
        RandAffined(keys=keys, prob=aug_prob, rotate_range=1, translate_range=16, mode=elast_mode),
        ToTensord(keys=keys),  # convert to tensor
    ]
)


data = [
    {"img": train_images[i], "seg": train_segs[i]} for i in range(len(train_images))
]

ds = CacheDataset(data, trans)
def main(config):
    now = datetime.now().strftime("%Y%m%d-%H:%M:%S")

    # path
    csv_path = config['path']['csv_path']

    trained_model_path = config['path'][
        'trained_model_path']  # if None, trained from scratch
    training_model_folder = os.path.join(
        config['path']['training_model_folder'], now)  # '/path/to/folder'
    if not os.path.exists(training_model_folder):
        os.makedirs(training_model_folder)
    logdir = os.path.join(training_model_folder, 'logs')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # PET CT scan params
    image_shape = tuple(config['preprocessing']['image_shape'])  # (x, y, z)
    in_channels = config['preprocessing']['in_channels']
    voxel_spacing = tuple(
        config['preprocessing']
        ['voxel_spacing'])  # (4.8, 4.8, 4.8)  # in millimeter, (x, y, z)
    data_augment = config['preprocessing'][
        'data_augment']  # True  # for training dataset only
    resize = config['preprocessing']['resize']  # True  # not use yet
    origin = config['preprocessing']['origin']  # how to set the new origin
    normalize = config['preprocessing'][
        'normalize']  # True  # whether or not to normalize the inputs
    number_class = config['preprocessing']['number_class']  # 2

    # CNN params
    architecture = config['model']['architecture']  # 'unet' or 'vnet'

    cnn_params = config['model'][architecture]['cnn_params']
    # transform list to tuple
    for key, value in cnn_params.items():
        if isinstance(value, list):
            cnn_params[key] = tuple(value)

    # Training params
    epochs = config['training']['epochs']
    batch_size = config['training']['batch_size']
    shuffle = config['training']['shuffle']
    opt_params = config['training']["optimizer"]["opt_params"]

    # Get Data
    DM = DataManager(csv_path=csv_path)
    train_images_paths, val_images_paths, test_images_paths = DM.get_train_val_test(
        wrap_with_dict=True)

    # Input preprocessing
    # use data augmentation for training
    train_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # user can also add other random transforms
        RandAffined(keys=("pet_img", "ct_img", "mask_img"),
                    spatial_size=None,
                    prob=0.4,
                    rotate_range=(0, np.pi / 30, np.pi / 15),
                    shear_range=None,
                    translate_range=(10, 10, 10),
                    scale_range=(0.1, 0.1, 0.1),
                    mode=("bilinear", "bilinear", "nearest"),
                    padding_mode="border"),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])
    # without data augmentation for validation
    val_transforms = Compose([  # read img + meta info
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
        Roi2Mask(keys=['pet_img', 'mask_img'],
                 method='otsu',
                 tval=0.0,
                 idx_channel=0),
        ResampleReshapeAlign(target_shape=image_shape,
                             target_voxel_spacing=voxel_spacing,
                             keys=['pet_img', "ct_img", 'mask_img'],
                             origin='head',
                             origin_key='pet_img'),
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
        # normalize input
        ScaleIntensityRanged(
            keys=["pet_img"],
            a_min=0.0,
            a_max=25.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ScaleIntensityRanged(
            keys=["ct_img"],
            a_min=-1000.0,
            a_max=1000.0,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Prepare for neural network
        ConcatModality(keys=['pet_img', 'ct_img']),
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
        ToTensord(keys=["image", "mask_img"]),
    ])

    # create a training data loader
    train_ds = monai.data.CacheDataset(data=train_images_paths,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    # use batch_size=2 to load images to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=2)
    # create a validation data loader
    val_ds = monai.data.CacheDataset(data=val_images_paths,
                                     transform=val_transforms,
                                     cache_rate=1.0)
    val_loader = monai.data.DataLoader(val_ds,
                                       batch_size=batch_size,
                                       num_workers=2)

    # Model
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,  # 3D
        in_channels=in_channels,
        out_channels=1,
        kernel_size=5,
        channels=(8, 16, 32, 64, 128),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True)
    opt = torch.optim.Adam(net.parameters(), 1e-3)

    # training
    val_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir="./runs/",
                                output_transform=lambda x: None),
        # TensorBoardImageHandler(
        #     log_dir="./runs/",
        #     batch_transform=lambda x: (x["image"], x["label"]),
        #     output_transform=lambda x: x["pred"],
        # ),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_key_metric=True),
    ]

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        inferer=SimpleInferer(),
        post_transform=val_post_transforms,
        key_val_metric={
            "val_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "val_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "val_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "val_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
        # amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
    )

    train_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
    ])
    train_handlers = [
        # LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir="./runs/",
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir="./runs/",
                        save_dict={
                            "net": net,
                            "opt": opt
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=5,
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss,
        prepare_batch=lambda x: (x['image'], x['mask_img']),
        inferer=SimpleInferer(),
        post_transform=train_post_transforms,
        key_train_metric={
            "train_mean_dice":
            MeanDice(include_background=True,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        additional_metrics={
            "train_acc":
            Accuracy(output_transform=lambda x: (x["pred"], x["label"])),
            "train_precision":
            Precision(output_transform=lambda x: (x["pred"], x["label"])),
            "train_recall":
            Recall(output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
        # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
        amp=True if monai.config.get_torch_version_tuple() >=
        (1, 6) else False,
    )
    trainer.run()
Exemplo n.º 12
0
def evaluta_model(test_files, model_name):
    test_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            ToTensord(keys=["inputs"]),
        ]
    )
    # create a validation data loader
    device = torch.device("cpu")
    print(len(test_files))
    test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
    test_loader = DataLoader(test_ds, batch_size=len(test_files), num_workers=2, pin_memory=torch.device)
    # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device)
    model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device)
    # Evaluate the model on test dataset #
    # print(os.path.basename(model_name).split('.')[0])
    checkpoint = torch.load(model_name)
    model.load_state_dict(checkpoint['model'])
    # optimizer.load_state_dict(checkpoint['optimizer'])
    # epochs = checkpoint['epoch']
    # model.load_state_dict(torch.load(log_dir))
    model.eval()
    with torch.no_grad():
        saver = CSVSaver(output_dir="../result/GLeason/2d_output/",
                         filename=os.path.basename(model_name).split('.')[0] + '.csv')
        for test_data in test_loader:
            test_images, test_labels = test_data["inputs"].to(device), test_data["label"].to(device)
            pred = model(test_images)  # Gleason Classification
            # y_soft_label = (test_labels / 0.25).long()
            # y_soft_pred = (pred / 0.25).round().squeeze_().long()
            # print(test_data)
            probabilities = torch.sigmoid(pred)
            # pred2 = model(test_images).argmax(dim=1)
            # print(test_data)
            # saver.save_batch(probabilities.argmax(dim=1), test_data["t2Img_meta_dict"])
            # zero = torch.zeros_like(probabilities)
            # one = torch.ones_like(probabilities)
            # y_pred_ordinal = torch.where(probabilities > 0.5, one, zero)
            # y_pred_acc = (y_pred_ordinal.sum(1)).to(torch.long)
            saver.save_batch(probabilities.argmax(dim=1), test_data["dwiImg_meta_dict"])
            # print(test_labels)
            # print(probabilities[:, 1])
            # for x in np.nditer(probabilities[:, 1]):
            #     print(x)
            #     prob_list.append(x)
        # falseList = []
        # trueList = []
        # for pre, label in zip(pred2.tolist(), test_labels.tolist() ):
        #     if pre == 0 and label == 0:
        #         falseList.append(0)
        #     elif pre == 1 and label == 1:
        #         trueList.append(1)
        # specificity = (falseList.count(0) / test_labels.tolist().count(0))
        # sensitivity = (trueList.count(1) / test_labels.tolist().count(1))
        # print('specificity:' + '%.4f' % specificity + ',',
        #       'sensitivity:' + '%.4f' % sensitivity + ',',
        #       'accuracy:' + '%.4f' % ((specificity + sensitivity) / 2))
        # print(type(test_labels), type(pred))
        # fpr, tpr, thresholds = roc_curve(test_labels, probabilities[:, 1])
        # roc_auc = auc(fpr, tpr)
        # print('AUC = ' + str(roc_auc))
        # AUC_list.append(roc_auc)
        # # print(accuracy_score(test_labels, pred2))
        # accuracy_list.append(accuracy_score(test_labels, pred2))
        # plt.plot(fpr, tpr, linewidth=2, label="ROC")
        # plt.xlabel("false presitive rate")
        # plt.ylabel("true presitive rate")
        # # plt.ylim(0, 1.05)
        # # plt.xlim(0, 1.05)
        # plt.legend(loc=4)  # 图例的位置
        # plt.show()
        saver.finalize()
        # cm = confusion_matrix(test_labels, y_pred_acc)
        cm = confusion_matrix(test_labels, probabilities.argmax(dim=1))
        # cm = confusion_matrix(y_soft_label, y_soft_pred)
        # kappa_value = cohen_kappa_score(test_labels, y_pred_acc, weights='quadratic')
        kappa_value = cohen_kappa_score(test_labels, probabilities.argmax(dim=1), weights='quadratic')
        print('quadratic weighted kappa=' + str(kappa_value))
        kappa_list.append(kappa_value)
        plot_confusion_matrix(cm, 'confusion_matrix.png', title='confusion matrix')
    from sklearn.metrics import classification_report
    print(classification_report(test_labels, probabilities.argmax(dim=1), digits=4))
    accuracy_list.append(
        classification_report(test_labels, probabilities.argmax(dim=1), digits=4, output_dict=True)["accuracy"])
Exemplo n.º 13
0
def training(train_files, val_files, log_dir):
    # Define transforms for image
    print(log_dir)
    train_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]),
            RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]),
            RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5),
            # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1),
            ToTensord(keys=["inputs"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            ToTensord(keys=["inputs"]),
        ]
    )
    # data_size = len(full_files)
    # split = data_size // 2
    # indices = list(range(data_size))
    # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])

    # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True)
    # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn)
    # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn)
    # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn)

    # randomBatch_sizeList = [8, 16, 32, 64, 128]
    # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3]
    # batch_size = random.choice(randomBatch_sizeList)
    # lr = random.choice(randomLRList)
    lr = 0.01
    batch_size = 256
    # print(batch_size)
    # print(lr)
    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)
    check_data = monai.utils.misc.first(check_loader)
    # print(check_data)
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device)
    # train_data = monai.utils.misc.first(train_loader)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)

    # Create Net, CrossEntropyLoss and Adam optimizer
    # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device)
    # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device)
    # im_size = (2,) + tuple(train_ds[0]["inputs"].shape)
    model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device)
    classes = np.array([0, 1, 2, 3, 4])
    # print(check_data["label"].numpy())
    class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy())
    class_weights_tensor = torch.Tensor(class_weights).to(device)
    # print(class_weights_tensor)
    # loss_function = nn.BCEWithLogitsLoss()
    loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)
    # loss_function = torch.nn.MSELoss()
    # m = torch.nn.LogSoftmax(dim=1)
    optimizer = torch.optim.Adam(model.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1)
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')
    # start a typical PyTorch training
    epoch_num = 300
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    # checkpoint_interval = 100
    for epoch in range(start_epoch + 1, epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        # print(scheduler.get_last_lr())
        model.train()
        epoch_loss = 0
        step = 0
        # for i, (inputs, labels, imgName) in enumerate(train_loader):
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device)
            # batch_arr = []
            # for j in range(len(inputs)):
            #     batch_arr.append(inputs[i])
            # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device))
            # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
            # batch_img = batch_img.type(torch.FloatTensor).to(device)
            outputs = model(inputs)
            # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5)
            # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device))
            loss = loss_function(outputs, labels.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}")
            epoch_len = len(train_loader) // train_loader.batch_size
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        # if (epoch + 1) % checkpoint_interval == 0:  # 每隔checkpoint_interval保存一次
        #     checkpoint = {'model': model.state_dict(),
        #                   'optimizer': optimizer.state_dict(),
        #                   'epoch': epoch
        #                   }
        #     path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch)
        #     torch.save(checkpoint, path_checkpoint)
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                # for i, (inputs, labels, imgName) in enumerate(val_loader):
                for val_data in val_loader:
                    val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device)
                    # val_batch_arr = []
                    # for j in range(len(inputs)):
                    #     val_batch_arr.append(inputs[i])
                    # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device))
                    # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
                    # val_img = val_img.type(torch.FloatTensor).to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                    # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5)
                    # y_pred = torch.sigmoid(y_pred)
                    # y = (y / 0.25).long()
                    # print(y)
                # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True)
                # zero = torch.zeros_like(y_pred)
                # one = torch.ones_like(y_pred)
                # y_pred_label = torch.where(y_pred > 0.5, one, zero)
                # print((y_pred_label.sum(1)).to(torch.long))
                # y_pred_acc = (y_pred_label.sum(1)).to(torch.long)
                # print(y_pred.argmax(dim=1))
                # kappa_value = kappa(cm)
                kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic')
                # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic')
                metric_values.append(kappa_value)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                # print(acc_value)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if kappa_value > best_metric:
                    best_metric = kappa_value
                    best_metric_epoch = epoch + 1
                    checkpoint = {'model': model.state_dict(),
                                  'optimizer': optimizer.state_dict(),
                                  'epoch': epoch
                                  }
                    torch.save(checkpoint, log_dir)
                    print("saved new best metric model")
                print(
                    "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format(
                        epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
    plt.figure('train', (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.subplot(1, 2, 2)
    plt.title("Validation: Area under the ROC curve")
    x = [val_interval * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.show()
    evaluta_model(val_files, log_dir)
Exemplo n.º 14
0
def image_mixing(data, seed=None):
    #random.seed(seed)

    file_list = [x for x in data if int(x['_label']) == 1]
    random.shuffle(file_list)

    crop_foreground = CropForegroundd(keys=["image"],
                                      source_key="image",
                                      margin=(0, 0, 0),
                                      select_fn=lambda x: x != 0)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    resize2 = Resized(keys=["image"],
                      spatial_size=(int(512 * 0.75), int(512 * 0.75), -1),
                      mode="area")
    resize1 = Resized(keys=["image"],
                      spatial_size=(-1, -1, 40),
                      mode="nearest")
    gauss = GaussianSmooth(sigma=(1., 1., 0))
    gauss2 = GaussianSmooth(sigma=(2.0, 2.0, 0))
    affine = Affined(keys=["image"],
                     scale_params=(1.0, 2.0, 1.0),
                     padding_mode='zeros')

    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        affine,
        crop_foreground,
        resize1,
        resize2,
        SqueezeDimd(keys=["image"]),
    ])

    dirs = setup_directories()
    data_dir = dirs['data']
    mixed_images_dir = os.path.join(data_dir, 'mixed_images')
    if not os.path.exists(mixed_images_dir):
        os.mkdir(mixed_images_dir)

    for img1, img2 in itertools.combinations(file_list, 2):

        img1 = {'image': img1["image"], 'seg': img1['seg']}
        img2 = {'image': img2["image"], 'seg': img2['seg']}

        img1_data = common_transform(img1)["image"]
        img2_data = common_transform(img2)["image"]
        img1_mask, img2_mask = (img1_data > 0), (img2_data > 0)
        img_presek = np.logical_and(img1_mask, img2_mask)
        img = np.maximum(img_presek * img1_data, img_presek * img2_data)

        multi_slice_viewer(img, "img1")

        loop = True
        while loop:
            save = input("Save image [y/n/e]: ")
            if save.lower() == 'y':
                loop = False
                k = str(time.time()).encode('utf-8')
                h = blake2b(key=k, digest_size=16)
                name = h.hexdigest() + '.nii.gz'
                out_path = os.path.join(mixed_images_dir, name)
                write_nifti(img, out_path, resample=False)
            elif save.lower() == 'n':
                loop = False
                break
            elif save.lower() == 'e':
                print("exeting")
                exit()
            else:
                print("wrong input!")
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(
                "image"
            ),  # test to support both Tensor and Numpy array when inverting
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            postfix="inverted1",
            to_tensor=[True, False],
            device="cpu",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        # test different nearest interpolation values
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="image",
            nearest_interp=[True, False],
            post_func=[lambda x: x + 10, lambda x: x],
            postfix="inverted2",
            collate_fn=pad_list_data_collate,
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        # check the nearest inerpolation mode
        for i in engine.state.output["image_inverted1"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        for i in engine.state.output["label_inverted1"]:
            np.testing.assert_allclose(
                i.astype(np.uint8).astype(np.float32), i.astype(np.float32))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        # check labels match
        reverted = engine.state.output["label_inverted1"][-1].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1824: torch 1.5.1
        self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824),
                        "diff. in 3 possible values")

        # check the case that different items use different interpolation mode to invert transforms
        d = engine.state.output["image_inverted2"]
        # if the interpolation mode is nearest, accumulated diff should be smaller than 1
        self.assertLess(
            torch.sum(d.to(torch.float) -
                      d.to(torch.uint8).to(torch.float)).item(), 1.0)
        self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))

        d = engine.state.output["label_inverted2"]
        # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
        self.assertGreater(
            torch.sum(d.to(torch.float) -
                      d.to(torch.uint8).to(torch.float)).item(), 10000.0)
        self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107))
Exemplo n.º 16
0
 def load_image(filename):
     data = {"image": filename}
     t = Compose([LoadImaged(keys="image"), AddChanneld(keys="image")])
     return t(data)
Exemplo n.º 17
0
    def test_values(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
        transform = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                AddChanneld(keys=["image", "label"]),
                ScaleIntensityd(keys="image"),
                ToTensord(keys=["image", "label"]),
            ]
        )

        def _test_dataset(dataset):
            self.assertEqual(len(dataset), 52)
            self.assertTrue("image" in dataset[0])
            self.assertTrue("label" in dataset[0])
            self.assertTrue(PostFix.meta("image") in dataset[0])
            self.assertTupleEqual(dataset[0]["image"].shape, (1, 36, 47, 44))

        with skip_if_downloading_fails():
            data = DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=True,
                copy_cache=False,
            )

        _test_dataset(data)
        data = DecathlonDataset(
            root_dir=testing_dir, task="Task04_Hippocampus", transform=transform, section="validation", download=False
        )
        _test_dataset(data)
        self.assertTrue(data[0][PostFix.meta("image")]["filename_or_obj"].endswith("hippocampus_163.nii.gz"))
        self.assertTrue(data[0][PostFix.meta("label")]["filename_or_obj"].endswith("hippocampus_163.nii.gz"))
        # test validation without transforms
        data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="validation", download=False)
        self.assertTupleEqual(data[0]["image"].shape, (36, 47, 44))
        self.assertEqual(len(data), 52)
        data = DecathlonDataset(root_dir=testing_dir, task="Task04_Hippocampus", section="training", download=False)
        self.assertTupleEqual(data[0]["image"].shape, (34, 56, 31))
        self.assertEqual(len(data), 208)

        # test dataset properties
        data = DecathlonDataset(
            root_dir=Path(testing_dir), task="Task04_Hippocampus", section="validation", download=False
        )
        properties = data.get_properties(keys="labels")
        self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"})

        shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus"))
        try:
            DecathlonDataset(
                root_dir=testing_dir,
                task="Task04_Hippocampus",
                transform=transform,
                section="validation",
                download=False,
            )
        except RuntimeError as e:
            print(str(e))
            self.assertTrue(str(e).startswith("Cannot find dataset directory"))
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = PNGSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = post_trans(val_outputs)
            value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            saver.save_batch(val_outputs)
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
Exemplo n.º 19
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        EnsureTyped(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate)
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = [
                        post_trans(i) for i in decollate_batch(val_outputs)
                    ]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Exemplo n.º 20
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images, labels)]

    # define transforms for image
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # create DenseNet121
    net = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                   in_channels=1,
                                                   out_channels=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    prediction_saver = ClassificationSaver(
        output_dir="tempdir",
        name="evaluator",
        batch_transform=lambda batch: batch["img_meta_dict"],
        output_transform=lambda output: output[0].argmax(1),
    )
    prediction_saver.attach(evaluator)

    # the model was trained by "densenet_training_dict" example
    CheckpointLoader(load_path="./runs_dict/net_checkpoint_20.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    state = evaluator.run(val_loader)
    print(state)
Exemplo n.º 21
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(
        glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(
        glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(
        glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(
        glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(
        glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{
        'image': image_name,
        'label': label_name
    }
                        for image_name, label_name in zip(
                            train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{
        'image': image_name,
        'label': label_name
    } for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-120,
                a_max=170,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts,
                                     transform=train_transforms)
    train_loader = DataLoader(check_train,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.workers,
                              pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts,
                                   transform=val_transforms)
    train_dice_loader = DataLoader(check_val,
                                   batch_size=1,
                                   num_workers=opt.workers,
                                   pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val,
                            batch_size=1,
                            num_workers=opt.workers,
                            pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val,
                             batch_size=1,
                             num_workers=opt.workers,
                             pin_memory=torch.cuda.is_available())

    # try to use all the available GPUs
    devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(
            ), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(
                        ), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(
                            val_images, roi_size, sw_batch_size, net)
                        value = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(
                    val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(
                    train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(
                    test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric_train, metric, metric_test,
                            best_metric, best_metric_epoch))

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="validation inference")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI314-IOP-0889-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI249-Guys-1072-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI609-HH-2600-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI173-HH-1590-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI020-Guys-0700-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI342-Guys-0909-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI134-Guys-0780-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI577-HH-2661-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI066-Guys-0731-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI130-HH-1528-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # Define transforms for image
    train_transforms = Compose([
        LoadImaged(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                     in_channels=1,
                                                     out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["label"].to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)

                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                auc_metric = compute_roc_auc(y_pred,
                                             y,
                                             to_onehot_y=True,
                                             softmax=True)
                if acc_metric > best_metric:
                    best_metric = acc_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_classification3d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}"
                    .format(epoch + 1, acc_metric, auc_metric, best_metric,
                            best_metric_epoch))
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
def main(hparams):
	print('===== INITIAL PARAMETERS =====')
	print('Model name: ', hparams.name)
	print('Batch size: ', hparams.batch_size)
	print('Patch size: ', hparams.patch_size)
	print('Epochs: ', hparams.epochs)
	print('Learning rate: ', hparams.learning_rate)
	print('Loss function: ', hparams.loss)
	print()

	### Data collection
	data_dir = 'data/'
	print('Available directories: ', os.listdir(data_dir))
	# Get paths for images and masks, organize into dictionaries
	images = sorted(glob.glob(data_dir + '**/*CTImg*', recursive=True))
	masks = sorted(glob.glob(data_dir + '**/*Mask*', recursive=True))
	data_dicts = [{'image': image_file, 'mask': mask_file} for image_file, mask_file in zip(images, masks)]
	# Dataset selection
	train_dicts = select_animals(images, masks, [12, 13, 14, 18, 20])
	val_dicts = select_animals(images, masks, [25])
	test_dicts = select_animals(images, masks, [27])
	data_keys = ['image', 'mask']
	# Data transformation
	data_transforms = Compose([
	    LoadNiftid(keys=data_keys),
   	    AddChanneld(keys=data_keys),
	    ScaleIntensityd(keys=data_keys),
	    CropForegroundd(keys=data_keys, source_key='image'),
	    RandSpatialCropd(
	        keys=data_keys,
	        roi_size=(hparams.patch_size, hparams.patch_size, 1),
	        random_size=False
	    ),
	])
	train_transforms = Compose([
	    data_transforms,
  	    ToTensord(keys=data_keys)
	])
	val_transforms = Compose([
	    data_transforms,
   	    ToTensord(keys=data_keys)
	])
	test_transforms = Compose([
	    data_transforms,
   	    ToTensord(keys=data_keys)
	])
	# Data loaders
	data_loaders = {
	    'train': create_loader(train_dicts, batch_size=hparams.batch_size, transforms=train_transforms, shuffle=True),
	    'val': create_loader(val_dicts, transforms=val_transforms),
	    'test': create_loader(test_dicts, transforms=test_transforms)
	}
	for key in data_loaders:
		print(key, len(data_loaders[key]))



	### Model training
	if hparams.loss == 'Dice':
		criterion = monai.losses.DiceLoss(to_onehot_y=True, do_softmax=True)
	elif hparams.loss == 'CrossEntropy':
		criterion = nn.CrossEntropyLoss()
        
	model = UNet(
	    dimensions=2,
	    in_channels=1,
	    out_channels=2,
	    channels=(64, 128, 258, 512, 1024),
	    strides=(2, 2, 2, 2),
	    norm=monai.networks.layers.Norm.BATCH,
	    criterion=criterion,
	    hparams=hparams,
	)

	early_stopping = EarlyStopping('val_loss')
	checkpoint_callback = ModelCheckpoint
	logger = TensorBoardLogger('models/' + hparams.name + '/tb_logs', name=hparams.name)
	
	trainer = Trainer(
	    check_val_every_n_epoch=5,
	    default_save_path='models/' + hparams.name + '/checkpoints',
	#     early_stop_callback=early_stopping,
	    gpus=1,
	    max_epochs=hparams.epochs,
	#     min_epochs=10,
	    logger=logger
	)

	trainer.fit(
	    model,
	    train_dataloader=data_loaders['train'],
	    val_dataloaders=data_loaders['val']
	)
Exemplo n.º 24
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        EnsureTyped(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose(
        [EnsureType(),
         Activations(sigmoid=True),
         AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output",
                      output_ext=".png",
                      output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(
        torch.load("best_metric_model_segmentation2d_dict.pth"))

    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output in val_outputs:
                saver(val_output)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
Exemplo n.º 25
0
 def test_fail_non_random(self):
     transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)])
     with self.assertRaises(RuntimeError):
         TestTimeAugmentation(transforms, None, None, None)
Exemplo n.º 26
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)),
            CenterSpatialCropd(KEYS, (110, 99))
        ])
        num_invertible_transforms = sum(1 for i in transforms.transforms
                                        if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2, ),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {
            "label": segs,
            label_transform_key: data[label_transform_key]
        }

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict,
                                    mode="nearest",
                                    align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms,
                                               loader,
                                               collate_fn=no_collation)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:],
                         test_data[0]["label"].shape)
Exemplo n.º 27
0
    def test_test_time_augmentation(self):
        input_size = (20, 20)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        keys = ["image", "label"]
        num_training_ims = 10
        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)

        transforms = Compose([
            AddChanneld(keys),
            RandAffined(
                keys,
                prob=1.0,
                spatial_size=(30, 30),
                rotate_range=(np.pi / 3, np.pi / 3),
                translate_range=(3, 3),
                scale_range=((0.8, 1), (0.8, 1)),
                padding_mode="zeros",
                mode=("bilinear", "nearest"),
                as_tensor_output=False,
            ),
            CropForegroundd(keys, source_key="image"),
            DivisiblePadd(keys, 4),
        ])

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds,
                                  batch_size=2,
                                  collate_fn=pad_list_data_collate)

        model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(
                    device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        post_trans = Compose([
            Activations(sigmoid=True),
            AsDiscrete(threshold_values=True),
        ])

        def inferrer_fn(x):
            return post_trans(model(x))

        tt_aug = TestTimeAugmentation(transforms,
                                      batch_size=5,
                                      num_workers=0,
                                      inferrer_fn=inferrer_fn,
                                      device=device)
        mode, mean, std, vvc = tt_aug(test_data)
        self.assertEqual(mode.shape, (1, ) + input_size)
        self.assertEqual(mean.shape, (1, ) + input_size)
        self.assertTrue(all(np.unique(mode) == (0, 1)))
        self.assertEqual((mean.min(), mean.max()), (0.0, 1.0))
        self.assertEqual(std.shape, (1, ) + input_size)
        self.assertIsInstance(vvc, float)
Exemplo n.º 28
0
    "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
    "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
    "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
    "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
    "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz"
]
# 2 binary labels for gender classification: man and woman
labels = np.array([
    0, 0, 1, 0, 1, 0, 1, 0, 1, 0
])
val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)]

# define transforms for image
val_transforms = Compose([
    LoadNiftid(keys=['img']),
    AddChanneld(keys=['img']),
    ScaleIntensityd(keys=['img']),
    Resized(keys=['img'], spatial_size=(96, 96, 96)),
    ToTensord(keys=['img'])
])

# create DenseNet121
net = monai.networks.nets.densenet.densenet121(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
)
device = torch.device("cuda:0")


def prepare_batch(batch, device=None, non_blocking=False):
Exemplo n.º 29
0
Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.
1. `ToTensord` converts the numpy array to PyTorch Tensor for further steps.
"""

train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    AddChanneld(keys=["image", "label"]),
    Spacingd(keys=["image", "label"],
             pixdim=(1.5, 1.5, 2.0),
             mode=("bilinear", "nearest")),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=-1000,
        a_max=300,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    #CropForegroundd(keys=["image", "label"], source_key="image"),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
Exemplo n.º 30
0
test_files = data_dicts[0:52]
"""## Set deterministic training for reproducibility"""

set_determinism(seed=0)
'''
Label 1: Bladder
Label 2: Liver
Label 3: Lungs
Label 4: Heart
Label 5: Pancreas
'''

test_transforms = Compose([
    LoadImaged(keys="image"),
    AddChanneld(keys="image"),
    Spacingd(keys="image", pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
    Orientationd(keys="image", axcodes="RAS"),
    ScaleIntensityRanged(
        keys="image",
        a_min=-1000,
        a_max=300,
        b_min=0.0,
        b_max=1.0,
        clip=True,
    ),
    #CropForegroundd(keys=["image", "label"], source_key="image"),
    ToTensord(keys="image"),
])

test_ds = CacheDataset(data=test_files,