Example #1
0
    def read_label(label_path: Path,
                   feature_size: int,
                   n_slices: int,
                   crop: bool = False) -> np.ndarray:
        label = sitk.GetArrayFromImage(sitk.ReadImage(str(label_path)))
        label = np.transpose(label, axes=(2, 1, 0))
        if label.ndim == 4:
            label = np.squeeze(label, axis=-1).astype(np.int16)
        label = label.astype(np.float32)
        label[label == 4] = 3

        X, Y, Z = label.shape
        cx, cy, cz = int(X / 2), int(Y / 2), int(Z / 2)
        if crop:
            label = Torch2DSegmentationDataset.crop_3D_image(
                label, cx, cy, feature_size, cz, n_slices)
        else:
            label = resize_image(label, (feature_size, feature_size, n_slices),
                                 0)
        blank_image = np.zeros((feature_size, feature_size, n_slices))
        # blank_image = np.zeros((X, Y, Z))

        blank_image[label == 2] = 1
        label = np.expand_dims(blank_image, axis=0)

        label = np.transpose(label, (0, 3, 1, 2))
        return label
Example #2
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.training_datasets[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    dataset = Torch2DSegmentationDataset(
        name=dataset_config.name,
        image_paths=[input_path],
        label_paths=[
            input_path.parent.joinpath(
                dataset_config.image_label_format.label.format(
                    phase=args.phase))
        ],
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
    )

    image = dataset.get_image_tensor_from_index(0)
    image = torch.unsqueeze(image, 0)
    image = prepare_tensors(image, True, args.device)

    label = dataset.get_label_tensor_from_index(0)
    inference(
        image=image,
        label=label,
        image_path=input_path,
        network=network,
        output_dir=output_dir,
    )
Example #3
0
 def read_image(image_path: Path, feature_size: int, n_slices: int, crop: bool = False) -> np.ndarray:
     image = nib.load(str(image_path)).get_data()
     if image.ndim == 4:
         image = np.squeeze(image, axis=-1).astype(np.int16)
     image = image.astype(np.float32)
     X, Y, Z = image.shape
     cx, cy, cz = int(X / 2), int(Y / 2), int(Z / 2)
     if crop:
         image = Torch2DSegmentationDataset.crop_3D_image(image, cx, cy, feature_size, cz, n_slices)
     else:
         image = resize_image(image, (feature_size, feature_size, n_slices), 0)
     image = np.transpose(image, (2, 0, 1))
     image = rescale_intensity(image, (1.0, 99.0))
     return image
Example #4
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.dataset_names[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    image = Torch2DSegmentationDataset.read_image(
        input_path,
        get_conf(train_conf, group="network", key="feature_size"),
        get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = prepare_tensors(image, gpu=True, device=args.device)
    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(input_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    print(predicted.shape)

    # map back to original size
    final_predicted = np.zeros(
        (image.shape[1], image.shape[2], image.shape[3]))
    print(predicted.shape, final_predicted.shape)

    for i in range(predicted.shape[0]):
        a = predicted[i, :, :, :] > 0.5
        print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])
    print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    print(predicted.shape, final_predicted.shape, np.max(final_predicted),
          np.mean(final_predicted), np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = image.cpu().detach().numpy()
    final_image = np.squeeze(final_image, 0)
    final_image = np.transpose(final_image, [1, 2, 0])
    print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    label = Torch2DSegmentationDataset.read_label(
        input_path.parent.joinpath(
            dataset_config.image_label_format.label.format(phase=args.phase)),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Example #5
0
    label_paths = data_table.select_column("label_path")
    image_paths = [Path(path) for path in image_paths]
    label_paths = [Path(path) for path in label_paths]
    return image_paths, label_paths


image_paths, label_paths = read_dataframe(config.dataframe_path)

images = []
labels = []
output_dir = Path(__file__).parent.joinpath("output")
output_dir.joinpath("image").mkdir(parents=True, exist_ok=True)
output_dir.joinpath("label").mkdir(parents=True, exist_ok=True)
pbar = list(zip(image_paths, label_paths))
for image_path, label_path in tqdm(pbar):
    image = Torch2DSegmentationDataset.read_image(image_path, None, None)
    label = Torch2DSegmentationDataset.read_label(label_path, None, None)
    middle_slice_index = image.shape[0] // 2
    image = image[middle_slice_index, :, :]
    label = label[:, middle_slice_index, :, :]
    image = np.expand_dims(image, 0)
    image = np.transpose(image, (1, 2, 0))
    label = np.transpose(label, (1, 2, 0))
    # cv2.imshow("label", label)
    # cv2.imshow("image", image)
    # cv2.waitKey()
    image = image * 255
    label = label * 255
    cv2.imwrite(
        str(
            output_dir.joinpath(
Example #6
0
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment",
                             key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path),
                str(config.experiment_dir.joinpath("train.conf")))
    atlas_path = "D:\surui\\rbh\cardiac\DL_segmentation\RBH_3D_atlases\IHD586_10RZ04876_RBH_IHD_201410241040_MRC25598\seg_lvsa_SR_ED.nii.gz"
    atlas = Torch2DSegmentationDataset.read_label(
        label_path=Path(atlas_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
    )
    network = DefSegNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        template=atlas,
        int_downsize=2,
    )

    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
        seed=config.seed,
        augmentation_config=augmentation_config,
        output_dir=config.experiment_dir,
    )
    for train in training_sets:
        train.export(
            config.experiment_dir.joinpath("training_set_{}.csv".format(
                train.name)))
    for val in validation_sets:
        val.export(
            config.experiment_dir.joinpath("validation_set_{}.csv".format(
                val.name)))
    for val in extra_validation_sets:
        val.export(
            config.experiment_dir.joinpath(
                "extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = DefSegLoss(penalty="l2", loss_mult=2, template=atlas)
    experiment = Experiment(config=config,
                            network=network,
                            training_sets=training_sets,
                            validation_sets=validation_sets,
                            extra_validation_sets=extra_validation_sets,
                            optimizer=optimizer,
                            loss=loss,
                            other_validation_metrics=[
                                DefSegWarpedDice(template=atlas),
                                DefSegPredDice()
                            ],
                            inference_func=inference)
    experiment.train()