def add_label_loss( model: tf.keras.Model, grid_fixed: tf.Tensor, fixed_label: (tf.Tensor, None), pred_fixed_label: (tf.Tensor, None), loss_config: dict, ) -> tf.keras.Model: """ Add label dissimilarity loss of ddf into model. :param model: tf.keras.Model :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3) :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) :param loss_config: config for loss """ if fixed_label is not None: loss_label = tf.reduce_mean( label_loss.get_dissimilarity_fn( config=loss_config["dissimilarity"]["label"])( y_true=fixed_label, y_pred=pred_fixed_label)) weighted_loss_label = (loss_label * loss_config["dissimilarity"]["label"]["weight"]) model.add_loss(weighted_loss_label) model.add_metric(loss_label, name="loss/label_dissimilarity", aggregation="mean") model.add_metric( weighted_loss_label, name="loss/weighted_label_dissimilarity", aggregation="mean", ) # metrics dice_binary = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=True) dice_float = label_loss.dice_score(y_true=fixed_label, y_pred=pred_fixed_label, binary=False) tre = label_loss.compute_centroid_distance(y_true=fixed_label, y_pred=pred_fixed_label, grid=grid_fixed) foreground_label = label_loss.foreground_proportion(y=fixed_label) foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label) model.add_metric(dice_binary, name="metric/dice_binary", aggregation="mean") model.add_metric(dice_float, name="metric/dice_float", aggregation="mean") model.add_metric(tre, name="metric/tre", aggregation="mean") model.add_metric(foreground_label, name="metric/foreground_label", aggregation="mean") model.add_metric(foreground_pred, name="metric/foreground_pred", aggregation="mean") return model
def test_dice_binary(): """ Testing dice score with not binary tensor to assert thresholding works. """ array_eye = 0.6 * np.identity((3)) tensor_eye = np.zeros((3, 3, 3, 3)) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_pred = np.zeros((3, 3, 3, 3)) tensor_pred[:, 0:2, :, :] = array_eye num = 2 * np.array([6, 6, 6]) denom = np.array([9, 9, 9]) + np.array([6, 6, 6]) get = num / denom expect = label.dice_score(tensor_eye, tensor_pred, binary=True) assert assertTensorsEqual(get, expect)
def test_dice_not_binary(): """ Testing dice score with binary tensor comparing to a precomputed value. """ array_eye = np.identity((3)) tensor_eye = np.zeros((3, 3, 3, 3)) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_pred = np.zeros((3, 3, 3, 3)) tensor_pred[:, 0:2, :, :] = array_eye num = 2 * np.array([6, 6, 6]) denom = np.array([9, 9, 9]) + np.array([6, 6, 6]) get = num / denom expect = label.dice_score(tensor_eye, tensor_pred) assert assertTensorsEqual(get, expect)
def calculate_metrics( fixed_image: tf.Tensor, fixed_label: (tf.Tensor, None), pred_fixed_image: (tf.Tensor, None), pred_fixed_label: (tf.Tensor, None), fixed_grid_ref: tf.Tensor, sample_index: int, ) -> dict: """ Calculate image/label based metrics :param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) :param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None :param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) :param sample_index: int, :return: dictionary of metrics """ if pred_fixed_image is not None: y_true = fixed_image[sample_index : (sample_index + 1), :, :, :] y_pred = pred_fixed_image[sample_index : (sample_index + 1), :, :, :] y_true = tf.expand_dims(y_true, axis=4) y_pred = tf.expand_dims(y_pred, axis=4) ssd = image_loss.ssd(y_true=y_true, y_pred=y_pred).numpy()[0] else: ssd = None if fixed_label is not None and pred_fixed_label is not None: y_true = fixed_label[sample_index : (sample_index + 1), :, :, :] y_pred = pred_fixed_label[sample_index : (sample_index + 1), :, :, :] dice = label_loss.dice_score(y_true=y_true, y_pred=y_pred, binary=True).numpy()[ 0 ] tre = label_loss.compute_centroid_distance( y_true=y_true, y_pred=y_pred, grid=fixed_grid_ref[0, :, :, :, :] ).numpy()[0] else: dice = None tre = None return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre)
def test_dice_binary(): """ Testing dice score with not binary tensor to assert thresholding works. """ array_eye = 0.6 * np.identity(3, dtype=np.float32) tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32) tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_pred[:, 0:2, :, :] = array_eye tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32) num = 2 * np.array([6, 6, 6]) denom = np.array([9, 9, 9]) + np.array([6, 6, 6]) get = num / denom expect = label.dice_score(tensor_eye, tensor_pred, binary=True) assert is_equal_tf(get, expect)
def test_dice_not_binary(): """ Testing dice score with binary tensor comparing to a precomputed value. """ array_eye = np.identity(3, dtype=np.float32) tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_eye[:, :, 0:3, 0:3] = array_eye tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32) tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32) tensor_pred[:, 0:2, :, :] = array_eye tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32) num = 2 * np.array([6, 6, 6]) denom = np.array([9, 9, 9]) + np.array([6, 6, 6]) get = num / denom expect = label.dice_score(tensor_eye, tensor_pred) assert is_equal_tf(get, expect)
def predict(data_loader, dataset, fixed_grid_ref, model, save_dir): metric_map = dict( ) # map[image_index][label_index][metric_name] = metric_value for i, (inputs, labels) in enumerate(dataset): # pred_fixed_label [batch, f_dim1, f_dim2, f_dim3] # moving_image [batch, m_dim1, m_dim2, m_dim3] # fixed_image [batch, f_dim1, f_dim2, f_dim3] # moving_label [batch, m_dim1, m_dim2, m_dim3] # fixed_label [batch, f_dim1, f_dim2, f_dim3] if hasattr(model, "ddf"): model_ddf = tf.keras.Model(inputs=model.inputs, outputs=model.outputs + [model.ddf]) pred_fixed_label, ddf = model_ddf.predict(x=inputs) else: pred_fixed_label = model.predict(x=inputs) ddf = None moving_image, fixed_image, moving_label, indices = inputs fixed_label = labels num_samples = moving_image.shape[0] moving_depth = moving_image.shape[3] fixed_depth = fixed_image.shape[3] image_dir_format = save_dir + "/{image_dir:s}/label{label_index:d}" for sample_index in range(num_samples): image_index, label_index = data_loader.split_indices( indices[sample_index, :].numpy().astype(int).tolist()) # save fixed image_dir = image_dir_format.format( image_dir=data_loader.image_index_to_dir(image_index), label_index=label_index) filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png" if not os.path.exists(image_dir): os.makedirs(image_dir) for fixed_depth_index in range(fixed_depth): fixed_image_d = fixed_image[sample_index, :, :, fixed_depth_index] fixed_label_d = fixed_label[sample_index, :, :, fixed_depth_index] fixed_pred_d = pred_fixed_label[sample_index, :, :, fixed_depth_index] plt.imsave(filename_format.format( depth_index=fixed_depth_index, name="fixed_image"), fixed_image_d, cmap='gray' ) # value range for h5 and nifti might be different plt.imsave(filename_format.format( depth_index=fixed_depth_index, name="fixed_label"), fixed_label_d, vmin=0, vmax=1, cmap='gray') plt.imsave(filename_format.format( depth_index=fixed_depth_index, name="fixed_pred"), fixed_pred_d, vmin=0, vmax=1, cmap='gray') # save moving image_dir = image_dir_format.format( image_dir=data_loader.image_index_to_dir(image_index), label_index=label_index) filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png" if not os.path.exists(image_dir): os.makedirs(image_dir) for moving_depth_index in range(moving_depth): moving_image_d = moving_image[sample_index, :, :, moving_depth_index] moving_label_d = moving_label[sample_index, :, :, moving_depth_index] plt.imsave(filename_format.format( depth_index=moving_depth_index, name="moving_image"), moving_image_d, cmap='gray' ) # value range for h5 and nifti might be different plt.imsave(filename_format.format( depth_index=moving_depth_index, name="moving_label"), moving_label_d, vmin=0, vmax=1, cmap='gray') # save ddf if exists if ddf is not None: image_dir = image_dir_format.format( image_dir=data_loader.image_index_to_dir(image_index), label_index=label_index) filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png" if not os.path.exists(image_dir): os.makedirs(image_dir) for fixed_depth_index in range(fixed_depth): ddf_d = ddf[sample_index, :, :, fixed_depth_index, :] # [f_dim1, f_dim2, 3] ddf_max, ddf_min = np.max(ddf_d), np.min(ddf_d) ddf_d = (ddf_d - ddf_min) / (ddf_max - ddf_min) plt.imsave( filename_format.format(depth_index=fixed_depth_index, name="ddf"), ddf_d) # calculate metric label = fixed_label[sample_index:(sample_index + 1), :, :, :] pred = pred_fixed_label[sample_index:(sample_index + 1), :, :, :] dice = label_loss.dice_score(y_true=label, y_pred=pred, binary=True) dist = label_loss.compute_centroid_distance(y_true=label, y_pred=pred, grid=fixed_grid_ref) # save metric if image_index not in metric_map.keys(): metric_map[image_index] = dict() assert label_index not in metric_map[image_index].keys( ) # label should not be repeated metric_map[image_index][label_index] = dict(dice=dice.numpy()[0], dist=dist.numpy()[0]) # print metric line_format = "{image_dir:s}, label {label_index:d}, dice {dice:.4f}, dist {dist:.4f}\n" with open(save_dir + "/metric.log", "w+") as f: for image_index in sorted(metric_map.keys()): for label_index in sorted(metric_map[image_index].keys()): f.write( line_format.format( image_dir=data_loader.image_index_to_dir(image_index), label_index=label_index, **metric_map[image_index][label_index]))
def predict_on_dataset(dataset, fixed_grid_ref, model, save_dir): """ Function to predict results from a dataset from some model :param dataset: where data is stored :param fixed_grid_ref: :param model: :param save_dir: str, path to store dir """ metric_map = dict( ) # map[image_index][label_index][metric_name] = metric_value for _, inputs_dict in enumerate(dataset): # pred_fixed_label [batch, f_dim1, f_dim2, f_dim3] # moving_image [batch, m_dim1, m_dim2, m_dim3] # fixed_image [batch, f_dim1, f_dim2, f_dim3] # moving_label [batch, m_dim1, m_dim2, m_dim3] # fixed_label [batch, f_dim1, f_dim2, f_dim3] outputs_dict = model.predict(x=inputs_dict) moving_image = inputs_dict.get("moving_image") fixed_image = inputs_dict.get("fixed_image") indices = inputs_dict.get("indices") moving_label = inputs_dict.get("moving_label", None) fixed_label = inputs_dict.get("fixed_label", None) ddf = outputs_dict.get("ddf", None) dvf = outputs_dict.get("dvf", None) pred_fixed_label = outputs_dict.get("pred_fixed_label", None) labeled = moving_label is not None num_samples = moving_image.shape[0] moving_depth = moving_image.shape[3] fixed_depth = fixed_image.shape[3] for sample_index in range(num_samples): indices_i = indices[sample_index, :].numpy().astype(int).tolist() image_index = "_".join([str(x) for x in indices_i[:-1]]) label_index = str(indices_i[-1]) # save fixed image_dir = os.path.join(save_dir, "image%s" % image_index) if labeled: image_dir = os.path.join(image_dir, "label%s" % label_index) filename_format = os.path.join( image_dir, "depth{depth_index:d}_{name:s}.png") if not os.path.exists(image_dir): os.makedirs(image_dir) for fixed_depth_index in range(fixed_depth): fixed_image_d = fixed_image[sample_index, :, :, fixed_depth_index] plt.imsave( filename_format.format(depth_index=fixed_depth_index, name="fixed_image"), fixed_image_d, cmap="gray", ) # value range for h5 and nifti might be different if labeled: fixed_label_d = fixed_label[sample_index, :, :, fixed_depth_index] fixed_pred_d = pred_fixed_label[sample_index, :, :, fixed_depth_index] plt.imsave( filename_format.format(depth_index=fixed_depth_index, name="fixed_label"), fixed_label_d, vmin=0, vmax=1, cmap="gray", ) plt.imsave( filename_format.format(depth_index=fixed_depth_index, name="fixed_label_pred"), fixed_pred_d, vmin=0, vmax=1, cmap="gray", ) # save moving if not os.path.exists(image_dir): os.makedirs(image_dir) for moving_depth_index in range(moving_depth): moving_image_d = moving_image[sample_index, :, :, moving_depth_index] plt.imsave( filename_format.format(depth_index=moving_depth_index, name="moving_image"), moving_image_d, cmap="gray", ) # value range for h5 and nifti might be different if labeled: moving_label_d = moving_label[sample_index, :, :, moving_depth_index] plt.imsave( filename_format.format(depth_index=moving_depth_index, name="moving_label"), moving_label_d, vmin=0, vmax=1, cmap="gray", ) # save ddf / dvf if exists for field, field_name in zip([ddf, dvf], ["ddf", "dvf"]): if field is not None: for fixed_depth_index in range(fixed_depth): field_d = field[ sample_index, :, :, fixed_depth_index, :] # [f_dim1, f_dim2, 3] field_max, field_min = np.max(field_d), np.min(field_d) field_d = (field_d - field_min) / np.maximum( field_max - field_min, EPS) plt.imsave( filename_format.format( depth_index=fixed_depth_index, name=field_name), field_d, ) # calculate metric if labeled: label = fixed_label[sample_index:(sample_index + 1), :, :, :] pred = pred_fixed_label[sample_index:(sample_index + 1), :, :, :] dice = label_loss.dice_score(y_true=label, y_pred=pred, binary=True) dist = label_loss.compute_centroid_distance( y_true=label, y_pred=pred, grid=fixed_grid_ref) # save metric if image_index not in metric_map.keys(): metric_map[image_index] = dict() # label should not be repeated - assert that it is not in keys assert label_index not in metric_map[image_index].keys() metric_map[image_index][label_index] = dict( dice=dice.numpy()[0], dist=dist.numpy()[0]) # print metric line_format = ( "{image_index:s}, label {label_index:s}, dice {dice:.4f}, dist {dist:.4f}\n" ) with open(save_dir + "/metric.log", "w+") as file: for image_index in sorted(metric_map.keys()): for label_index in sorted(metric_map[image_index].keys()): file.write( line_format.format( image_index=image_index, label_index=label_index, **metric_map[image_index][label_index], ))
def fn(self, y_true, y_pred): return label_loss.dice_score(y_true=y_true, y_pred=y_pred, binary=True)