def test_calculate_metrics(): """ Test calculate_metrics by checking output keys. Assuming the metrics functions are correct. """ batch_size = 2 fixed_image_shape = (4, 4, 4) # (f_dim1, f_dim2, f_dim3) fixed_image = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape) fixed_label = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape) pred_fixed_image = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape) pred_fixed_label = tf.random.uniform(shape=(batch_size, ) + fixed_image_shape) fixed_grid_ref = tf.random.uniform(shape=(1, ) + fixed_image_shape + (3, )) sample_index = 0 # labeled and have pred_fixed_image got = calculate_metrics( fixed_image=fixed_image, fixed_label=fixed_label, pred_fixed_image=pred_fixed_image, pred_fixed_label=pred_fixed_label, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) assert got["image_ssd"] is not None assert got["label_binary_dice"] is not None assert got["label_tre"] is not None assert sorted(list(got.keys())) == sorted( ["image_ssd", "label_binary_dice", "label_tre"]) # labeled and do not have pred_fixed_image got = calculate_metrics( fixed_image=fixed_image, fixed_label=fixed_label, pred_fixed_image=None, pred_fixed_label=pred_fixed_label, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) assert got["image_ssd"] is None assert got["label_binary_dice"] is not None assert got["label_tre"] is not None # unlabeled and have pred_fixed_image got = calculate_metrics( fixed_image=fixed_image, fixed_label=None, pred_fixed_image=pred_fixed_image, pred_fixed_label=None, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) assert got["image_ssd"] is not None assert got["label_binary_dice"] is None assert got["label_tre"] is None # unlabeled and do not have pred_fixed_image got = calculate_metrics( fixed_image=fixed_image, fixed_label=None, pred_fixed_image=None, pred_fixed_label=None, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) assert got["image_ssd"] is None assert got["label_binary_dice"] is None assert got["label_tre"] is None
def predict_on_dataset( dataset: tf.data.Dataset, fixed_grid_ref: tf.Tensor, model: tf.keras.Model, model_method: str, save_dir: str, save_nifti: bool, save_png: bool, ): """ Function to predict results from a dataset from some model :param dataset: where data is stored :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) :param model: model to be used for prediction :param model_method: str, ddf / dvf / affine / conditional :param save_dir: str, path to store dir :param save_nifti: if true, outputs will be saved in nifti format :param save_png: if true, outputs will be saved in png format """ # remove the save_dir in case it exists if os.path.exists(save_dir): shutil.rmtree(save_dir) sample_index_strs = [] metric_lists = [] for _, inputs_dict in enumerate(dataset): outputs_dict = model.predict(x=inputs_dict) # moving image/label # (batch, m_dim1, m_dim2, m_dim3) moving_image = inputs_dict["moving_image"] moving_label = inputs_dict.get("moving_label", None) # fixed image/labelimage_index # (batch, f_dim1, f_dim2, f_dim3) fixed_image = inputs_dict["fixed_image"] fixed_label = inputs_dict.get("fixed_label", None) # indices to identify the pair # (batch, num_indices) last indice is for label, -1 means unlabeled data indices = inputs_dict.get("indices") # ddf / dvf # (batch, f_dim1, f_dim2, f_dim3, 3) ddf = outputs_dict.get("ddf", None) dvf = outputs_dict.get("dvf", None) affine = outputs_dict.get("affine", None) # (batch, 4, 3) # prediction # (batch, f_dim1, f_dim2, f_dim3) pred_fixed_label = outputs_dict.get("pred_fixed_label", None) pred_fixed_image = (layer_util.resample( vol=moving_image, loc=fixed_grid_ref + ddf) if ddf is not None else None) # save images of inputs and outputs for sample_index in range(moving_image.shape[0]): # save moving/fixed image under pair_dir # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir # init output path indices_i = indices[sample_index, :].numpy().astype(int).tolist() pair_dir, label_dir = build_pair_output_path(indices=indices_i, save_dir=save_dir) # save image/label # if model is conditional, the pred_fixed_image depends on the input label conditional = model_method == "conditional" arr_save_dirs = [ pair_dir, pair_dir, label_dir if conditional else pair_dir, label_dir, label_dir, label_dir, ] arrs = [ moving_image, fixed_image, pred_fixed_image, moving_label, fixed_label, pred_fixed_label, ] names = [ "moving_image", "fixed_image", "pred_fixed_image", # or warped moving image "moving_label", "fixed_label", "pred_fixed_label", # or warped moving label ] for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names): if arr is not None: # for files under pair_dir, do not overwrite save_array( save_dir=arr_save_dir, arr=arr[sample_index, :, :, :], name=name, gray=True, save_nifti=save_nifti, save_png=save_png, overwrite=arr_save_dir == label_dir, ) # save ddf / dvf arrs = [ddf, dvf] names = ["ddf", "dvf"] for arr, name in zip(arrs, names): if arr is not None: arr = normalize_array(arr=arr[sample_index, :, :, :]) save_array( save_dir=label_dir if conditional else pair_dir, arr=arr, name=name, gray=False, save_nifti=save_nifti, save_png=save_png, ) # save affine if affine is not None: np.savetxt( fname=os.path.join(label_dir if conditional else pair_dir, "affine.txt"), x=affine[sample_index, :, :].numpy(), delimiter=",", ) # calculate metric sample_index_str = "_".join([str(x) for x in indices_i]) if sample_index_str in sample_index_strs: raise ValueError( "Sample is repeated, maybe the dataset has been repeated.") sample_index_strs.append(sample_index_str) metric = calculate_metrics( fixed_image=fixed_image, fixed_label=fixed_label, pred_fixed_image=pred_fixed_image, pred_fixed_label=pred_fixed_label, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) metric["pair_index"] = indices_i[:-1] metric["label_index"] = indices_i[-1] metric_lists.append(metric) # save metric save_metric_dict(save_dir=save_dir, metrics=metric_lists)
def predict_on_dataset( dataset: tf.data.Dataset, fixed_grid_ref: tf.Tensor, model: tf.keras.Model, model_method: str, save_dir: str, save_nifti: bool, save_png: bool, ): """ Function to predict results from a dataset from some model :param dataset: where data is stored :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) :param model: model to be used for prediction :param model_method: ddf / dvf / affine / conditional :param save_dir: path to store dir :param save_nifti: if true, outputs will be saved in nifti format :param save_png: if true, outputs will be saved in png format """ # remove the save_dir in case it exists if os.path.exists(save_dir): shutil.rmtree(save_dir) # pragma: no cover sample_index_strs = [] metric_lists = [] for _, inputs in enumerate(dataset): batch_size = inputs[list(inputs.keys())[0]].shape[0] outputs = model.predict(x=inputs, batch_size=batch_size) indices, processed = model.postprocess(inputs=inputs, outputs=outputs) # convert to np arrays indices = indices.numpy() processed = { k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) for k, v in processed.items() } # save images of inputs and outputs for sample_index in range(batch_size): # save label independent tensors under pair_dir, otherwise under label_dir # init output path indices_i = indices[sample_index, :].astype(int).tolist() pair_dir, label_dir = build_pair_output_path(indices=indices_i, save_dir=save_dir) for name, (arr, normalize, on_label) in processed.items(): if name == "theta": np.savetxt( fname=os.path.join(pair_dir, "affine.txt"), X=arr[sample_index, :, :], delimiter=",", ) continue arr_save_dir = label_dir if on_label else pair_dir save_array( save_dir=arr_save_dir, arr=arr[sample_index, :, :, :], name=name, normalize=normalize, # label's value is already in [0, 1] save_nifti=save_nifti, save_png=save_png, overwrite=arr_save_dir == label_dir, ) # calculate metric sample_index_str = "_".join([str(x) for x in indices_i]) if sample_index_str in sample_index_strs: # pragma: no cover raise ValueError( "Sample is repeated, maybe the dataset has been repeated.") sample_index_strs.append(sample_index_str) metric = calculate_metrics( fixed_image=processed["fixed_image"][0], fixed_label=processed["fixed_label"][0] if model.labeled else None, pred_fixed_image=processed["pred_fixed_image"][0], pred_fixed_label=processed["pred_fixed_label"][0] if model.labeled else None, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) metric["pair_index"] = indices_i[:-1] metric["label_index"] = indices_i[-1] metric_lists.append(metric) # save metric save_metric_dict(save_dir=save_dir, metrics=metric_lists)