Example #1
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.dataset_names[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    image = Torch2DSegmentationDataset.read_image(
        input_path,
        get_conf(train_conf, group="network", key="feature_size"),
        get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = prepare_tensors(image, gpu=True, device=args.device)
    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(input_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    print(predicted.shape)

    # map back to original size
    final_predicted = np.zeros(
        (image.shape[1], image.shape[2], image.shape[3]))
    print(predicted.shape, final_predicted.shape)

    for i in range(predicted.shape[0]):
        a = predicted[i, :, :, :] > 0.5
        print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])
    print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    print(predicted.shape, final_predicted.shape, np.max(final_predicted),
          np.mean(final_predicted), np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = image.cpu().detach().numpy()
    final_image = np.squeeze(final_image, 0)
    final_image = np.transpose(final_image, [1, 2, 0])
    print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    label = Torch2DSegmentationDataset.read_label(
        input_path.parent.joinpath(
            dataset_config.image_label_format.label.format(phase=args.phase)),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Example #2
0
    label_paths = data_table.select_column("label_path")
    image_paths = [Path(path) for path in image_paths]
    label_paths = [Path(path) for path in label_paths]
    return image_paths, label_paths


image_paths, label_paths = read_dataframe(config.dataframe_path)

images = []
labels = []
output_dir = Path(__file__).parent.joinpath("output")
output_dir.joinpath("image").mkdir(parents=True, exist_ok=True)
output_dir.joinpath("label").mkdir(parents=True, exist_ok=True)
pbar = list(zip(image_paths, label_paths))
for image_path, label_path in tqdm(pbar):
    image = Torch2DSegmentationDataset.read_image(image_path, None, None)
    label = Torch2DSegmentationDataset.read_label(label_path, None, None)
    middle_slice_index = image.shape[0] // 2
    image = image[middle_slice_index, :, :]
    label = label[:, middle_slice_index, :, :]
    image = np.expand_dims(image, 0)
    image = np.transpose(image, (1, 2, 0))
    label = np.transpose(label, (1, 2, 0))
    # cv2.imshow("label", label)
    # cv2.imshow("image", image)
    # cv2.waitKey()
    image = image * 255
    label = label * 255
    cv2.imwrite(
        str(
            output_dir.joinpath(