示例#1
0
import gudhi

from scripts.ssc.persistence_pairings_visualization.utils_definitions import make_plot
from src.datasets.datasets import SwissRoll

if __name__ == "__main__":
    dataset_sampler = SwissRoll()
    n_points = 2048
    seed = 13
    samples, color = dataset_sampler.sample(n_points, seed=seed)

    tc = gudhi.TangentialComplex(intrisic_dim=1, points=samples)
    tc.compute_tangential_complex()
    simplex_tree = tc.create_simplex_tree()

    print(simplex_tree.get_skeleton(1))

    skeleton_sorted = sorted(simplex_tree.get_skeleton(1), key=lambda t: t[1])

    pairings = []
    for element in skeleton_sorted:
        pair = element[0]
        if len(pair) == 2 and element[1] == 0:
            print(pair)

            pairings.append(pair)

    make_plot(samples, pairings, color, name='witness_TEST')
示例#2
0
                                                   r_max=10,
                                                   create_simplex_tree=False,
                                                   create_metric=True)

        for k in ks:

            print('{} out of {}'.format(counter, ntot))

            landmarks_dist = torch.tensor(witness_complex.landmarks_dist)
            sorted, indices = torch.sort(landmarks_dist)
            kNN_mask = torch.zeros(
                (n_samples, n_samples),
                device='cpu').scatter(1, indices[:, 1:(k + 1)], 1)
            pairings_i = np.where(kNN_mask.numpy() == 1)
            pairings = np.column_stack((pairings_i[0], pairings_i[1]))

            name = 'wc{nw}_k{k}_seed{seed}'.format(nw=N_WITNESSES,
                                                   k=k,
                                                   seed=seed)

            make_plot(landmarks,
                      pairings,
                      color,
                      name=name,
                      path_root=path_to_save,
                      knn=False,
                      show=False,
                      dpi=50)

            counter += 1
示例#3
0
    #
    #             count_pairings(n_samples, pairs_filtered)
    #             make_plot(landmarks, pairs_filtered, color, name=name)

    n_samples_array = [32,48,64,96,128]
    n_witnesses_array = [256,512,1024]
    seeds = [10,13,20]
    n_samples_array = [64]
    n_witnesses_array = [512]
    seeds = [27]
    for n_witnesses in n_witnesses_array:
        for seed in seeds:
            for n_samples in n_samples_array:

                name = 'witness_ssc_corrected_nl{}_nw{}_seed{}'.format(n_samples, n_witnesses, seed)
                dataset_sampler = SwissRoll()
                n_landmarks = n_samples
                seed = seed
                landmarks, color = dataset_sampler.sample(n_landmarks, seed = seed)
                witnesses, _ = dataset_sampler.sample(n_witnesses, seed=(seed+17))


                distances = wl_table(witnesses,landmarks)
                pairs = get_pairs_0(distances)

                pairs_filtered = get_persistence_pairs(pairs, n_samples)

                count_pairings(n_samples, pairs_filtered)
                make_plot(landmarks, pairs_filtered, color, name=name)

from scripts.ssc.persistence_pairings_visualization.utils_definitions import make_plot
from src.datasets.datasets import SwissRoll
from src.topology.witness_complex import WitnessComplex

PATH = '/Users/simons/PycharmProjects/MT-VAEs-TDA/output/SwissRoll_pairings/witness_complex_k/'

if __name__ == "__main__":

    n_landmarks = 512
    n_witnesses = 2048
    seed = 0

    dataset_sampler = SwissRoll()
    landmarks, color = dataset_sampler.sample(n_landmarks, seed=seed)
    witnesses, _ = dataset_sampler.sample(n_witnesses, seed=(seed + 17))

    witness_complex = WitnessComplex(landmarks, witnesses)
    witness_complex.compute_simplicial_complex(1, True, r_max=7)

    for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]:
        name = 'nl{}_nw{}_k{}_seed{}'.format(n_landmarks, n_witnesses, k, seed)

        neigh = NearestNeighbors(n_neighbors=(k + 1),
                                 metric='precomputed').fit(
                                     witness_complex.landmarks_dist)
        distances, pairings = neigh.kneighbors(witness_complex.landmarks_dist)
        print(distances)
        print(pairings)

        make_plot(landmarks, pairings, color, name, path_root=PATH, knn=True)
示例#5
0
        ax.legend((plt_obs, plt_out), ('data', 'outliers'), loc='lower left')
    return data


if __name__ == "__main__":

    dataset_sampler = SwissRoll()
    name = '128_noise4_DTM'
    data, color = dataset_sampler.sample(128, noise=0.4)

    st = DTMFiltration(data, m=0.01, p=10, dimension_max=1)
    st.persistence()
    pers_pairs = st.persistence_pairs()
    print(pers_pairs)
    pairings = np.array([[pers_pairs[0][1][0], pers_pairs[0][1][1]]])
    for pair in pers_pairs[1:-1]:
        pairings = np.vstack((pairings, np.array([[pair[1][0], pair[1][1]]])))

    make_plot(data, pairings, color, name=name)

    # name = '128_noise4_reg'
    # make_data(data, color, name = name)
    #
    # path_pairings = '{}pairings_{}.npy'.format(PATH_ROOT, name)
    # path_data = '{}data_{}.npy'.format(PATH_ROOT, name)
    # path_color = '{}color_{}.npy'.format(PATH_ROOT, name)
    # pairings, data, color = np.load(path_pairings), np.load(path_data), np.load(path_color)
    # print(type(pairings))
    # print(pairings.shape)
    # #
    # make_plot(data, pairings, color, name = name)
示例#6
0
        points_tensor = torch.from_numpy(points)
        pairwise_distances = torch.norm(points_tensor[:, None] - points_tensor,
                                        dim=2,
                                        p=2)
        sorted, indices = torch.sort(pairwise_distances)

        for k in ks:

            print('{} out of {}'.format(counter, ntot))

            kNN_mask = torch.zeros(
                (points_tensor.size(0),
                 points_tensor.size(0))).scatter(1, indices[:, 1:(k + 1)], 1)

            pairings_i = np.where(kNN_mask.numpy() == 1)
            pairings = np.column_stack((pairings_i[0], pairings_i[1]))

            name = 'knn_k{k}_seed{seed}'.format(k=k, seed=seed)

            make_plot(points,
                      pairings,
                      color,
                      name=name,
                      path_root=path_to_save,
                      knn=False,
                      show=True,
                      dpi=400,
                      cmap=plt.cm.viridis)

            counter += 1
if __name__ == "__main__":
    alpha_beta = []
    for i in range(20):
        alpha_beta.append([2**(0.25 * i + 0.0000001), 2**(0.25 * i)])

    n_samples_array = [32, 48, 64, 96, 128]
    n_witnesses_array = [256, 512]
    seeds = [10, 13, 20]
    for n_witnesses in n_witnesses_array:
        for seed in seeds:
            for n_samples in n_samples_array:

                name = 'witness_alphabeta_nl{}_nw{}_seed{}'.format(
                    n_samples, n_witnesses, seed)
                dataset_sampler = SwissRoll()
                n_landmarks = n_samples
                seed = seed
                landmarks, color = dataset_sampler.sample(n_landmarks,
                                                          seed=seed)
                witnesses, _ = dataset_sampler.sample(n_witnesses,
                                                      seed=(seed + 17))

                nlt = get_nlt(landmarks, witnesses)

                pairings = pseudo_alpha_beta_witness_complex(
                    nlt, alpha_beta, n_landmarks)

                count_pairings(n_landmarks, pairings)

                make_plot(landmarks, pairings, color, name=name)