Exemple #1
0
def test_data_properties():

    n = 10
    d = Data(np.arange(n), 's', normalise=False)
    real_time = d.n_realisations_samples()
    assert (real_time == n), 'Realisations in time are not returned correctly.'
    cv = (0, 8)
    real_time = d.n_realisations_samples(current_value=cv)
    assert (real_time == (n - cv[1])), ('Realisations in time are not '
                                        'returned correctly when current value'
                                        ' is set.')
Exemple #2
0
def test_data_properties():
    """Test data properties attributes."""
    n = 10
    d = Data(np.arange(n), 's', normalise=False)
    real_time = d.n_realisations_samples()
    assert (real_time == n), 'Realisations in time are not returned correctly.'
    cv = (0, 8)
    real_time = d.n_realisations_samples(current_value=cv)
    assert (real_time == (n - cv[1])), ('Realisations in time are not '
                                        'returned correctly when current value'
                                        ' is set.')
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_mi': 21,
        'max_lag': max_lag,
        'tau': 1
    }
    data = Data()
    data.generate_mute_data(100, 3)
    ais = ActiveInformationStorage()
    processes = [1, 2]
    results = ais.analyse_network(settings, data, processes)

    for p in processes:
        lais = results.get_single_process(p, fdr=False)['ais']
        if lais is np.nan:
            continue
        assert type(lais) is np.ndarray, (
            'LAIS estimation did not return an array of values: {0}'.format(
                lais))
        assert lais.shape[0] == data.n_replications, (
            'Wrong dim (no. replications) in LAIS estimate: {0}'.format(
                lais.shape))
        assert lais.shape[1] == data.n_realisations_samples((0, max_lag)), (
            'Wrong dim (no. samples) in LAIS estimate: {0}'.format(lais.shape))
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_mi': 21,
        'max_lag': max_lag,
        'tau': 1}
    data = Data()
    data.generate_mute_data(100, 3)
    ais = ActiveInformationStorage()
    processes = [1, 2]
    results = ais.analyse_network(settings, data, processes)

    for p in processes:
        lais = results.get_single_process(p, fdr=False)['ais']
        if lais is np.nan:
            continue
        assert type(lais) is np.ndarray, (
            'LAIS estimation did not return an array of values: {0}'.format(
                lais))
        assert lais.shape[0] == data.n_replications, (
            'Wrong dim (no. replications) in LAIS estimate: {0}'.format(
                lais.shape))
        assert lais.shape[1] == data.n_realisations_samples((0, max_lag)), (
            'Wrong dim (no. samples) in LAIS estimate: {0}'.format(lais.shape))
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(500, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag
    }
    target = 1
    te = MultivariateTE()
    results = te.analyse_network(settings, data, targets=[target])

    # Test if any sources were inferred. If not, return (this may happen
    # sometimes due to too few samples, however, a higher no. samples is not
    # feasible for a unit test).
    if results.get_single_target(target, fdr=False)['te'] is None:
        return

    lte = results.get_single_target(target, fdr=False)['te']
    n_sources = len(results.get_target_sources(target, fdr=False))
    assert type(lte) is np.ndarray, (
        'LTE estimation did not return an array of values: {0}'.format(lte))
    assert lte.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[1] == data.n_realisations_samples(
        (0, max_lag)), ('Wrong dim (no. samples) in LTE estimate: {0}'.format(
            lte.shape))
    assert lte.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LTE estimate: {0}'.format(lte.shape))

    # Test for correctnes of single link TE estimation by comparing it to the
    # omnibus TE. In this case (single source), the two should be the same.
    settings['local_values'] = False
    results_avg = te.analyse_network(settings, data, targets=[target])
    if results_avg.get_single_target(target, fdr=False)['te'] is None:
        return
    te_single_link = results_avg.get_single_target(target, fdr=False)['te'][0]
    te_omnibus = results_avg.get_single_target(target, fdr=False)['omnibus_te']
    assert np.isclose(te_single_link, te_omnibus), (
        'Single link TE is not equal to omnibus information transfer.')
    # Compare mean local TE to average TE.
    assert np.isclose(
        te_single_link,
        np.mean(lte)), ('Single link average TE and mean LTE deviate.')
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(500, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag}
    target = 1
    te = MultivariateTE()
    results = te.analyse_network(settings, data, targets=[target])

    # Test if any sources were inferred. If not, return (this may happen
    # sometimes due to too few samples, however, a higher no. samples is not
    # feasible for a unit test).
    if results.get_single_target(target, fdr=False)['te'] is None:
        return

    lte = results.get_single_target(target, fdr=False)['te']
    n_sources = len(results.get_target_sources(target, fdr=False))
    assert type(lte) is np.ndarray, (
        'LTE estimation did not return an array of values: {0}'.format(lte))
    assert lte.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[1] == data.n_realisations_samples((0, max_lag)), (
        'Wrong dim (no. samples) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LTE estimate: {0}'.format(lte.shape))

    # Test for correctnes of single link TE estimation by comparing it to the
    # omnibus TE. In this case (single source), the two should be the same.
    settings['local_values'] = False
    results_avg = te.analyse_network(settings, data, targets=[target])
    if results_avg.get_single_target(target, fdr=False)['te'] is None:
        return
    te_single_link = results_avg.get_single_target(target, fdr=False)['te'][0]
    te_omnibus = results_avg.get_single_target(target, fdr=False)['omnibus_te']
    assert np.isclose(te_single_link, te_omnibus), (
        'Single link TE is not equal to omnibus information transfer.')
    # Compare mean local TE to average TE.
    assert np.isclose(te_single_link, np.mean(lte)), (
        'Single link average TE and mean LTE deviate.')
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(500, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'noise_level': 0,
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag
    }
    target = 3
    sources = [0, 4]
    mi = MultivariateMI()
    results = mi.analyse_single_target(settings,
                                       data,
                                       target=target,
                                       sources=sources)
    settings['local_values'] = False
    results_avg = mi.analyse_single_target(settings,
                                           data,
                                           target=target,
                                           sources=sources)

    # Test if any sources were inferred. If not, return (this may happen
    # sometimes due to too few samples, however, a higher no. samples is not
    # feasible for a unit test).
    if results.get_single_target(target, fdr=False)['mi'] is None:
        return
    if results_avg.get_single_target(target, fdr=False)['mi'] is None:
        return

    lmi = results.get_single_target(target, fdr=False)['mi']
    n_sources = len(results.get_target_sources(target, fdr=False))
    assert type(lmi) is np.ndarray, (
        'LMI estimation did not return an array of values: {0}'.format(lmi))
    assert lmi.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LMI estimate: {0}'.format(lmi.shape))
    assert lmi.shape[1] == data.n_realisations_samples(
        (0, max_lag)), ('Wrong dim (no. samples) in LMI estimate: {0}'.format(
            lmi.shape))
    assert lmi.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LMI estimate: {0}'.format(lmi.shape))

    # Check if average and mean local values are the same. Test each source
    # separately. Inferred sources and variables may differ between the two
    # calls to analyse_single_target() due to low number of surrogates used in
    # unit testing.
    mi_single_link = results_avg.get_single_target(target, fdr=False)['mi']
    sources_local = results.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    for s in list(set(sources_avg).intersection(sources_local)):
        i1 = np.where(sources_avg == s)[0][0]
        i2 = np.where(sources_local == s)[0][0]
        # Skip comparison if inferred variables differ between links.
        vars_local = [
            v for v in results.get_single_target(
                target, fdr=False).selected_vars_sources if v[0] == s
        ]
        vars_avg = [
            v for v in results_avg.get_single_target(
                target, fdr=False).selected_vars_sources if v[0] == s
        ]
        if vars_local != vars_avg:
            continue
        print('Compare average ({0:.4f}) and local values ({1:.4f}).'.format(
            mi_single_link[i1], np.mean(lmi[i2, :, :])))
        assert np.isclose(
            mi_single_link[i1], np.mean(lmi[i2, :, :]), rtol=0.00005), (
                'Single link average MI ({0:.6f}) and mean LMI ({1:.6f}) '
                ' deviate.'.format(mi_single_link[i1], np.mean(lmi[i2, :, :])))
Exemple #8
0
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(500, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag
    }
    target = 1
    mi = MultivariateMI()
    results = mi.analyse_network(settings, data, targets=[target])

    # Test if any sources were inferred. If not, return (this may happen
    # sometimes due to too few samples, however, a higher no. samples is not
    # feasible for a unit test).
    if results.get_single_target(target, fdr=False)['mi'] is None:
        return

    lmi = results.get_single_target(target, fdr=False)['mi']
    n_sources = len(results.get_target_sources(target, fdr=False))
    assert type(lmi) is np.ndarray, (
        'LMI estimation did not return an array of values: {0}'.format(lmi))
    assert lmi.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LMI estimate: {0}'.format(lmi.shape))
    assert lmi.shape[1] == data.n_realisations_samples(
        (0, max_lag)), ('Wrong dim (no. samples) in LMI estimate: {0}'.format(
            lmi.shape))
    assert lmi.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LMI estimate: {0}'.format(lmi.shape))

    # Test for correctnes of single link MI estimation by comparing it to the
    # omnibus MI. In this case (single source), the two should be the same.
    # Skip assertion if more than one source was inferred (this happens
    # sometime due to random data and low no. permutations for statistical
    # testing in unit tests).
    settings['local_values'] = False
    results_avg = mi.analyse_network(settings, data, targets=[target])
    if results_avg.get_single_target(target, fdr=False)['mi'] is None:
        return
    mi_single_link = results_avg.get_single_target(target, fdr=False)['mi']
    mi_omnibus = results_avg.get_single_target(target, fdr=False)['omnibus_mi']
    sources_local = results.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    if len(sources_avg) == 1:
        print('Compare single link and omnibus MI.')
        assert np.isclose(mi_single_link, mi_omnibus, rtol=0.00005), (
            'Single link MI ({0:.6f}) is not equal to omnibus information '
            '({1:.6f}).'.format(mi_single_link[0], mi_omnibus))
    # Check if average and mean local values are the same. Test each source
    # separately. Inferred sources may differ between the two calls to
    # analyse_network() due to low number of surrogates used in unit testing.
    for s in list(set(sources_avg).intersection(sources_local)):
        print('Compare average and local values.')
        i1 = np.where(sources_avg == s)[0][0]
        i2 = np.where(sources_local == s)[0][0]
        assert np.isclose(
            mi_single_link[i1], np.mean(lmi[i2, :, :]), rtol=0.00005), (
                'Single link average MI ({0:.6f}) and mean LMI ({1:.6f}) '
                ' deviate.'.format(mi_single_link, np.mean(lmi)))
Exemple #9
0
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data(seed=SEED)
    data.generate_mute_data(200, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': max_lag,
        'max_lag_target': max_lag
    }
    target = 1
    mi = BivariateMI()
    results_local = mi.analyse_network(settings, data, targets=[target])

    lmi = results_local.get_single_target(target, fdr=False)['mi']
    if lmi is None:
        return
    n_sources = len(results_local.get_target_sources(target, fdr=False))
    assert type(lmi) is np.ndarray, (
        'LMI estimation did not return an array of values: {0}'.format(lmi))
    assert lmi.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LMI estimate: {0}'.format(lmi.shape))
    assert lmi.shape[1] == data.n_realisations_samples(
        (0, max_lag)), ('Wrong dim (no. samples) in LMI estimate {0}'.format(
            lmi.shape))
    assert lmi.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LMI estimate {0}'.format(lmi.shape))

    # Test for correctnes of single link MI estimation by comparing it to the
    # MI between single variables and the target. For this test case where we
    # find only one significant past variable per source, the two should be the
    # same. Also compare single link average MI to mean local MI for each
    # link.
    settings['local_values'] = False
    results_avg = mi.analyse_network(settings, data, targets=[target])
    mi_single_link = results_avg.get_single_target(target, fdr=False)['mi']
    mi_selected_sources = results_avg.get_single_target(
        target, fdr=False)['selected_sources_mi']
    sources_local = results_local.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    print('Single link average MI: {0}, single source MI: {1}.'.format(
        mi_single_link, mi_selected_sources))
    if mi_single_link is None:
        return
    assert np.isclose(mi_single_link, mi_selected_sources, atol=0.005).all(), (
        'Single link average MI {0} and single source MI {1} deviate.'.format(
            mi_single_link, mi_selected_sources))
    # Check if average and local values are the same. Test each source
    # separately. Inferred sources may differ between the two calls to
    # analyse_network() due to low number of surrogates used in unit testing.
    print('Compare average and local values.')
    for s in list(set(sources_avg).intersection(sources_local)):
        i1 = np.where(sources_avg == s)[0][0]
        i2 = np.where(sources_local == s)[0][0]
        assert np.isclose(
            mi_single_link[i1], np.mean(lmi[i2, :, :]), atol=0.005
        ), ('Single link average MI {0:0.6f} and mean LMI {1:0.6f} deviate.'.
            format(mi_single_link[i1], np.mean(lmi[i2, :, :])))
        assert np.isclose(
            mi_single_link[i1], mi_selected_sources[i1], atol=0.005
        ), ('Single link average MI {0:0.6f} and single source MI {1:0.6f} deviate.'
            .format(mi_single_link[i1], mi_selected_sources[i1]))
Exemple #10
0
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(200, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag
    }
    target = 2
    te = BivariateTE()
    results_local = te.analyse_network(settings, data, targets=[target])

    lte = results_local.get_single_target(target, fdr=False)['te']
    n_sources = len(results_local.get_target_sources(target, fdr=False))
    assert type(lte) is np.ndarray, (
        'LTE estimation did not return an array of values: {0}'.format(lte))
    assert lte.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[1] == data.n_realisations_samples(
        (0, max_lag)), ('Wrong dim (no. samples) in LTE estimate: {0}'.format(
            lte.shape))
    assert lte.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LTE estimate: {0}'.format(lte.shape))

    # Test for correctnes of single link TE estimation by comparing it to the
    # TE between single variables and the target. For this test case where we
    # find only one significant past variable per source, the two should be the
    # same. Also compare single link average TE to mean local TE for each
    # link.
    settings['local_values'] = False
    results_avg = te.analyse_network(settings, data, targets=[target])
    te_single_link = results_avg.get_single_target(target, fdr=False)['te']
    te_selected_sources = results_avg.get_single_target(
        target, fdr=False)['selected_sources_te']
    sources_local = results_local.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    assert np.isclose(te_single_link, te_selected_sources, atol=0.005).all(), (
        'Single link average TE {0} and single source TE {1} deviate.'.format(
            te_single_link, te_selected_sources))
    # Check if average and local values are the same. Make sure target pasts
    # are the same and test each source separately. Inferred source and target
    # may differ between the two calls to analyse_network() due to random data
    # and low number of surrogates used in unit testing. Different no. inferred
    # past variables will also lead to differences in estimates.
    if (results_avg.get_single_target(
            target,
            fdr=False).selected_vars_target == results_local.get_single_target(
                target, fdr=False).selected_vars_target):
        print('Compare average and local values.')
        for s in list(set(sources_avg).intersection(sources_local)):
            i1 = np.where(sources_avg == s)[0][0]
            i2 = np.where(sources_local == s)[0][0]
            assert np.isclose(
                te_single_link[i1], np.mean(lte[i2, :, :]), atol=0.005
            ), ('Single link average TE {0:.6f} and mean LTE {1:.6f} deviate for '
                'source {2}.'.format(te_single_link[i1],
                                     np.mean(lte[i2, :, :]), s))
            assert np.isclose(
                te_single_link[i1], te_selected_sources[i1], atol=0.005
            ), ('Single link average TE {0:.6f} and single source TE {1:.6f} '
                'deviate.'.format(te_single_link[i1], te_selected_sources[i1]))
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(200, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag}
    target = 2
    te = BivariateTE()
    results_local = te.analyse_network(settings, data, targets=[target])

    lte = results_local.get_single_target(target, fdr=False)['te']
    n_sources = len(results_local.get_target_sources(target, fdr=False))
    assert type(lte) is np.ndarray, (
        'LTE estimation did not return an array of values: {0}'.format(lte))
    assert lte.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[1] == data.n_realisations_samples((0, max_lag)), (
        'Wrong dim (no. samples) in LTE estimate: {0}'.format(lte.shape))
    assert lte.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LTE estimate: {0}'.format(lte.shape))

    # Test for correctnes of single link TE estimation by comparing it to the
    # TE between single variables and the target. For this test case where we
    # find only one significant past variable per source, the two should be the
    # same. Also compare single link average TE to mean local TE for each
    # link.
    settings['local_values'] = False
    results_avg = te.analyse_network(settings, data, targets=[target])
    te_single_link = results_avg.get_single_target(target, fdr=False)['te']
    te_selected_sources = results_avg.get_single_target(
        target, fdr=False)['selected_sources_te']
    sources_local = results_local.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    assert np.isclose(te_single_link, te_selected_sources, atol=0.005).all(), (
        'Single link average TE {0} and single source TE {1} deviate.'.format(
                te_single_link, te_selected_sources))
    # Check if average and local values are the same. Make sure target pasts
    # are the same and test each source separately. Inferred source and target
    # may differ between the two calls to analyse_network() due to random data
    # and low number of surrogates used in unit testing. Different no. inferred
    # past variables will also lead to differences in estimates.
    if (results_avg.get_single_target(target, fdr=False).selected_vars_target ==
            results_local.get_single_target(target, fdr=False).selected_vars_target):
        print('Compare average and local values.')
        for s in list(set(sources_avg).intersection(sources_local)):
            i1 = np.where(sources_avg == s)[0][0]
            i2 = np.where(sources_local == s)[0][0]
            assert np.isclose(te_single_link[i1], np.mean(lte[i2, :, :]), atol=0.005), (
                'Single link average TE {0:.6f} and mean LTE {1:.6f} deviate for '
                'source {2}.'.format(
                    te_single_link[i1], np.mean(lte[i2, :, :]), s))
            assert np.isclose(te_single_link[i1], te_selected_sources[i1], atol=0.005), (
                'Single link average TE {0:.6f} and single source TE {1:.6f} '
                'deviate.'.format(te_single_link[i1], te_selected_sources[i1]))
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(500, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag}
    target = 1
    mi = MultivariateMI()
    results = mi.analyse_network(settings, data, targets=[target])

    # Test if any sources were inferred. If not, return (this may happen
    # sometimes due to too few samples, however, a higher no. samples is not
    # feasible for a unit test).
    if results.get_single_target(target, fdr=False)['mi'] is None:
        return

    lmi = results.get_single_target(target, fdr=False)['mi']
    n_sources = len(results.get_target_sources(target, fdr=False))
    assert type(lmi) is np.ndarray, (
        'LMI estimation did not return an array of values: {0}'.format(
                lmi))
    assert lmi.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LMI estimate: {0}'.format(lmi.shape))
    assert lmi.shape[1] == data.n_realisations_samples((0, max_lag)), (
        'Wrong dim (no. samples) in LMI estimate: {0}'.format(lmi.shape))
    assert lmi.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LMI estimate: {0}'.format(lmi.shape))

    # Test for correctnes of single link MI estimation by comparing it to the
    # omnibus MI. In this case (single source), the two should be the same.
    # Skip assertion if more than one source was inferred (this happens
    # sometime due to random data and low no. permutations for statistical
    # testing in unit tests).
    settings['local_values'] = False
    results_avg = mi.analyse_network(settings, data, targets=[target])
    if results_avg.get_single_target(target, fdr=False)['mi'] is None:
        return
    mi_single_link = results_avg.get_single_target(target, fdr=False)['mi']
    mi_omnibus = results_avg.get_single_target(target, fdr=False)['omnibus_mi']
    sources_local = results.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    if len(sources_avg) == 1:
        print('Compare single link and omnibus MI.')
        assert np.isclose(mi_single_link, mi_omnibus, rtol=0.00005), (
            'Single link MI ({0:.6f}) is not equal to omnibus information '
            '({1:.6f}).'.format(mi_single_link[0], mi_omnibus))
    # Check if average and mean local values are the same. Test each source
    # separately. Inferred sources may differ between the two calls to
    # analyse_network() due to low number of surrogates used in unit testing.
    for s in list(set(sources_avg).intersection(sources_local)):
        print('Compare average and local values.')
        i1 = np.where(sources_avg == s)[0][0]
        i2 = np.where(sources_local == s)[0][0]
        assert np.isclose(mi_single_link[i1], np.mean(lmi[i2, :, :]), rtol=0.00005), (
            'Single link average MI ({0:.6f}) and mean LMI ({1:.6f}) '
            ' deviate.'.format(mi_single_link, np.mean(lmi)))
def test_return_local_values():
    """Test estimation of local values."""
    max_lag = 5
    data = Data()
    data.generate_mute_data(200, 5)
    settings = {
        'cmi_estimator': 'JidtKraskovCMI',
        'local_values': True,  # request calculation of local values
        'n_perm_max_stat': 21,
        'n_perm_min_stat': 21,
        'n_perm_max_seq': 21,
        'n_perm_omnibus': 21,
        'max_lag_sources': max_lag,
        'min_lag_sources': 4,
        'max_lag_target': max_lag}
    target = 1
    mi = BivariateMI()
    results_local = mi.analyse_network(settings, data, targets=[target])

    lmi = results_local.get_single_target(target, fdr=False)['mi']
    n_sources = len(results_local.get_target_sources(target, fdr=False))
    assert type(lmi) is np.ndarray, (
        'LMI estimation did not return an array of values: {0}'.format(
                lmi))
    assert lmi.shape[0] == n_sources, (
        'Wrong dim (no. sources) in LMI estimate: {0}'.format(
                lmi.shape))
    assert lmi.shape[1] == data.n_realisations_samples((0, max_lag)), (
        'Wrong dim (no. samples) in LMI estimate {0}'.format(
                lmi.shape))
    assert lmi.shape[2] == data.n_replications, (
        'Wrong dim (no. replications) in LMI estimate {0}'.format(
                lmi.shape))

    # Test for correctnes of single link MI estimation by comparing it to the
    # MI between single variables and the target. For this test case where we
    # find only one significant past variable per source, the two should be the
    # same. Also compare single link average MI to mean local MI for each
    # link.
    settings['local_values'] = False
    results_avg = mi.analyse_network(settings, data, targets=[target])
    mi_single_link = results_avg.get_single_target(target, fdr=False)['mi']
    mi_selected_sources = results_avg.get_single_target(
        target, fdr=False)['selected_sources_mi']
    sources_local = results_local.get_target_sources(target, fdr=False)
    sources_avg = results_avg.get_target_sources(target, fdr=False)
    assert np.isclose(mi_single_link, mi_selected_sources, atol=0.005).all(), (
        'Single link average MI {0} and single source MI {1} deviate.'.format(
                mi_single_link, mi_selected_sources))
    # Check if average and local values are the same. Test each source
    # separately. Inferred sources may differ between the two calls to
    # analyse_network() due to low number of surrogates used in unit testing.
    print('Compare average and local values.')
    for s in list(set(sources_avg).intersection(sources_local)):
        i1 = np.where(sources_avg == s)[0][0]
        i2 = np.where(sources_local == s)[0][0]
        assert np.isclose(mi_single_link[i1], np.mean(lmi[i2, :, :]), atol=0.005), (
            'Single link average MI {0:0.6f} and mean LMI {1:0.6f} deviate.'.format(
                mi_single_link[i1], np.mean(lmi[i2, :, :])))
        assert np.isclose(mi_single_link[i1], mi_selected_sources[i1], atol=0.005), (
            'Single link average MI {0:0.6f} and single source MI {1:0.6f} deviate.'.format(
                mi_single_link[i1], mi_selected_sources[i1]))