def compute_persistence_2DImg(f, dimension=1.0, threshold=0.4):
    """
    Copied from Hu's code under https://github.com/HuXiaoling/TopoLoss. 
    This function computes the persistence diagram of a 2D probability map and the corresponding critical points,
    given a threshold.

    :param f: numpy array of shape NxN
            This is the input image/probability map from which e calculate the topological features.
    :param dimension: float
            Controls the dimension of the topoological features, that we want to calculate.
    :param threshold: float
            When calculating topological features, we have to say, which pixels belong to the boundary 
            and which pixels belong to the voids/holes. Everything above the threshold is boundary, 
            everything below is considered to belong to the hole.
            We have the following color code: Black = 0 = hole, white = 1 = boundary
    """
    # f has to be 2D function
    assert len(f.shape) == 2
    dim = 2
    # Pad the function with a few pixels of maximum values
    # This way one can compute the 1D topology as loops
    # Remember to transform back to the original coordinates when finished
    padwidth = 3
    padvalue = min(f.min(), 0.0)
    f_padded = np.pad(f, padwidth, 'constant', constant_values=padvalue)

    start = time.time()

    # Call persistence code to compute diagrams
    # Loads PersistencePython.so (compiled from C++); should be in current dir
    persistence_result = cubePers(
        np.reshape(f_padded, f_padded.size).tolist(), list(f_padded.shape),
        threshold)

    end = time.time()
    print('cubePers time elapsed ' + str(end - start))

    # Only take 1-dim topology, first column of persistence_result is dimension
    persistence_result_filtered = np.array(
        list(filter(lambda x: x[0] == float(dimension), persistence_result)))

    # Persistence diagram (second and third columns are coordinates)
    # Check if filtration is not empty
    if len(persistence_result_filtered.shape) < 2:
        dgm = np.array([])
        birth_cp_list = np.array([])
        death_cp_list = np.array([])
    else:
        dgm = persistence_result_filtered[:, 1:3]

        # Critical points
        birth_cp_list = persistence_result_filtered[:, 4:4 + dim]
        death_cp_list = persistence_result_filtered[:, 4 + dim:]

        # When mapping back, shift critical points back to the original coordinates
        birth_cp_list = birth_cp_list - padwidth
        death_cp_list = death_cp_list - padwidth
    return dgm, birth_cp_list, death_cp_list
def compute_persistence_2DImg_1DHom(f):
    """
    compute persistence diagram in a 2D function (can be N-dim) and critical pts
    only generate 1D homology dots and critical points
    """
    #assert len(f.shape) == 2  # f has to be 2D function
    #dim = 2
    dim = len(f.shape)

    # pad the function with a few pixels of minimum values
    # this way one can compute the 1D topology as loops
    # remember to transform back to the original coordinates when finished
    padwidth = 2
    padvalue = min(f.min(), 0.0)
    print(padvalue)
    f_padded = np.pad(f, padwidth, 'constant', constant_values=padvalue)

    # call persistence code to compute diagrams
    # loads PersistencePython.so (compiled from C++); should be in current dir
    from PersistencePython import cubePers
    # print (f_padded)
    print((f_padded.shape))
    print(f_padded.size)
    print(type(f_padded))
    print(type(np.reshape(f_padded, f_padded.size)))

    print((np.reshape(f_padded, f_padded.size).tolist()))
    print(type(list(f_padded.shape)))
    # print (np.reshape(f_padded, f_padded.size).tolist())
    persistence_result = cubePers(
        np.reshape(f_padded, f_padded.size).tolist(), list(f_padded.shape),
        0.001)

    # print(type(persistence_result))
    # print (persistence_result)
    # print(len(persistence_result))

    # only take 1-dim topology, first column of persistence_result is dimension
    persistence_result_filtered = np.array(
        list(filter(lambda x: x[0] == 1, persistence_result)))

    # persistence diagram (second and third columns are coordinates)
    # print (persistence_result_filtered)
    dgm = persistence_result_filtered[:, 1:3]

    # critical points
    birth_cp_list = persistence_result_filtered[:, 4:4 + dim]
    death_cp_list = persistence_result_filtered[:, 4 + dim:]

    # when mapping back, shift critical points back to the original coordinates
    birth_cp_list = birth_cp_list - padwidth
    death_cp_list = death_cp_list - padwidth

    return dgm, birth_cp_list, death_cp_list
def compute_persistence_2DImg_1DHom_gt(f,
                                       padwidth=2,
                                       homo_dim=1,
                                       pers_thd=0.001):
    """
    compute persistence diagram in a 2D function (can be N-dim) and critical pts
    only generate 1D homology dots and critical points
    """
    # print (len(f.shape))
    assert len(f.shape) == 2  # f has to be 2D function
    dim = 2

    # pad the function with a few pixels of minimum values
    # this way one can compute the 1D topology as loops
    # remember to transform back to the original coordinates when finished
    #padwidth = 2
    # padvalue = min(f.min(), 0.0)
    padvalue = f.min()
    # print(f)
    # print (type(f.cpu().numpy()))
    if (not isinstance(f, np.ndarray)):
        f_padded = np.pad(f.cpu().detach().numpy(),
                          padwidth,
                          'constant',
                          constant_values=padvalue.cpu().detach().numpy())
    else:
        f_padded = np.pad(f, padwidth, 'constant', constant_values=padvalue)

    # call persistence code to compute diagrams
    # loads PersistencePython.so (compiled from C++); should be in current dir
    #from src.PersistencePython import cubePers
    from PersistencePython import cubePers

    # persistence_result = cubePers(a, list(f_padded.shape), 0.001)
    persistence_result = cubePers(
        np.reshape(f_padded, f_padded.size).tolist(), list(f_padded.shape),
        pers_thd)

    # print("persistence_result", type(persistence_result))
    # print(type(persistence_result))
    # print (persistence_result)
    # print(len(persistence_result))

    # only take 1-dim topology, first column of persistence_result is dimension
    persistence_result_filtered = np.array(
        list(filter(lambda x: x[0] == homo_dim, persistence_result)))

    # persistence diagram (second and third columns are coordinates)
    # print (persistence_result_filtered)
    #print ('shape of persistence_result_filtered')
    #print (persistence_result_filtered.shape)
    if (persistence_result_filtered.shape[0] == 0):
        return np.array([]), np.array([]), np.array([])
    dgm = persistence_result_filtered[:, 1:3]

    # critical points
    birth_cp_list = persistence_result_filtered[:, 4:4 + dim]
    death_cp_list = persistence_result_filtered[:, 4 + dim:]

    # when mapping back, shift critical points back to the original coordinates
    birth_cp_list = birth_cp_list - padwidth
    death_cp_list = death_cp_list - padwidth

    return dgm, birth_cp_list, death_cp_list