Esempio n. 1
0
def test_knn_memory():
    if not have_flann:
        raise SkipTest("No flann, so skipping knn tests.")

    dim = 3
    n = 20
    np.random.seed(47)
    bags = Features([np.random.randn(np.random.randint(30, 100), dim)
                     for _ in xrange(n)])

    tdir = tempfile.mkdtemp()
    div_funcs = ('kl', 'js', 'renyi:.9', 'l2', 'tsallis:.8')
    Ks = (3, 4)
    est = KNNDivergenceEstimator(div_funcs=div_funcs, Ks=Ks, memory=tdir)
    res1 = est.fit_transform(bags)

    with LogCapture('skl_groups.divergences.knn', level=logging.INFO) as l:
        res2 = est.transform(bags)
        assert len(l.records) == 0
    assert np.all(res1 == res2)

    with LogCapture('skl_groups.divergences.knn', level=logging.INFO) as l:
        res3 = est.fit_transform(bags)
        for r in l.records:
            assert not r.message.startswith("Getting divergences")
    assert np.all(res1 == res3)
Esempio n. 2
0
def test_knn_memory():
    if not have_flann:
        raise SkipTest("No flann, so skipping knn tests.")

    dim = 3
    n = 20
    np.random.seed(47)
    bags = Features(
        [np.random.randn(np.random.randint(30, 100), dim) for _ in xrange(n)])

    tdir = tempfile.mkdtemp()
    div_funcs = ('kl', 'js', 'renyi:.9', 'l2', 'tsallis:.8')
    Ks = (3, 4)
    est = KNNDivergenceEstimator(div_funcs=div_funcs, Ks=Ks, memory=tdir)
    res1 = est.fit_transform(bags)

    with LogCapture('skl_groups.divergences.knn', level=logging.INFO) as l:
        res2 = est.transform(bags)
        assert len(l.records) == 0
    assert np.all(res1 == res2)

    with LogCapture('skl_groups.divergences.knn', level=logging.INFO) as l:
        res3 = est.fit_transform(bags)
        for r in l.records:
            assert not r.message.startswith("Getting divergences")
    assert np.all(res1 == res3)
Esempio n. 3
0
def test_knn_sanity_slow():
    if not have_flann:
        raise SkipTest("No flann, so skipping knn tests.")

    dim = 3
    n = 20
    np.random.seed(47)
    bags = Features(
        [np.random.randn(np.random.randint(30, 100), dim) for _ in xrange(n)])

    # just make sure it runs
    div_funcs = ('kl', 'js', 'renyi:.9', 'l2', 'tsallis:.8')
    Ks = (3, 4)
    est = KNNDivergenceEstimator(div_funcs=div_funcs, Ks=Ks)
    res = est.fit_transform(bags)
    assert res.shape == (len(div_funcs), len(Ks), n, n)
    assert np.all(np.isfinite(res))

    # test that JS blows up when there's a huge difference in bag sizes
    # (so that K is too low)
    assert_raises(
        ValueError,
        partial(est.fit_transform, bags + [np.random.randn(1000, dim)]))

    # test fit() and then transform() with JS, with different-sized test bags
    est = KNNDivergenceEstimator(div_funcs=('js', ), Ks=(5, ))
    est.fit(bags, get_rhos=True)
    with LogCapture('skl_groups.divergences.knn', level=logging.WARNING) as l:
        res = est.transform([np.random.randn(300, dim)])
        assert res.shape == (1, 1, 1, len(bags))
        assert len(l.records) == 1
        assert l.records[0].message.startswith('Y_rhos had a lower max_K')

    # test that passing div func more than once raises
    def blah(df):
        est = KNNDivergenceEstimator(div_funcs=[df, df])
        return est.fit(bags)

    assert_raises(ValueError, lambda: blah('kl'))
    assert_raises(ValueError, lambda: blah('renyi:.8'))
    assert_raises(ValueError, lambda: blah('l2'))
Esempio n. 4
0
def test_knn_sanity_slow():
    if not have_flann:
        raise SkipTest("No flann, so skipping knn tests.")

    dim = 3
    n = 20
    np.random.seed(47)
    bags = Features([np.random.randn(np.random.randint(30, 100), dim)
                     for _ in xrange(n)])

    # just make sure it runs
    div_funcs = ('kl', 'js', 'renyi:.9', 'l2', 'tsallis:.8')
    Ks = (3, 4)
    est = KNNDivergenceEstimator(div_funcs=div_funcs, Ks=Ks)
    res = est.fit_transform(bags)
    assert res.shape == (len(div_funcs), len(Ks), n, n)
    assert np.all(np.isfinite(res))

    # test that JS blows up when there's a huge difference in bag sizes
    # (so that K is too low)
    assert_raises(
        ValueError,
        partial(est.fit_transform, bags + [np.random.randn(1000, dim)]))

    # test fit() and then transform() with JS, with different-sized test bags
    est = KNNDivergenceEstimator(div_funcs=('js',), Ks=(5,))
    est.fit(bags, get_rhos=True)
    with LogCapture('skl_groups.divergences.knn', level=logging.WARNING) as l:
        res = est.transform([np.random.randn(300, dim)])
        assert res.shape == (1, 1, 1, len(bags))
        assert len(l.records) == 1
        assert l.records[0].message.startswith('Y_rhos had a lower max_K')

    # test that passing div func more than once raises
    def blah(df):
        est = KNNDivergenceEstimator(div_funcs=[df, df])
        return est.fit(bags)
    assert_raises(ValueError, lambda: blah('kl'))
    assert_raises(ValueError, lambda: blah('renyi:.8'))
    assert_raises(ValueError, lambda: blah('l2'))
Esempio n. 5
0
def distribution_divergence(X_s, X_l, k=10):
    """
    This function computes l2 and js divergences from samples of two distributions.
    The implementation use `skl-groups`, which implements non-parametric estimation
    of divergences.

    Args:
        + X_s: a numpy array containing point cloud in state space
        + X_e: a numpy array containing point cloud in latent space
    """

    # We discard cases with too large dimensions
    if X_s.shape[1] > 50:
        return {'l2_divergence': -1., 'js_divergence': -1.}

    # We instantiate the divergence object
    div = KNNDivergenceEstimator(div_funcs=['l2', 'js'],
                                 Ks=[k],
                                 n_jobs=4,
                                 clamp=True)

    # We turn both data to float32
    X_s = X_s.astype(np.float32)
    X_l = X_l.astype(np.float32)

    # We generate Features
    f_s = Features(X_s, n_pts=[X_s.shape[0]])
    f_l = Features(X_l, n_pts=[X_l.shape[0]])

    # We create the knn graph
    div.fit(X=f_s)

    # We compute the divergences
    l2, js = div.transform(X=f_l).squeeze()

    # We construct the returned dictionnary
    output = {'l2_divergence': l2, 'js_divergence': js}

    return output