Example #1
0
 def test_3d_4d(self, arr: Tuple[tf.Tensor, np.ndarray]):
     save_array(save_dir=self.save_dir,
                arr=arr,
                name=self.arr_name,
                normalize=True)
     assert self.get_num_files_in_dir(self.png_dir, suffix=".png") == 4
     assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == 1
Example #2
0
 def test_wrong_shape(self, arr: (tf.Tensor, np.ndarray), err_msg: str):
     with pytest.raises(ValueError) as err_info:
         save_array(save_dir=self.save_dir,
                    arr=arr,
                    name=self.arr_name,
                    normalize=True)
     assert err_msg in str(err_info.value)
Example #3
0
 def test_wrong_shape(self, arr: (tf.Tensor, np.ndarray), err_msg: str):
     """test TensorFlow/Numpy inputs with incorrect shapes"""
     with pytest.raises(ValueError) as err_info:
         save_array(save_dir=self.save_dir,
                    arr=arr,
                    name=self.arr_name,
                    normalize=True)
     assert err_msg in str(err_info.value)
Example #4
0
 def test_3d_4d(self, arr: (tf.Tensor, np.ndarray)):
     """test 3d/4d TensorFlow/Numpy inputs"""
     save_array(save_dir=self.save_dir,
                arr=arr,
                name=self.arr_name,
                normalize=True)
     assert self.get_num_files_in_dir(self.png_dir, suffix=".png") == 4
     assert self.get_num_files_in_dir(self.save_dir, suffix=".nii.gz") == 1
Example #5
0
 def test_save_png(self, save_png: bool):
     arr = np.random.rand(2, 3, 4, 3)
     save_array(
         save_dir=self.save_dir,
         arr=arr,
         name=self.arr_name,
         normalize=True,
         save_png=save_png,
     )
     assert (self.get_num_files_in_dir(self.png_dir,
                                       suffix=".png") == int(save_png) * 4)
Example #6
0
 def test_save_nifti(self, save_nifti: bool):
     arr = np.random.rand(2, 3, 4, 3)
     save_array(
         save_dir=self.save_dir,
         arr=arr,
         name=self.arr_name,
         normalize=True,
         save_nifti=save_nifti,
     )
     assert self.get_num_files_in_dir(self.save_dir,
                                      suffix=".nii.gz") == int(save_nifti)
Example #7
0
 def test_overwrite(self, overwrite: bool):
     arr1 = np.random.rand(2, 3, 4, 3)
     arr2 = arr1 + 1
     nifti_file_path = os.path.join(self.save_dir,
                                    self.arr_name + ".nii.gz")
     # save arr1
     os.makedirs(self.save_dir, exist_ok=True)
     nib.save(img=nib.Nifti1Image(arr1, affine=np.eye(4)),
              filename=nifti_file_path)
     # save arr2 w/o overwrite
     save_array(
         save_dir=self.save_dir,
         arr=arr2,
         name=self.arr_name,
         normalize=True,
         overwrite=overwrite,
     )
     arr_read = load_nifti_file(file_path=nifti_file_path)
     assert is_equal_np(arr2 if overwrite else arr1, arr_read)
Example #8
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: str, ddf / dvf / affine / conditional
    :param save_dir: str, path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)

    sample_index_strs = []
    metric_lists = []
    for _, inputs_dict in enumerate(dataset):
        outputs_dict = model.predict(x=inputs_dict)

        # moving image/label
        # (batch, m_dim1, m_dim2, m_dim3)
        moving_image = inputs_dict["moving_image"]
        moving_label = inputs_dict.get("moving_label", None)
        # fixed image/labelimage_index
        # (batch, f_dim1, f_dim2, f_dim3)
        fixed_image = inputs_dict["fixed_image"]
        fixed_label = inputs_dict.get("fixed_label", None)

        # indices to identify the pair
        # (batch, num_indices) last indice is for label, -1 means unlabeled data
        indices = inputs_dict.get("indices")
        # ddf / dvf
        # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = outputs_dict.get("ddf", None)
        dvf = outputs_dict.get("dvf", None)
        affine = outputs_dict.get("affine", None)  # (batch, 4, 3)

        # prediction
        # (batch, f_dim1, f_dim2, f_dim3)
        pred_fixed_label = outputs_dict.get("pred_fixed_label", None)
        pred_fixed_image = (layer_util.resample(
            vol=moving_image, loc=fixed_grid_ref +
            ddf) if ddf is not None else None)

        # save images of inputs and outputs
        for sample_index in range(moving_image.shape[0]):
            # save moving/fixed image under pair_dir
            # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir
            # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir

            # init output path
            indices_i = indices[sample_index, :].numpy().astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            # save image/label
            # if model is conditional, the pred_fixed_image depends on the input label
            conditional = model_method == "conditional"
            arr_save_dirs = [
                pair_dir,
                pair_dir,
                label_dir if conditional else pair_dir,
                label_dir,
                label_dir,
                label_dir,
            ]
            arrs = [
                moving_image,
                fixed_image,
                pred_fixed_image,
                moving_label,
                fixed_label,
                pred_fixed_label,
            ]
            names = [
                "moving_image",
                "fixed_image",
                "pred_fixed_image",  # or warped moving image
                "moving_label",
                "fixed_label",
                "pred_fixed_label",  # or warped moving label
            ]
            for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names):
                if arr is not None:
                    # for files under pair_dir, do not overwrite
                    save_array(
                        save_dir=arr_save_dir,
                        arr=arr[sample_index, :, :, :],
                        name=name,
                        gray=True,
                        save_nifti=save_nifti,
                        save_png=save_png,
                        overwrite=arr_save_dir == label_dir,
                    )

            # save ddf / dvf
            arrs = [ddf, dvf]
            names = ["ddf", "dvf"]
            for arr, name in zip(arrs, names):
                if arr is not None:
                    arr = normalize_array(arr=arr[sample_index, :, :, :])
                    save_array(
                        save_dir=label_dir if conditional else pair_dir,
                        arr=arr,
                        name=name,
                        gray=False,
                        save_nifti=save_nifti,
                        save_png=save_png,
                    )

            # save affine
            if affine is not None:
                np.savetxt(
                    fname=os.path.join(label_dir if conditional else pair_dir,
                                       "affine.txt"),
                    x=affine[sample_index, :, :].numpy(),
                    delimiter=",",
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=fixed_image,
                fixed_label=fixed_label,
                pred_fixed_image=pred_fixed_image,
                pred_fixed_label=pred_fixed_label,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
Example #9
0
    shutil.rmtree(SAVE_PATH)
os.mkdir(SAVE_PATH)

arrays = [
    tf.squeeze(a)
    for a in [
        moving_image,
        fixed_image,
        warped_moving_image,
        moving_label,
        fixed_label,
        warped_moving_label,
        var_ddf,
    ]
]
arr_names = [
    "moving_image",
    "fixed_image",
    "warped_moving_image",
    "moving_label",
    "fixed_label",
    "warped_moving_label",
    "ddf",
]
for arr, arr_name in zip(arrays, arr_names):
    util.save_array(
        save_dir=SAVE_PATH, arr=arr, name=arr_name, normalize=True, save_png=False
    )

os.chdir(MAIN_PATH)
Example #10
0
arrays = [
    tf.transpose(a, [1, 2, 3, 0]) if a.ndim == 4 else tf.squeeze(a)
    for a in [
        moving_image,
        fixed_image,
        warped_moving_image,
        moving_labels,
        fixed_labels,
        warped_moving_labels,
    ]
]
arr_names = [
    "moving_image",
    "fixed_image",
    "warped_moving_image",
    "moving_label",
    "fixed_label",
    "warped_moving_label",
]
for arr, arr_name in zip(arrays, arr_names):
    for n in range(arr.shape[-1]):
        util.save_array(
            save_dir=SAVE_PATH,
            arr=arr[..., n],
            name=arr_name + (arr.shape[-1] > 1) * "_{}".format(n),
            normalize="image" in arr_name,  # label's value is already in [0, 1]
        )

os.chdir(MAIN_PATH)
Example #11
0
os.mkdir(SAVE_PATH)

arrays = [
    tf.transpose(a, [1, 2, 3, 0]) if a.ndim == 4 else tf.squeeze(a) for a in [
        moving_image,
        fixed_image,
        warped_moving_image,
        moving_labels,
        fixed_labels,
        warped_moving_labels,
    ]
]
arr_names = [
    "moving_image",
    "fixed_image",
    "warped_moving_image",
    "moving_label",
    "fixed_label",
    "warped_moving_label",
]
for arr, arr_name in zip(arrays, arr_names):
    for n in range(arr.shape[-1]):
        util.save_array(
            save_dir=SAVE_PATH,
            arr=arr[..., n],
            name=arr_name + (arr.shape[-1] > 1) * "_{}".format(n),
            gray=True,
        )

os.chdir(MAIN_PATH)
Example #12
0
def test_save_array():
    """
    Test save_array by testing different shapes and count output files
    """
    def get_num_pngs_in_dir(dir_paths):
        return len([x for x in os.listdir(dir_paths) if x.endswith(".png")])

    def get_num_niftis_in_dir(dir_paths):
        return len([x for x in os.listdir(dir_paths) if x.endswith(".nii.gz")])

    save_dir = "logs/test_util_save_array"
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)

    # test 3D tf tensor
    name = "3d_tf"
    out_dir = os.path.join(save_dir, name)
    arr = tf.random.uniform(shape=(2, 3, 4))
    save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert get_num_pngs_in_dir(out_dir) == 4
    assert get_num_niftis_in_dir(save_dir) == 1
    shutil.rmtree(out_dir)
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 4D tf tensor
    name = "4d_tf"
    out_dir = os.path.join(save_dir, name)
    arr = tf.random.uniform(shape=(2, 3, 4, 3))
    save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert get_num_pngs_in_dir(out_dir) == 4
    assert get_num_niftis_in_dir(save_dir) == 1
    shutil.rmtree(out_dir)
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 3D np tensor
    name = "3d_np"
    out_dir = os.path.join(save_dir, name)
    arr = np.random.rand(2, 3, 4)
    save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert get_num_pngs_in_dir(out_dir) == 4
    assert get_num_niftis_in_dir(save_dir) == 1
    shutil.rmtree(out_dir)
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 4D np tensor
    name = "4d_np"
    out_dir = os.path.join(save_dir, name)
    arr = np.random.rand(2, 3, 4, 3)
    save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert get_num_pngs_in_dir(out_dir) == 4
    assert get_num_niftis_in_dir(save_dir) == 1
    shutil.rmtree(out_dir)
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 4D np tensor without nifti
    name = "4d_np"
    out_dir = os.path.join(save_dir, name)
    arr = np.random.rand(2, 3, 4, 3)
    save_array(save_dir=save_dir,
               arr=arr,
               name=name,
               gray=True,
               save_nifti=False)
    assert get_num_pngs_in_dir(out_dir) == 4
    assert get_num_niftis_in_dir(save_dir) == 0
    shutil.rmtree(out_dir)

    # test 4D np tensor without png
    name = "4d_np"
    out_dir = os.path.join(save_dir, name)
    arr = np.random.rand(2, 3, 4, 3)
    assert not os.path.exists(out_dir)
    save_array(save_dir=save_dir,
               arr=arr,
               name=name,
               gray=True,
               save_png=False)
    assert not os.path.exists(out_dir)
    assert get_num_niftis_in_dir(save_dir) == 1
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 4D np tensor with overwrite
    name = "4d_np"
    out_dir = os.path.join(save_dir, name)
    arr1 = np.random.rand(2, 3, 4, 3)
    arr2 = np.random.rand(2, 3, 4, 3)
    assert not is_equal_np(arr1, arr2)
    nifti_file_path = os.path.join(save_dir, name + ".nii.gz")
    # save arr1
    os.makedirs(save_dir, exist_ok=True)
    nib.save(img=nib.Nifti2Image(arr1, affine=np.eye(4)),
             filename=nifti_file_path)
    # save arr2 without overwrite
    save_array(save_dir=save_dir,
               arr=arr1,
               name=name,
               gray=True,
               overwrite=False)
    arr_read = load_nifti_file(file_path=nifti_file_path)
    assert is_equal_np(arr1, arr_read)
    # save arr2 with overwrite
    save_array(save_dir=save_dir,
               arr=arr2,
               name=name,
               gray=True,
               overwrite=True)
    arr_read = load_nifti_file(file_path=nifti_file_path)
    assert is_equal_np(arr2, arr_read)
    shutil.rmtree(out_dir)
    os.remove(os.path.join(save_dir, name + ".nii.gz"))

    # test 5D np tensor
    name = "5d_np"
    arr = np.random.rand(2, 3, 4, 1, 3)
    with pytest.raises(ValueError) as err_info:
        save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert "arr must be 3d or 4d numpy array or tf tensor" in str(
        err_info.value)

    # test 4D np tensor with wrong shape
    name = "5d_np"
    arr = np.random.rand(2, 3, 4, 1)
    with pytest.raises(ValueError) as err_info:
        save_array(save_dir=save_dir, arr=arr, name=name, gray=True)
    assert "4d arr must have 3 channels as last dimension" in str(
        err_info.value)
Example #13
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: ddf / dvf / affine / conditional
    :param save_dir: path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)  # pragma: no cover

    sample_index_strs = []
    metric_lists = []
    for _, inputs in enumerate(dataset):
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
        outputs = model.predict(x=inputs, batch_size=batch_size)
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)

        # convert to np arrays
        indices = indices.numpy()
        processed = {
            k:
            (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
            for k, v in processed.items()
        }

        # save images of inputs and outputs
        for sample_index in range(batch_size):
            # save label independent tensors under pair_dir, otherwise under label_dir

            # init output path
            indices_i = indices[sample_index, :].astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            for name, (arr, normalize, on_label) in processed.items():
                if name == "theta":
                    np.savetxt(
                        fname=os.path.join(pair_dir, "affine.txt"),
                        X=arr[sample_index, :, :],
                        delimiter=",",
                    )
                    continue

                arr_save_dir = label_dir if on_label else pair_dir
                save_array(
                    save_dir=arr_save_dir,
                    arr=arr[sample_index, :, :, :],
                    name=name,
                    normalize=normalize,  # label's value is already in [0, 1]
                    save_nifti=save_nifti,
                    save_png=save_png,
                    overwrite=arr_save_dir == label_dir,
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:  # pragma: no cover
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=processed["fixed_image"][0],
                fixed_label=processed["fixed_label"][0]
                if model.labeled else None,
                pred_fixed_image=processed["pred_fixed_image"][0],
                pred_fixed_label=processed["pred_fixed_label"][0]
                if model.labeled else None,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)