def test_vmf_log_detect_breakage():
    '''
    Find where scipy approximation breaks down.
    This doesn't really test anything but demonstrates where approximation
    should be applied instead.
    '''
    n_examples = 3
    kappas = [5, 30, 100, 1000, 5000]
    n_features = range(2, 500)

    breakage_points = []
    for kappa in kappas:
        first_breakage = None
        for n_f in n_features:
            mu = np.random.randn(n_f)
            mu /= np.linalg.norm(mu)

            X = np.random.randn(n_examples, n_f)
            for ee in range(n_examples):
                X[ee, :] /= np.linalg.norm(X[ee, :])

            try:
                von_mises_fisher_mixture._vmf_log(X, kappa, mu)
            except:
                if first_breakage is None:
                    first_breakage = n_f

        breakage_points.append(first_breakage)
        print('Scipy vmf_log breaks for kappa={} at n_features={}'.format(
            kappa, first_breakage))

    print(breakage_points)
    assert_array_equal(breakage_points, [141, 420, 311, 3, 3])
def test_vmf_log_dense():
    '''
    Test that approximation approaches whatever scipy has.
    '''
    n_examples = 2
    n_features = 50

    kappas = np.linspace(2, 600, 20)

    mu = np.random.randn(n_features)
    mu /= np.linalg.norm(mu)

    X = np.random.randn(n_examples, n_features)
    for ee in range(n_examples):
        X[ee, :] /= np.linalg.norm(X[ee, :])

    diffs = []
    for kappa in kappas:
        v = von_mises_fisher_mixture._vmf_log(X, kappa, mu)

        v_approx = von_mises_fisher_mixture._vmf_log_asymptotic(X, kappa, mu)

        normalized_approx_diff = (np.linalg.norm(v - v_approx) /
                                  np.linalg.norm(v))
        print(normalized_approx_diff)
        diffs.append(normalized_approx_diff)

    assert diffs[0] > 10 * diffs[-1]
Exemplo n.º 3
0
def sim_multimodal():

    to_save = []
    data_cells = []

    # Data pre-processing
    for i in range(40):
        read_sim(i,
                 f_in='datasets/multimodal_sim2.npy',
                 f_out='datasets/transformed/multimodal_sim2_cell' + str(i))

    # Angles to query
    Thq = np.linspace(-np.pi, np.pi, 360)[:, None]
    Xq = np.hstack((np.cos(Thq), np.sin(Thq)))

    # Fit one cell at a time
    for i in range(40):
        print('\ncell no={}'.format(i))

        try:
            # Read data
            read_data = np.load('datasets/transformed/multimodal_sim2_cell' +
                                str(i) + '.npz')
            data, xx, yy = read_data['data'], read_data['xx'], read_data['yy']
            if data.shape[0] <= 1:
                continue

            # Data
            Th = data[:, 4][:, None]
            X = np.hstack((np.cos(Th), np.sin(Th)))
            db = DBSCAN().fit(X)
            core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
            core_samples_mask[db.core_sample_indices_] = True
            labels = db.labels_
            # Number of clusters in labels, ignoring noise if present.
            n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
            unique_labels = set(labels)
            print("n_clusters_={}, labels={}".format(n_clusters_,
                                                     unique_labels))
            for k in unique_labels:
                if k == -1:  # noisy samples
                    continue
                class_member_mask = (labels == k)
                xy = X[class_member_mask & core_samples_mask]
                if k == 0:
                    db_centers = np.mean(xy, axis=0)[None, :]
                else:
                    db_centers = np.concatenate(
                        (db_centers, np.mean(xy, axis=0)[None, :]), axis=0)
            print("db_centers=", db_centers)

            # TBD: "NOTE:: play with max_iter if you get the denom=inf error"

            # Mixture of von Mises Fisher clustering (soft)
            vmf_soft = VonMisesFisherMixture(n_clusters=n_clusters_,
                                             posterior_type='soft',
                                             init=db_centers,
                                             n_init=1,
                                             verbose=True,
                                             max_iter=20)
            vmf_soft.fit(X)

            y = 0
            for cn in range(n_clusters_):
                y += vmf_soft.weights_[cn] * np.exp(
                    von_mises_fisher_mixture._vmf_log(
                        Xq, vmf_soft.concentrations_[cn],
                        vmf_soft.cluster_centers_[cn]))
            yq = np.array(y)[:, None]
            to_save.append(yq)
            data_cells.append(i)

            # Plot
            pl.figure(figsize=(15, 4))

            pl.subplot(131)
            mesh = np.vstack((xx.ravel(), yy.ravel())).T
            pl.scatter(mesh[:, 0], mesh[:, 1], c='k', marker='.')
            pl.scatter(data[:, 1],
                       data[:, 2],
                       c=data[:, 0],
                       marker='*',
                       cmap='jet')
            pl.colorbar()
            pl.xlim([0, 20])
            pl.ylim([-5, 30])
            pl.title('data')

            pl.subplot(132)
            pl.scatter(Xq[:, 0], Xq[:, 1], c=yq[:], cmap='jet')
            pl.colorbar()
            pl.scatter(X[:, 0] * 0.9, X[:, 1] * 0.9, c='k', marker='+')
            pl.title('data and extimated distribution')

            pl.subplot(133, projection='polar')
            pl.polar(Thq, yq)
            pl.title('polar plot')
            pl.savefig('outputs/multimodal_sim2_cell{}'.format(i))
            #pl.show()
        except:
            print(' skipped...')
            continue
Exemplo n.º 4
0
def sim_unimodal():

    to_save = []
    data_cells = []

    # Data pre-processing
    for i in range(40):
        read_sim(i,
                 f_in='datasets/unimodal_sim1.npy',
                 f_out='datasets/transformed/unimodal_sim1_cell' + str(i))

    # Angles to query
    Thq = np.linspace(-np.pi, np.pi, 360)[:, None]
    Xq = np.hstack((np.cos(Thq), np.sin(Thq)))

    # Fit one cell at a time
    for i in range(40):
        print('cell no={}'.format(i))

        try:
            # Read data
            read_data = np.load('datasets/transformed/unimodal_sim1_cell' +
                                str(i) + '.npz')
            data, xx, yy = read_data['data'], read_data['xx'], read_data['yy']
            if data.shape[0] <= 1:
                continue

            # Data
            Th = data[:, 4][:, None]
            X = np.hstack((np.cos(Th), np.sin(Th)))

            # Von Mises clustering (soft)
            vmf_soft = VonMisesFisherMixture(n_clusters=1,
                                             posterior_type='soft',
                                             n_init=20)
            vmf_soft.fit(X)
            y0 = np.exp(
                von_mises_fisher_mixture._vmf_log(
                    Xq, vmf_soft.concentrations_[0],
                    vmf_soft.cluster_centers_[0]))
            y = y0 * vmf_soft.weights_[0]

            # Query
            yq = np.array(y)[:, None]
            to_save.append(yq)
            data_cells.append(i)

            # Plot
            pl.figure(figsize=(15, 4))

            pl.subplot(131)
            mesh = np.vstack((xx.ravel(), yy.ravel())).T
            pl.scatter(mesh[:, 0], mesh[:, 1], c='k', marker='.')
            pl.scatter(data[:, 1],
                       data[:, 2],
                       c=data[:, 0],
                       marker='*',
                       cmap='jet')
            pl.colorbar()
            pl.xlim([0, 20])
            pl.ylim([-5, 30])
            pl.title('data')

            pl.subplot(132)
            pl.scatter(Xq[:, 0], Xq[:, 1], c=y0[:], cmap='jet')
            pl.colorbar()
            pl.scatter(X[:, 0] * 0.9, X[:, 1] * 0.9, c='k', marker='+')
            pl.title('data and extimated distribution')

            pl.subplot(133, projection='polar')
            pl.polar(Thq, yq)
            pl.title('polar plot')
            #pl.show()
            pl.savefig('outputs/unimodal_sim1_cell{}'.format(i))
        except:
            print(' skipped...')
            continue