Exemple #1
0
def test_save_metric_dict():
    """
    Test save_metric_dict by checking output files.
    """

    save_dir = "logs/test_save_metric_dict"
    metrics = [
        dict(image_ssd=0.1, label_dice=0.8, pair_index=[0], label_index=0),
        dict(image_ssd=0.2, label_dice=0.7, pair_index=[1], label_index=1),
        dict(image_ssd=0.3, label_dice=0.6, pair_index=[2], label_index=0),
    ]
    save_metric_dict(save_dir=save_dir, metrics=metrics)
    assert len([x for x in os.listdir(save_dir) if x.endswith(".csv")]) == 3
    shutil.rmtree(save_dir)
Exemple #2
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: str, ddf / dvf / affine / conditional
    :param save_dir: str, path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)

    sample_index_strs = []
    metric_lists = []
    for _, inputs_dict in enumerate(dataset):
        outputs_dict = model.predict(x=inputs_dict)

        # moving image/label
        # (batch, m_dim1, m_dim2, m_dim3)
        moving_image = inputs_dict["moving_image"]
        moving_label = inputs_dict.get("moving_label", None)
        # fixed image/labelimage_index
        # (batch, f_dim1, f_dim2, f_dim3)
        fixed_image = inputs_dict["fixed_image"]
        fixed_label = inputs_dict.get("fixed_label", None)

        # indices to identify the pair
        # (batch, num_indices) last indice is for label, -1 means unlabeled data
        indices = inputs_dict.get("indices")
        # ddf / dvf
        # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = outputs_dict.get("ddf", None)
        dvf = outputs_dict.get("dvf", None)
        affine = outputs_dict.get("affine", None)  # (batch, 4, 3)

        # prediction
        # (batch, f_dim1, f_dim2, f_dim3)
        pred_fixed_label = outputs_dict.get("pred_fixed_label", None)
        pred_fixed_image = (layer_util.resample(
            vol=moving_image, loc=fixed_grid_ref +
            ddf) if ddf is not None else None)

        # save images of inputs and outputs
        for sample_index in range(moving_image.shape[0]):
            # save moving/fixed image under pair_dir
            # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir
            # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir

            # init output path
            indices_i = indices[sample_index, :].numpy().astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            # save image/label
            # if model is conditional, the pred_fixed_image depends on the input label
            conditional = model_method == "conditional"
            arr_save_dirs = [
                pair_dir,
                pair_dir,
                label_dir if conditional else pair_dir,
                label_dir,
                label_dir,
                label_dir,
            ]
            arrs = [
                moving_image,
                fixed_image,
                pred_fixed_image,
                moving_label,
                fixed_label,
                pred_fixed_label,
            ]
            names = [
                "moving_image",
                "fixed_image",
                "pred_fixed_image",  # or warped moving image
                "moving_label",
                "fixed_label",
                "pred_fixed_label",  # or warped moving label
            ]
            for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names):
                if arr is not None:
                    # for files under pair_dir, do not overwrite
                    save_array(
                        save_dir=arr_save_dir,
                        arr=arr[sample_index, :, :, :],
                        name=name,
                        gray=True,
                        save_nifti=save_nifti,
                        save_png=save_png,
                        overwrite=arr_save_dir == label_dir,
                    )

            # save ddf / dvf
            arrs = [ddf, dvf]
            names = ["ddf", "dvf"]
            for arr, name in zip(arrs, names):
                if arr is not None:
                    arr = normalize_array(arr=arr[sample_index, :, :, :])
                    save_array(
                        save_dir=label_dir if conditional else pair_dir,
                        arr=arr,
                        name=name,
                        gray=False,
                        save_nifti=save_nifti,
                        save_png=save_png,
                    )

            # save affine
            if affine is not None:
                np.savetxt(
                    fname=os.path.join(label_dir if conditional else pair_dir,
                                       "affine.txt"),
                    x=affine[sample_index, :, :].numpy(),
                    delimiter=",",
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=fixed_image,
                fixed_label=fixed_label,
                pred_fixed_image=pred_fixed_image,
                pred_fixed_label=pred_fixed_label,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
Exemple #3
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: ddf / dvf / affine / conditional
    :param save_dir: path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)  # pragma: no cover

    sample_index_strs = []
    metric_lists = []
    for _, inputs in enumerate(dataset):
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
        outputs = model.predict(x=inputs, batch_size=batch_size)
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)

        # convert to np arrays
        indices = indices.numpy()
        processed = {
            k:
            (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
            for k, v in processed.items()
        }

        # save images of inputs and outputs
        for sample_index in range(batch_size):
            # save label independent tensors under pair_dir, otherwise under label_dir

            # init output path
            indices_i = indices[sample_index, :].astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            for name, (arr, normalize, on_label) in processed.items():
                if name == "theta":
                    np.savetxt(
                        fname=os.path.join(pair_dir, "affine.txt"),
                        X=arr[sample_index, :, :],
                        delimiter=",",
                    )
                    continue

                arr_save_dir = label_dir if on_label else pair_dir
                save_array(
                    save_dir=arr_save_dir,
                    arr=arr[sample_index, :, :, :],
                    name=name,
                    normalize=normalize,  # label's value is already in [0, 1]
                    save_nifti=save_nifti,
                    save_png=save_png,
                    overwrite=arr_save_dir == label_dir,
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:  # pragma: no cover
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=processed["fixed_image"][0],
                fixed_label=processed["fixed_label"][0]
                if model.labeled else None,
                pred_fixed_image=processed["pred_fixed_image"][0],
                pred_fixed_label=processed["pred_fixed_label"][0]
                if model.labeled else None,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)