Exemplo n.º 1
0
def test_max_statistic_sequential():
    data = Data()
    data.generate_mute_data(104, 10)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,
        'max_lag_sources': 5,
        'min_lag_sources': 1,
        'max_lag_target': 5
        }
    setup = MultivariateTE()
    setup._initialise(settings, data, sources=[0, 1], target=2)
    setup.current_value = (0, 4)
    setup.selected_vars_sources = [(1, 1), (1, 2)]
    setup.selected_vars_full = [(0, 1), (1, 1), (1, 2)]
    setup._selected_vars_realisations = np.random.rand(
                                    data.n_realisations(setup.current_value),
                                    len(setup.selected_vars_full))
    setup._current_value_realisations = np.random.rand(
                                    data.n_realisations(setup.current_value),
                                    1)
    [sign, p, te] = stats.max_statistic_sequential(analysis_setup=setup,
                                                   data=data)
Exemplo n.º 2
0
def test_get_permuted_replications():
    """Test if permutation of replications works."""
    # Load previously generated example data
    path = os.path.join(os.path.dirname(__file__), 'data/')
    res_0 = pickle.load(open(path + 'mute_results_0.p', 'rb'))
    res_1 = pickle.load(open(path + 'mute_results_1.p', 'rb'))
    comp_settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'n_perm_max_stat': 50,
        'n_perm_min_stat': 50,
        'n_perm_omnibus': 200,
        'n_perm_max_seq': 50,
        'tail': 'two',
        'n_perm_comp': 6,
        'alpha_comp': 0.2,
        'stats_type': 'dependent'
    }
    comp = NetworkComparison()
    comp._initialise(comp_settings)
    comp._create_union(res_0, res_1)

    # Check permutation for dependent samples test: Replace realisations by
    # zeros and ones, check if realisations get swapped correctly.
    dat1 = Data()
    dat1.normalise = False
    dat1.set_data(np.zeros((5, 100, 5)), 'psr')
    dat2 = Data()
    dat2.normalise = False
    dat2.set_data(np.ones((5, 100, 5)), 'psr')
    [cond_a_perm, cv_a_perm, cond_b_perm,
     cv_b_perm] = comp._get_permuted_replications(data_a=dat1,
                                                  data_b=dat2,
                                                  target=1)
    n_vars = cond_a_perm.shape[1]
    assert (np.sum(cond_a_perm + cond_b_perm, axis=1) == n_vars).all(), (
        'Dependent samples permutation did not work correctly.')
    assert np.logical_xor(cond_a_perm, cond_b_perm).all(), (
        'Dependent samples permutation did not work correctly.')

    # Check permutations for independent samples test: Check the sum over
    # realisations.
    comp_settings['stats_type'] = 'independent'
    comp = NetworkComparison()
    comp._initialise(comp_settings)
    comp._create_union(res_0, res_1)
    [cond_a_perm, cv_a_perm, cond_b_perm,
     cv_b_perm] = comp._get_permuted_replications(data_a=dat1,
                                                  data_b=dat2,
                                                  target=1)
    n_samples = n_vars * dat1.n_realisations((0, comp.union['max_lag']))
    assert np.sum(cond_a_perm + cond_b_perm, axis=None) == n_samples, (
        'Independent samples permutation did not work correctly.')

    # test unequal number of replications
    dat2.generate_mute_data(100, 7)
    with pytest.raises(AssertionError):
        comp._get_permuted_replications(data_a=dat1, data_b=dat2, target=1)
Exemplo n.º 3
0
def test_permute_replications():
    """Test surrogate creation by permuting replications."""
    n = 20
    data = Data(np.vstack((np.zeros(n), np.ones(n) * 1, np.ones(n) * 2,
                           np.ones(n) * 3)).astype(int),
                'rs',
                normalise=False)
    current_value = (0, n - 1)
    l = [(0, 1), (0, 3), (0, 7)]
    [perm, perm_idx] = data.permute_replications(current_value=current_value,
                                                 idx_list=l)
    assert (np.all(perm[:, 0] == perm_idx)), 'Permutation did not work.'

    # Assert that samples have been swapped within the permutation range for
    # the first replication.
    rng = 3
    current_value = (0, 3)
    l = [(0, 0), (0, 1), (0, 2)]
    # data = Data(np.arange(n), 's', normalise=False)
    data = Data(np.vstack((np.arange(n), np.arange(n))).astype(int),
                'rs',
                normalise=False)
    perm_opts = {'perm_type': 'local', 'perm_range': rng}
    [perm, perm_idx] = data.permute_samples(current_value=current_value,
                                            idx_list=l,
                                            perm_opts=perm_opts)
    samples = np.arange(rng)
    i = 0
    n_per_repl = int(data.n_realisations(current_value) / data.n_replications)
    for p in range(n_per_repl // rng):
        assert (np.unique(perm[i:i + rng, 0]) == samples).all(), (
            'The permutation range was not respected.')
        samples += rng
        i += rng
    rem = n_per_repl % rng
    if rem > 0:
        assert (np.unique(perm[i:i + rem, 0]) == samples[0:rem]).all(), (
            'The remainder did not contain the same realisations.')

    # Test assertions that perm_range is not too low or too high.
    perm_opts = {'perm_type': 'local', 'perm_range': 1}
    with pytest.raises(AssertionError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_opts=perm_opts)
    perm_opts['perm_range'] = np.inf
    with pytest.raises(AssertionError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_opts=perm_opts)
    # Test ValueError if a string other than 'max' is given for perm_range.
    perm_opts['perm_range'] = 'foo'
    with pytest.raises(ValueError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_opts=perm_opts)
Exemplo n.º 4
0
def test_permute_replications():
    """Test surrogate creation by permuting replications."""
    n = 20
    data = Data(np.vstack((np.zeros(n),
                           np.ones(n) * 1,
                           np.ones(n) * 2,
                           np.ones(n) * 3)).astype(int),
                         'rs',
                         normalise=False)
    current_value = (0, n)
    l = [(0, 1), (0, 3), (0, 7)]
    [perm, perm_idx] = data.permute_replications(current_value=current_value,
                                                 idx_list=l)
    assert (np.all(perm[:, 0] == perm_idx)), 'Permutation did not work.'

    # Assert that samples have been swapped within the permutation range for
    # the first replication.
    rng = 3
    current_value = (0, 3)
    l = [(0, 0), (0, 1), (0, 2)]
    #data = Data(np.arange(n), 's', normalise=False)
    data = Data(np.vstack((np.arange(n),
                           np.arange(n))).astype(int),
                         'rs',
                         normalise=False)
    [perm, perm_idx] = data.permute_samples(current_value=current_value,
                                            idx_list=l,
                                            perm_range=rng)
    samples = np.arange(rng)
    i = 0
    n_per_repl = int(data.n_realisations(current_value) / data.n_replications)
    for p in range(n_per_repl // rng):
        assert (np.unique(perm[i:i + rng, 0]) == samples).all(), ('The '
            'permutation range was not respected.')
        samples += rng
        i += rng
    rem = n_per_repl % rng
    if rem > 0:
        assert (np.unique(perm[i:i + rem, 0]) == samples[0:rem]).all(), ('The '
            'remainder did not contain the same realisations.')

    # Test assertions that perm_range is not too low or too high.
    with pytest.raises(AssertionError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_range=1)
    with pytest.raises(AssertionError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_range=np.inf)
    # Test ValueError if a string other than 'max' is given for perm_range.
    with pytest.raises(ValueError):
        data.permute_samples(current_value=current_value,
                             idx_list=l,
                             perm_range='foo')
Exemplo n.º 5
0
def test_max_statistic_sequential():
    dat = Data()
    dat.generate_mute_data(104, 10)
    opts = {
        'cmi_calc_name': 'jidt_kraskov',
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,
        }
    setup = Multivariate_te(max_lag_sources=5, min_lag_sources=1,
                            max_lag_target=5, options=opts)
    setup.current_value = (0, 4)
    setup.selected_vars_sources = [(1, 1), (1, 2)]
    setup.selected_vars_full = [(0, 1), (1, 1), (1, 2)]
    setup._selected_vars_realisations = np.random.rand(dat.n_realisations(setup.current_value),
                                                       len(setup.selected_vars_full))
    setup._current_value_realisations = np.random.rand(dat.n_realisations(setup.current_value),
                                                       1)
    [sign, p, te] = stats.max_statistic_sequential(analysis_setup=setup,
                                                   data=dat, opts=opts)
Exemplo n.º 6
0
def test_max_statistic_sequential():
    dat = Data()
    dat.generate_mute_data(104, 10)
    opts = {
        'cmi_calc_name': 'jidt_kraskov',
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_omnibus': 21,
        'n_perm_max_seq': 21,
        }
    setup = Multivariate_te(max_lag_sources=5, min_lag_sources=1,
                            max_lag_target=5, options=opts)
    setup.current_value = (0, 4)
    setup.selected_vars_sources = [(1, 1), (1, 2)]
    setup.selected_vars_full = [(0, 1), (1, 1), (1, 2)]
    setup._selected_vars_realisations = np.random.rand(dat.n_realisations(setup.current_value),
                                                       len(setup.selected_vars_full))
    setup._current_value_realisations = np.random.rand(dat.n_realisations(setup.current_value),
                                                       1)
    [sign, p, te] = stats.max_statistic_sequential(analysis_setup=setup,
                                                   data=dat, opts=opts)
Exemplo n.º 7
0
def test_permute_replications():
    """Test surrogate creation by permuting replications."""
    n = 20
    data = Data(np.vstack((np.zeros(n),
                           np.ones(n) * 1,
                           np.ones(n) * 2,
                           np.ones(n) * 3)).astype(int),
                'rs',
                normalise=False)
    current_value = (0, n - 1)
    l = [(0, 1), (0, 3), (0, 7)]
    [perm, perm_idx] = data.permute_replications(current_value=current_value,
                                                 idx_list=l)
    assert (np.all(perm[:, 0] == perm_idx)), 'Permutation did not work.'

    # Assert that samples have been swapped within the permutation range for
    # the first replication.
    rng = 3
    current_value = (0, 3)
    l = [(0, 0), (0, 1), (0, 2)]
    # data = Data(np.arange(n), 's', normalise=False)
    data = Data(np.vstack((np.arange(n),
                           np.arange(n))).astype(int),
                'rs',
                normalise=False)
    perm_settings = {
        'perm_type': 'local',
        'perm_range': rng
    }
    [perm, perm_idx] = data.permute_samples(current_value=current_value,
                                            idx_list=l,
                                            perm_settings=perm_settings)
    samples = np.arange(rng)
    i = 0
    n_per_repl = int(data.n_realisations(current_value) / data.n_replications)
    for p in range(n_per_repl // rng):
        assert (np.unique(perm[i:i + rng, 0]) == samples).all(), (
                                    'The permutation range was not respected.')
        samples += rng
        i += rng
    rem = n_per_repl % rng
    if rem > 0:
        assert (np.unique(perm[i:i + rem, 0]) == samples[0:rem]).all(), (
                        'The remainder did not contain the same realisations.')
def test_get_permuted_replications():
    """Test if permutation of replications works."""
    # Load previously generated example data
    path = os.path.join(os.path.dirname(__file__), 'data/')
    res_0 = pickle.load(open(path + 'mute_results_0.p', 'rb'))
    res_1 = pickle.load(open(path + 'mute_results_1.p', 'rb'))
    comp_settings = {
            'cmi_estimator': 'JidtKraskovCMI',
            'n_perm_max_stat': 50,
            'n_perm_min_stat': 50,
            'n_perm_omnibus': 200,
            'n_perm_max_seq': 50,
            'tail': 'two',
            'n_perm_comp': 6,
            'alpha_comp': 0.2,
            'stats_type': 'dependent'
            }
    comp = NetworkComparison()
    comp._initialise(comp_settings)
    comp._create_union(res_0, res_1)

    # Check permutation for dependent samples test: Replace realisations by
    # zeros and ones, check if realisations get swapped correctly.
    dat1 = Data()
    dat1.normalise = False
    dat1.set_data(np.zeros((5, 100, 5)), 'psr')
    dat2 = Data()
    dat2.normalise = False
    dat2.set_data(np.ones((5, 100, 5)), 'psr')
    [cond_a_perm,
     cv_a_perm,
     cond_b_perm,
     cv_b_perm] = comp._get_permuted_replications(data_a=dat1,
                                                  data_b=dat2,
                                                  target=1)
    n_vars = cond_a_perm.shape[1]
    assert (np.sum(cond_a_perm + cond_b_perm, axis=1) == n_vars).all(), (
                'Dependent samples permutation did not work correctly.')
    assert np.logical_xor(cond_a_perm, cond_b_perm).all(), (
                'Dependent samples permutation did not work correctly.')

    # Check permutations for independent samples test: Check the sum over
    # realisations.
    comp_settings['stats_type'] = 'independent'
    comp = NetworkComparison()
    comp._initialise(comp_settings)
    comp._create_union(res_0, res_1)
    [cond_a_perm,
     cv_a_perm,
     cond_b_perm,
     cv_b_perm] = comp._get_permuted_replications(data_a=dat1,
                                                  data_b=dat2,
                                                  target=1)
    n_samples = n_vars * dat1.n_realisations((0, comp.union['max_lag']))
    assert np.sum(cond_a_perm + cond_b_perm, axis=None) == n_samples, (
                'Independent samples permutation did not work correctly.')

    # test unequal number of replications
    dat2.generate_mute_data(100, 7)
    with pytest.raises(AssertionError):
        comp._get_permuted_replications(data_a=dat1, data_b=dat2, target=1)