예제 #1
0
def test_define_candidates():
    """Test candidate definition from a list of procs and a list of samples."""
    target = 1
    tau_sources = 3
    max_lag_sources = 10
    current_val = (target, 10)
    procs = [target]
    samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources,
                        -tau_sources)
    # Test if candidates that are added manually to the conditioning set are
    # removed from the candidate set.
    nw = MultivariateMI()
    settings = [{
        'add_conditionals': None
    }, {
        'add_conditionals': (2, 3)
    }, {
        'add_conditionals': [(2, 3), (4, 1)]
    }]
    for s in settings:
        nw.settings = s
        candidates = nw._define_candidates(procs, samples)
        assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
        assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
        assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'

    settings = [{
        'add_conditionals': [(1, 9)]
    }, {
        'add_conditionals': [(1, 9), (2, 3), (4, 1)]
    }]
    for s in settings:
        nw.settings = s
        candidates = nw._define_candidates(procs, samples)
    assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).'
    assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
    assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'
def test_define_candidates():
    """Test candidate definition from a list of procs and a list of samples."""
    target = 1
    tau_sources = 3
    max_lag_sources = 10
    current_val = (target, 10)
    procs = [target]
    samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources,
                        -tau_sources)
    # Test if candidates that are added manually to the conditioning set are
    # removed from the candidate set.
    nw = MultivariateMI()
    nw.current_value = current_val
    settings = [{
        'add_conditionals': None
    }, {
        'add_conditionals': (2, 3)
    }, {
        'add_conditionals': [(2, 3), (4, 1)]
    }, {
        'add_conditionals': [(1, 9)]
    }, {
        'add_conditionals': [(1, 9), (2, 3), (4, 1)]
    }]
    for s in settings:
        nw.settings = s
        candidates = nw._define_candidates(procs, samples)
        assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).'
        assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).'
        assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'
        if s['add_conditionals'] is not None:
            if type(s['add_conditionals']) is tuple:
                cond_ind = nw._lag_to_idx([s['add_conditionals']])
            else:
                cond_ind = nw._lag_to_idx(s['add_conditionals'])
            for c in cond_ind:
                assert c not in candidates, (
                    'Sample added erronously to candidates: {}.'.format(c))
def test_check_source_set():
    """Test the method _check_source_set.

    This method sets the list of source processes from which candidates are
    taken for multivariate MI estimation.
    """
    data = Data()
    data.generate_mute_data(100, 5)
    nw_0 = MultivariateMI()
    nw_0.settings = {'verbose': True}
    # Add list of sources.
    sources = [1, 2, 3]
    nw_0._check_source_set(sources, data.n_processes)
    assert nw_0.source_set == sources, 'Sources were not added correctly.'

    # Assert that initialisation fails if the target is also in the source list
    sources = [0, 1, 2, 3]
    nw_0.target = 0
    with pytest.raises(RuntimeError):
        nw_0._check_source_set(sources=[0, 1, 2, 3],
                               n_processes=data.n_processes)

    # Test if a single source, no list is added correctly.
    sources = 1
    nw_0._check_source_set(sources, data.n_processes)
    assert (type(nw_0.source_set) is list)

    # Test if 'all' is handled correctly
    nw_0.target = 0
    nw_0._check_source_set('all', data.n_processes)
    assert nw_0.source_set == [1, 2, 3, 4], 'Sources were not added correctly.'

    # Test invalid inputs.
    with pytest.raises(RuntimeError):  # sources greater than no. procs
        nw_0._check_source_set(8, data.n_processes)
    with pytest.raises(RuntimeError):  # negative value as source
        nw_0._check_source_set(-3, data.n_processes)
예제 #4
0
def test_check_source_set():
    """Test the method _check_source_set.

    This method sets the list of source processes from which candidates are
    taken for multivariate MI estimation.
    """
    data = Data()
    data.generate_mute_data(100, 5)
    nw_0 = MultivariateMI()
    nw_0.settings = {'verbose': True}
    # Add list of sources.
    sources = [1, 2, 3]
    nw_0._check_source_set(sources, data.n_processes)
    assert nw_0.source_set == sources, 'Sources were not added correctly.'

    # Assert that initialisation fails if the target is also in the source list
    sources = [0, 1, 2, 3]
    nw_0.target = 0
    with pytest.raises(RuntimeError):
        nw_0._check_source_set(sources=[0, 1, 2, 3],
                               n_processes=data.n_processes)

    # Test if a single source, no list is added correctly.
    sources = 1
    nw_0._check_source_set(sources, data.n_processes)
    assert (type(nw_0.source_set) is list)

    # Test if 'all' is handled correctly
    nw_0.target = 0
    nw_0._check_source_set('all', data.n_processes)
    assert nw_0.source_set == [1, 2, 3, 4], 'Sources were not added correctly.'

    # Test invalid inputs.
    with pytest.raises(RuntimeError):   # sources greater than no. procs
        nw_0._check_source_set(8, data.n_processes)
    with pytest.raises(RuntimeError):  # negative value as source
        nw_0._check_source_set(-3, data.n_processes)