コード例 #1
0
ファイル: util.py プロジェクト: zcemycl/DeepReg
def add_label_loss(
    model: tf.keras.Model,
    grid_fixed: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    loss_config: dict,
) -> tf.keras.Model:
    """
    Add label dissimilarity loss of ddf into model.

    :param model: tf.keras.Model
    :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3)
    :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param loss_config: config for loss
    """
    if fixed_label is not None:
        loss_label = tf.reduce_mean(
            label_loss.get_dissimilarity_fn(
                config=loss_config["dissimilarity"]["label"])(
                    y_true=fixed_label, y_pred=pred_fixed_label))
        weighted_loss_label = (loss_label *
                               loss_config["dissimilarity"]["label"]["weight"])
        model.add_loss(weighted_loss_label)
        model.add_metric(loss_label,
                         name="loss/label_dissimilarity",
                         aggregation="mean")
        model.add_metric(
            weighted_loss_label,
            name="loss/weighted_label_dissimilarity",
            aggregation="mean",
        )

        # metrics
        dice_binary = label_loss.dice_score(y_true=fixed_label,
                                            y_pred=pred_fixed_label,
                                            binary=True)
        dice_float = label_loss.dice_score(y_true=fixed_label,
                                           y_pred=pred_fixed_label,
                                           binary=False)
        tre = label_loss.compute_centroid_distance(y_true=fixed_label,
                                                   y_pred=pred_fixed_label,
                                                   grid=grid_fixed)
        foreground_label = label_loss.foreground_proportion(y=fixed_label)
        foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label)
        model.add_metric(dice_binary,
                         name="metric/dice_binary",
                         aggregation="mean")
        model.add_metric(dice_float,
                         name="metric/dice_float",
                         aggregation="mean")
        model.add_metric(tre, name="metric/tre", aggregation="mean")
        model.add_metric(foreground_label,
                         name="metric/foreground_label",
                         aggregation="mean")
        model.add_metric(foreground_pred,
                         name="metric/foreground_pred",
                         aggregation="mean")
    return model
コード例 #2
0
def test_dice_binary():
    """
    Testing dice score with not binary tensor
    to assert thresholding works.
    """
    array_eye = 0.6 * np.identity((3))
    tensor_eye = np.zeros((3, 3, 3, 3))
    tensor_eye[:, :, 0:3, 0:3] = array_eye

    tensor_pred = np.zeros((3, 3, 3, 3))
    tensor_pred[:, 0:2, :, :] = array_eye

    num = 2 * np.array([6, 6, 6])
    denom = np.array([9, 9, 9]) + np.array([6, 6, 6])

    get = num / denom
    expect = label.dice_score(tensor_eye, tensor_pred, binary=True)
    assert assertTensorsEqual(get, expect)
コード例 #3
0
def test_dice_not_binary():
    """
    Testing dice score with binary tensor
    comparing to a precomputed value.
    """
    array_eye = np.identity((3))
    tensor_eye = np.zeros((3, 3, 3, 3))
    tensor_eye[:, :, 0:3, 0:3] = array_eye

    tensor_pred = np.zeros((3, 3, 3, 3))
    tensor_pred[:, 0:2, :, :] = array_eye

    num = 2 * np.array([6, 6, 6])
    denom = np.array([9, 9, 9]) + np.array([6, 6, 6])

    get = num / denom
    expect = label.dice_score(tensor_eye, tensor_pred)
    assert assertTensorsEqual(get, expect)
コード例 #4
0
def calculate_metrics(
    fixed_image: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_image: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    fixed_grid_ref: tf.Tensor,
    sample_index: int,
) -> dict:
    """
    Calculate image/label based metrics
    :param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
    :param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
    :param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param sample_index: int,
    :return: dictionary of metrics
    """

    if pred_fixed_image is not None:
        y_true = fixed_image[sample_index : (sample_index + 1), :, :, :]
        y_pred = pred_fixed_image[sample_index : (sample_index + 1), :, :, :]
        y_true = tf.expand_dims(y_true, axis=4)
        y_pred = tf.expand_dims(y_pred, axis=4)
        ssd = image_loss.ssd(y_true=y_true, y_pred=y_pred).numpy()[0]
    else:
        ssd = None

    if fixed_label is not None and pred_fixed_label is not None:
        y_true = fixed_label[sample_index : (sample_index + 1), :, :, :]
        y_pred = pred_fixed_label[sample_index : (sample_index + 1), :, :, :]
        dice = label_loss.dice_score(y_true=y_true, y_pred=y_pred, binary=True).numpy()[
            0
        ]
        tre = label_loss.compute_centroid_distance(
            y_true=y_true, y_pred=y_pred, grid=fixed_grid_ref[0, :, :, :, :]
        ).numpy()[0]
    else:
        dice = None
        tre = None

    return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre)
コード例 #5
0
def test_dice_binary():
    """
    Testing dice score with not binary tensor
    to assert thresholding works.
    """
    array_eye = 0.6 * np.identity(3, dtype=np.float32)
    tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_eye[:, :, 0:3, 0:3] = array_eye
    tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32)

    tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_pred[:, 0:2, :, :] = array_eye
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)

    num = 2 * np.array([6, 6, 6])
    denom = np.array([9, 9, 9]) + np.array([6, 6, 6])

    get = num / denom
    expect = label.dice_score(tensor_eye, tensor_pred, binary=True)
    assert is_equal_tf(get, expect)
コード例 #6
0
def test_dice_not_binary():
    """
    Testing dice score with binary tensor
    comparing to a precomputed value.
    """
    array_eye = np.identity(3, dtype=np.float32)
    tensor_eye = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_eye[:, :, 0:3, 0:3] = array_eye
    tensor_eye = tf.convert_to_tensor(tensor_eye, dtype=tf.float32)

    tensor_pred = np.zeros((3, 3, 3, 3), dtype=np.float32)
    tensor_pred[:, 0:2, :, :] = array_eye
    tensor_pred = tf.convert_to_tensor(tensor_pred, dtype=tf.float32)

    num = 2 * np.array([6, 6, 6])
    denom = np.array([9, 9, 9]) + np.array([6, 6, 6])

    get = num / denom
    expect = label.dice_score(tensor_eye, tensor_pred)
    assert is_equal_tf(get, expect)
コード例 #7
0
ファイル: predict.py プロジェクト: NMontanaBrown/DeepReg
def predict(data_loader, dataset, fixed_grid_ref, model, save_dir):
    metric_map = dict(
    )  # map[image_index][label_index][metric_name] = metric_value
    for i, (inputs, labels) in enumerate(dataset):
        # pred_fixed_label [batch, f_dim1, f_dim2, f_dim3]
        # moving_image     [batch, m_dim1, m_dim2, m_dim3]
        # fixed_image      [batch, f_dim1, f_dim2, f_dim3]
        # moving_label     [batch, m_dim1, m_dim2, m_dim3]
        # fixed_label      [batch, f_dim1, f_dim2, f_dim3]
        if hasattr(model, "ddf"):
            model_ddf = tf.keras.Model(inputs=model.inputs,
                                       outputs=model.outputs + [model.ddf])
            pred_fixed_label, ddf = model_ddf.predict(x=inputs)
        else:
            pred_fixed_label = model.predict(x=inputs)
            ddf = None

        moving_image, fixed_image, moving_label, indices = inputs
        fixed_label = labels
        num_samples = moving_image.shape[0]
        moving_depth = moving_image.shape[3]
        fixed_depth = fixed_image.shape[3]

        image_dir_format = save_dir + "/{image_dir:s}/label{label_index:d}"
        for sample_index in range(num_samples):
            image_index, label_index = data_loader.split_indices(
                indices[sample_index, :].numpy().astype(int).tolist())

            # save fixed
            image_dir = image_dir_format.format(
                image_dir=data_loader.image_index_to_dir(image_index),
                label_index=label_index)
            filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png"
            if not os.path.exists(image_dir):
                os.makedirs(image_dir)
            for fixed_depth_index in range(fixed_depth):
                fixed_image_d = fixed_image[sample_index, :, :,
                                            fixed_depth_index]
                fixed_label_d = fixed_label[sample_index, :, :,
                                            fixed_depth_index]
                fixed_pred_d = pred_fixed_label[sample_index, :, :,
                                                fixed_depth_index]
                plt.imsave(filename_format.format(
                    depth_index=fixed_depth_index, name="fixed_image"),
                           fixed_image_d,
                           cmap='gray'
                           )  # value range for h5 and nifti might be different
                plt.imsave(filename_format.format(
                    depth_index=fixed_depth_index, name="fixed_label"),
                           fixed_label_d,
                           vmin=0,
                           vmax=1,
                           cmap='gray')
                plt.imsave(filename_format.format(
                    depth_index=fixed_depth_index, name="fixed_pred"),
                           fixed_pred_d,
                           vmin=0,
                           vmax=1,
                           cmap='gray')

            # save moving
            image_dir = image_dir_format.format(
                image_dir=data_loader.image_index_to_dir(image_index),
                label_index=label_index)
            filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png"
            if not os.path.exists(image_dir):
                os.makedirs(image_dir)
            for moving_depth_index in range(moving_depth):
                moving_image_d = moving_image[sample_index, :, :,
                                              moving_depth_index]
                moving_label_d = moving_label[sample_index, :, :,
                                              moving_depth_index]
                plt.imsave(filename_format.format(
                    depth_index=moving_depth_index, name="moving_image"),
                           moving_image_d,
                           cmap='gray'
                           )  # value range for h5 and nifti might be different
                plt.imsave(filename_format.format(
                    depth_index=moving_depth_index, name="moving_label"),
                           moving_label_d,
                           vmin=0,
                           vmax=1,
                           cmap='gray')

            # save ddf if exists
            if ddf is not None:
                image_dir = image_dir_format.format(
                    image_dir=data_loader.image_index_to_dir(image_index),
                    label_index=label_index)
                filename_format = image_dir + "/depth{depth_index:d}_{name:s}.png"
                if not os.path.exists(image_dir):
                    os.makedirs(image_dir)
                for fixed_depth_index in range(fixed_depth):
                    ddf_d = ddf[sample_index, :, :,
                                fixed_depth_index, :]  # [f_dim1, f_dim2,  3]
                    ddf_max, ddf_min = np.max(ddf_d), np.min(ddf_d)
                    ddf_d = (ddf_d - ddf_min) / (ddf_max - ddf_min)
                    plt.imsave(
                        filename_format.format(depth_index=fixed_depth_index,
                                               name="ddf"), ddf_d)

            # calculate metric
            label = fixed_label[sample_index:(sample_index + 1), :, :, :]
            pred = pred_fixed_label[sample_index:(sample_index + 1), :, :, :]
            dice = label_loss.dice_score(y_true=label,
                                         y_pred=pred,
                                         binary=True)
            dist = label_loss.compute_centroid_distance(y_true=label,
                                                        y_pred=pred,
                                                        grid=fixed_grid_ref)

            # save metric
            if image_index not in metric_map.keys():
                metric_map[image_index] = dict()
            assert label_index not in metric_map[image_index].keys(
            )  # label should not be repeated
            metric_map[image_index][label_index] = dict(dice=dice.numpy()[0],
                                                        dist=dist.numpy()[0])

    # print metric
    line_format = "{image_dir:s}, label {label_index:d}, dice {dice:.4f}, dist {dist:.4f}\n"
    with open(save_dir + "/metric.log", "w+") as f:
        for image_index in sorted(metric_map.keys()):
            for label_index in sorted(metric_map[image_index].keys()):
                f.write(
                    line_format.format(
                        image_dir=data_loader.image_index_to_dir(image_index),
                        label_index=label_index,
                        **metric_map[image_index][label_index]))
コード例 #8
0
ファイル: predict.py プロジェクト: knvsmadhav/DeepReg
def predict_on_dataset(dataset, fixed_grid_ref, model, save_dir):
    """
    Function to predict results from a dataset from some model
    :param dataset: where data is stored
    :param fixed_grid_ref:
    :param model:
    :param save_dir: str, path to store dir
    """
    metric_map = dict(
    )  # map[image_index][label_index][metric_name] = metric_value
    for _, inputs_dict in enumerate(dataset):
        # pred_fixed_label [batch, f_dim1, f_dim2, f_dim3]
        # moving_image     [batch, m_dim1, m_dim2, m_dim3]
        # fixed_image      [batch, f_dim1, f_dim2, f_dim3]
        # moving_label     [batch, m_dim1, m_dim2, m_dim3]
        # fixed_label      [batch, f_dim1, f_dim2, f_dim3]
        outputs_dict = model.predict(x=inputs_dict)

        moving_image = inputs_dict.get("moving_image")
        fixed_image = inputs_dict.get("fixed_image")
        indices = inputs_dict.get("indices")
        moving_label = inputs_dict.get("moving_label", None)
        fixed_label = inputs_dict.get("fixed_label", None)

        ddf = outputs_dict.get("ddf", None)
        dvf = outputs_dict.get("dvf", None)
        pred_fixed_label = outputs_dict.get("pred_fixed_label", None)

        labeled = moving_label is not None

        num_samples = moving_image.shape[0]
        moving_depth = moving_image.shape[3]
        fixed_depth = fixed_image.shape[3]

        for sample_index in range(num_samples):
            indices_i = indices[sample_index, :].numpy().astype(int).tolist()
            image_index = "_".join([str(x) for x in indices_i[:-1]])
            label_index = str(indices_i[-1])

            # save fixed
            image_dir = os.path.join(save_dir, "image%s" % image_index)
            if labeled:
                image_dir = os.path.join(image_dir, "label%s" % label_index)

            filename_format = os.path.join(
                image_dir, "depth{depth_index:d}_{name:s}.png")
            if not os.path.exists(image_dir):
                os.makedirs(image_dir)
            for fixed_depth_index in range(fixed_depth):
                fixed_image_d = fixed_image[sample_index, :, :,
                                            fixed_depth_index]
                plt.imsave(
                    filename_format.format(depth_index=fixed_depth_index,
                                           name="fixed_image"),
                    fixed_image_d,
                    cmap="gray",
                )  # value range for h5 and nifti might be different

                if labeled:
                    fixed_label_d = fixed_label[sample_index, :, :,
                                                fixed_depth_index]
                    fixed_pred_d = pred_fixed_label[sample_index, :, :,
                                                    fixed_depth_index]

                    plt.imsave(
                        filename_format.format(depth_index=fixed_depth_index,
                                               name="fixed_label"),
                        fixed_label_d,
                        vmin=0,
                        vmax=1,
                        cmap="gray",
                    )
                    plt.imsave(
                        filename_format.format(depth_index=fixed_depth_index,
                                               name="fixed_label_pred"),
                        fixed_pred_d,
                        vmin=0,
                        vmax=1,
                        cmap="gray",
                    )

            # save moving
            if not os.path.exists(image_dir):
                os.makedirs(image_dir)
            for moving_depth_index in range(moving_depth):
                moving_image_d = moving_image[sample_index, :, :,
                                              moving_depth_index]
                plt.imsave(
                    filename_format.format(depth_index=moving_depth_index,
                                           name="moving_image"),
                    moving_image_d,
                    cmap="gray",
                )  # value range for h5 and nifti might be different
                if labeled:
                    moving_label_d = moving_label[sample_index, :, :,
                                                  moving_depth_index]
                    plt.imsave(
                        filename_format.format(depth_index=moving_depth_index,
                                               name="moving_label"),
                        moving_label_d,
                        vmin=0,
                        vmax=1,
                        cmap="gray",
                    )

            # save ddf / dvf if exists
            for field, field_name in zip([ddf, dvf], ["ddf", "dvf"]):
                if field is not None:
                    for fixed_depth_index in range(fixed_depth):
                        field_d = field[
                            sample_index, :, :,
                            fixed_depth_index, :]  # [f_dim1, f_dim2,  3]
                        field_max, field_min = np.max(field_d), np.min(field_d)
                        field_d = (field_d - field_min) / np.maximum(
                            field_max - field_min, EPS)
                        plt.imsave(
                            filename_format.format(
                                depth_index=fixed_depth_index,
                                name=field_name),
                            field_d,
                        )

            # calculate metric
            if labeled:
                label = fixed_label[sample_index:(sample_index + 1), :, :, :]
                pred = pred_fixed_label[sample_index:(sample_index +
                                                      1), :, :, :]
                dice = label_loss.dice_score(y_true=label,
                                             y_pred=pred,
                                             binary=True)
                dist = label_loss.compute_centroid_distance(
                    y_true=label, y_pred=pred, grid=fixed_grid_ref)

                # save metric
                if image_index not in metric_map.keys():
                    metric_map[image_index] = dict()
                # label should not be repeated - assert that it is not in keys
                assert label_index not in metric_map[image_index].keys()
                metric_map[image_index][label_index] = dict(
                    dice=dice.numpy()[0], dist=dist.numpy()[0])

    # print metric
    line_format = (
        "{image_index:s}, label {label_index:s}, dice {dice:.4f}, dist {dist:.4f}\n"
    )
    with open(save_dir + "/metric.log", "w+") as file:
        for image_index in sorted(metric_map.keys()):
            for label_index in sorted(metric_map[image_index].keys()):
                file.write(
                    line_format.format(
                        image_index=image_index,
                        label_index=label_index,
                        **metric_map[image_index][label_index],
                    ))
コード例 #9
0
 def fn(self, y_true, y_pred):
     return label_loss.dice_score(y_true=y_true, y_pred=y_pred, binary=True)