Пример #1
0
def construct_training_validation_dataset(
    data_config: DataConfig,
    feature_size: int,
    n_slices: int,
    output_dir: Path,
    is_3d: bool = False,
    augmentation_config: AugmentationConfig = None,
    seed: int = None,
    template_path: Path = None,
) -> Tuple[List["Torch2DSegmentationDataset"],
           List["Torch2DSegmentationDataset"],
           List["Torch2DSegmentationDataset"]]:
    training_set_configs = [
        DatasetConfig.from_conf(name,
                                mode=data_config.data_mode,
                                mount_prefix=data_config.mount_prefix)
        for name in data_config.training_datasets
    ]

    extra_val_set_configs = [
        DatasetConfig.from_conf(name,
                                mode=data_config.data_mode,
                                mount_prefix=data_config.mount_prefix)
        for name in data_config.extra_validation_datasets
    ]
    training_sets = []
    validation_sets = []
    for config in training_set_configs:
        train, val, template_path = train_val_dataset_from_config(
            dataset_config=config,
            augmentation_config=augmentation_config,
            augmentation_prob=data_config.augmentation_prob,
            validation_split=data_config.validation_split,
            feature_size=feature_size,
            n_slices=n_slices,
            is_3d=is_3d,
            renew_dataframe=data_config.renew_dataframe,
            seed=seed,
            output_dir=output_dir,
            template_path=template_path)
        training_sets.append(train)
        validation_sets.append(val)
    extra_val_sets = []
    for config in extra_val_set_configs:
        __, val, __ = train_val_dataset_from_config(
            dataset_config=config,
            validation_split=data_config.validation_split,
            feature_size=feature_size,
            n_slices=n_slices,
            is_3d=is_3d,
            only_val=True,
            renew_dataframe=data_config.renew_dataframe,
            seed=seed,
            output_dir=output_dir,
            template_path=template_path,
        )
        extra_val_sets.append(val)
    return training_sets, validation_sets, extra_val_sets
Пример #2
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.training_datasets[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    dataset = Torch2DSegmentationDataset(
        name=dataset_config.name,
        image_paths=[input_path],
        label_paths=[
            input_path.parent.joinpath(
                dataset_config.image_label_format.label.format(
                    phase=args.phase))
        ],
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
    )

    image = dataset.get_image_tensor_from_index(0)
    image = torch.unsqueeze(image, 0)
    image = prepare_tensors(image, True, args.device)

    label = dataset.get_label_tensor_from_index(0)
    inference(
        image=image,
        label=label,
        image_path=input_path,
        network=network,
        output_dir=output_dir,
    )
Пример #3
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.dataset_names[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    image = Torch2DSegmentationDataset.read_image(
        input_path,
        get_conf(train_conf, group="network", key="feature_size"),
        get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = prepare_tensors(image, gpu=True, device=args.device)
    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(input_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    print(predicted.shape)

    # map back to original size
    final_predicted = np.zeros(
        (image.shape[1], image.shape[2], image.shape[3]))
    print(predicted.shape, final_predicted.shape)

    for i in range(predicted.shape[0]):
        a = predicted[i, :, :, :] > 0.5
        print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])
    print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    print(predicted.shape, final_predicted.shape, np.max(final_predicted),
          np.mean(final_predicted), np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = image.cpu().detach().numpy()
    final_image = np.squeeze(final_image, 0)
    final_image = np.transpose(final_image, [1, 2, 0])
    print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    label = Torch2DSegmentationDataset.read_label(
        input_path.parent.joinpath(
            dataset_config.image_label_format.label.format(phase=args.phase)),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Пример #4
0
import cv2
import numpy as np
from pathlib import Path
from CMRSegment.common.config import DatasetConfig
from CMRSegment.common.data_table import DataTable
from CMRSegment.common.nn.torch.data import Torch2DSegmentationDataset
from tqdm import tqdm
config = DatasetConfig.from_conf(
    name="RBH_3D_atlases",
    mode="3D",
    mount_prefix=Path("/mnt/storage/home/suruli/"))
# config = DatasetConfig.from_conf(name="RBH_3D_atlases", mode="3D", mount_prefix=Path("D:/surui/rbh/"))


def read_dataframe(dataframe_path: Path):
    data_table = DataTable.from_csv(dataframe_path)
    image_paths = data_table.select_column("image_path")
    label_paths = data_table.select_column("label_path")
    image_paths = [Path(path) for path in image_paths]
    label_paths = [Path(path) for path in label_paths]
    return image_paths, label_paths


image_paths, label_paths = read_dataframe(config.dataframe_path)

images = []
labels = []
output_dir = Path(__file__).parent.joinpath("output")
output_dir.joinpath("image").mkdir(parents=True, exist_ok=True)
output_dir.joinpath("label").mkdir(parents=True, exist_ok=True)
pbar = list(zip(image_paths, label_paths))