Exemple #1
0
def test_calculate_metrics():
    """
    Test calculate_metrics by checking output keys.
    Assuming the metrics functions are correct.
    """

    batch_size = 2
    fixed_image_shape = (4, 4, 4)  # (f_dim1, f_dim2, f_dim3)

    fixed_image = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape)
    fixed_label = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape)
    pred_fixed_image = tf.random.uniform(shape=(batch_size, ) +
                                         fixed_image_shape)
    pred_fixed_label = tf.random.uniform(shape=(batch_size, ) +
                                         fixed_image_shape)
    fixed_grid_ref = tf.random.uniform(shape=(1, ) + fixed_image_shape + (3, ))
    sample_index = 0

    # labeled and have pred_fixed_image
    got = calculate_metrics(
        fixed_image=fixed_image,
        fixed_label=fixed_label,
        pred_fixed_image=pred_fixed_image,
        pred_fixed_label=pred_fixed_label,
        fixed_grid_ref=fixed_grid_ref,
        sample_index=sample_index,
    )
    assert got["image_ssd"] is not None
    assert got["label_binary_dice"] is not None
    assert got["label_tre"] is not None
    assert sorted(list(got.keys())) == sorted(
        ["image_ssd", "label_binary_dice", "label_tre"])

    # labeled and do not have pred_fixed_image
    got = calculate_metrics(
        fixed_image=fixed_image,
        fixed_label=fixed_label,
        pred_fixed_image=None,
        pred_fixed_label=pred_fixed_label,
        fixed_grid_ref=fixed_grid_ref,
        sample_index=sample_index,
    )
    assert got["image_ssd"] is None
    assert got["label_binary_dice"] is not None
    assert got["label_tre"] is not None

    # unlabeled and have pred_fixed_image
    got = calculate_metrics(
        fixed_image=fixed_image,
        fixed_label=None,
        pred_fixed_image=pred_fixed_image,
        pred_fixed_label=None,
        fixed_grid_ref=fixed_grid_ref,
        sample_index=sample_index,
    )
    assert got["image_ssd"] is not None
    assert got["label_binary_dice"] is None
    assert got["label_tre"] is None

    # unlabeled and do not have pred_fixed_image
    got = calculate_metrics(
        fixed_image=fixed_image,
        fixed_label=None,
        pred_fixed_image=None,
        pred_fixed_label=None,
        fixed_grid_ref=fixed_grid_ref,
        sample_index=sample_index,
    )
    assert got["image_ssd"] is None
    assert got["label_binary_dice"] is None
    assert got["label_tre"] is None
Exemple #2
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: str, ddf / dvf / affine / conditional
    :param save_dir: str, path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)

    sample_index_strs = []
    metric_lists = []
    for _, inputs_dict in enumerate(dataset):
        outputs_dict = model.predict(x=inputs_dict)

        # moving image/label
        # (batch, m_dim1, m_dim2, m_dim3)
        moving_image = inputs_dict["moving_image"]
        moving_label = inputs_dict.get("moving_label", None)
        # fixed image/labelimage_index
        # (batch, f_dim1, f_dim2, f_dim3)
        fixed_image = inputs_dict["fixed_image"]
        fixed_label = inputs_dict.get("fixed_label", None)

        # indices to identify the pair
        # (batch, num_indices) last indice is for label, -1 means unlabeled data
        indices = inputs_dict.get("indices")
        # ddf / dvf
        # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = outputs_dict.get("ddf", None)
        dvf = outputs_dict.get("dvf", None)
        affine = outputs_dict.get("affine", None)  # (batch, 4, 3)

        # prediction
        # (batch, f_dim1, f_dim2, f_dim3)
        pred_fixed_label = outputs_dict.get("pred_fixed_label", None)
        pred_fixed_image = (layer_util.resample(
            vol=moving_image, loc=fixed_grid_ref +
            ddf) if ddf is not None else None)

        # save images of inputs and outputs
        for sample_index in range(moving_image.shape[0]):
            # save moving/fixed image under pair_dir
            # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir
            # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir

            # init output path
            indices_i = indices[sample_index, :].numpy().astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            # save image/label
            # if model is conditional, the pred_fixed_image depends on the input label
            conditional = model_method == "conditional"
            arr_save_dirs = [
                pair_dir,
                pair_dir,
                label_dir if conditional else pair_dir,
                label_dir,
                label_dir,
                label_dir,
            ]
            arrs = [
                moving_image,
                fixed_image,
                pred_fixed_image,
                moving_label,
                fixed_label,
                pred_fixed_label,
            ]
            names = [
                "moving_image",
                "fixed_image",
                "pred_fixed_image",  # or warped moving image
                "moving_label",
                "fixed_label",
                "pred_fixed_label",  # or warped moving label
            ]
            for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names):
                if arr is not None:
                    # for files under pair_dir, do not overwrite
                    save_array(
                        save_dir=arr_save_dir,
                        arr=arr[sample_index, :, :, :],
                        name=name,
                        gray=True,
                        save_nifti=save_nifti,
                        save_png=save_png,
                        overwrite=arr_save_dir == label_dir,
                    )

            # save ddf / dvf
            arrs = [ddf, dvf]
            names = ["ddf", "dvf"]
            for arr, name in zip(arrs, names):
                if arr is not None:
                    arr = normalize_array(arr=arr[sample_index, :, :, :])
                    save_array(
                        save_dir=label_dir if conditional else pair_dir,
                        arr=arr,
                        name=name,
                        gray=False,
                        save_nifti=save_nifti,
                        save_png=save_png,
                    )

            # save affine
            if affine is not None:
                np.savetxt(
                    fname=os.path.join(label_dir if conditional else pair_dir,
                                       "affine.txt"),
                    x=affine[sample_index, :, :].numpy(),
                    delimiter=",",
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=fixed_image,
                fixed_label=fixed_label,
                pred_fixed_image=pred_fixed_image,
                pred_fixed_label=pred_fixed_label,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
Exemple #3
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: ddf / dvf / affine / conditional
    :param save_dir: path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)  # pragma: no cover

    sample_index_strs = []
    metric_lists = []
    for _, inputs in enumerate(dataset):
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
        outputs = model.predict(x=inputs, batch_size=batch_size)
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)

        # convert to np arrays
        indices = indices.numpy()
        processed = {
            k:
            (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
            for k, v in processed.items()
        }

        # save images of inputs and outputs
        for sample_index in range(batch_size):
            # save label independent tensors under pair_dir, otherwise under label_dir

            # init output path
            indices_i = indices[sample_index, :].astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            for name, (arr, normalize, on_label) in processed.items():
                if name == "theta":
                    np.savetxt(
                        fname=os.path.join(pair_dir, "affine.txt"),
                        X=arr[sample_index, :, :],
                        delimiter=",",
                    )
                    continue

                arr_save_dir = label_dir if on_label else pair_dir
                save_array(
                    save_dir=arr_save_dir,
                    arr=arr[sample_index, :, :, :],
                    name=name,
                    normalize=normalize,  # label's value is already in [0, 1]
                    save_nifti=save_nifti,
                    save_png=save_png,
                    overwrite=arr_save_dir == label_dir,
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:  # pragma: no cover
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=processed["fixed_image"][0],
                fixed_label=processed["fixed_label"][0]
                if model.labeled else None,
                pred_fixed_image=processed["pred_fixed_image"][0],
                pred_fixed_label=processed["pred_fixed_label"][0]
                if model.labeled else None,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)