def test_define_candidates(): """Test candidate definition from a list of procs and a list of samples.""" target = 1 tau_sources = 3 max_lag_sources = 10 current_val = (target, 10) procs = [target] samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources, -tau_sources) # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = MultivariateMI() settings = [{ 'add_conditionals': None }, { 'add_conditionals': (2, 3) }, { 'add_conditionals': [(2, 3), (4, 1)] }] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' settings = [{ 'add_conditionals': [(1, 9)] }, { 'add_conditionals': [(1, 9), (2, 3), (4, 1)] }] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) not in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).'
def test_define_candidates(): """Test candidate definition from a list of procs and a list of samples.""" target = 1 tau_sources = 3 max_lag_sources = 10 current_val = (target, 10) procs = [target] samples = np.arange(current_val[1] - 1, current_val[1] - max_lag_sources, -tau_sources) # Test if candidates that are added manually to the conditioning set are # removed from the candidate set. nw = MultivariateMI() nw.current_value = current_val settings = [{ 'add_conditionals': None }, { 'add_conditionals': (2, 3) }, { 'add_conditionals': [(2, 3), (4, 1)] }, { 'add_conditionals': [(1, 9)] }, { 'add_conditionals': [(1, 9), (2, 3), (4, 1)] }] for s in settings: nw.settings = s candidates = nw._define_candidates(procs, samples) assert (1, 9) in candidates, 'Sample missing from candidates: (1, 9).' assert (1, 6) in candidates, 'Sample missing from candidates: (1, 6).' assert (1, 3) in candidates, 'Sample missing from candidates: (1, 3).' if s['add_conditionals'] is not None: if type(s['add_conditionals']) is tuple: cond_ind = nw._lag_to_idx([s['add_conditionals']]) else: cond_ind = nw._lag_to_idx(s['add_conditionals']) for c in cond_ind: assert c not in candidates, ( 'Sample added erronously to candidates: {}.'.format(c))
def test_check_source_set(): """Test the method _check_source_set. This method sets the list of source processes from which candidates are taken for multivariate MI estimation. """ data = Data() data.generate_mute_data(100, 5) nw_0 = MultivariateMI() nw_0.settings = {'verbose': True} # Add list of sources. sources = [1, 2, 3] nw_0._check_source_set(sources, data.n_processes) assert nw_0.source_set == sources, 'Sources were not added correctly.' # Assert that initialisation fails if the target is also in the source list sources = [0, 1, 2, 3] nw_0.target = 0 with pytest.raises(RuntimeError): nw_0._check_source_set(sources=[0, 1, 2, 3], n_processes=data.n_processes) # Test if a single source, no list is added correctly. sources = 1 nw_0._check_source_set(sources, data.n_processes) assert (type(nw_0.source_set) is list) # Test if 'all' is handled correctly nw_0.target = 0 nw_0._check_source_set('all', data.n_processes) assert nw_0.source_set == [1, 2, 3, 4], 'Sources were not added correctly.' # Test invalid inputs. with pytest.raises(RuntimeError): # sources greater than no. procs nw_0._check_source_set(8, data.n_processes) with pytest.raises(RuntimeError): # negative value as source nw_0._check_source_set(-3, data.n_processes)