Пример #1
0
def test_exceptions():
    with pytest.raises(ValueError) as e_info:
        ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, method='threading', weights=1)
        assert e_info == 'weights need to be None or xarray.DataArray with method=threading'

    with pytest.raises(ValueError) as e_info:
        ndl.ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, method='magic')
        assert e_info == 'method needs to be either "threading" or "openmp"'

    with pytest.raises(ValueError) as e_info:
        ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, weights=1)
        assert e_info == 'weights needs to be either defaultdict or None'

    with pytest.raises(ValueError) as e_info:
        ndl.dict_ndl(FILE_PATH_MULTIPLE_CUES,
                     ALPHA,
                     BETAS,
                     remove_duplicates=None)
        assert e_info == 'cues or outcomes needs to be unique: cues "a a"; outcomes "A"; use remove_duplicates=True'

    with pytest.raises(ValueError) as e_info:
        ndl.ndl(FILE_PATH_SIMPLE,
                ALPHA,
                BETAS,
                method='threading',
                len_sublists=-1)
        assert e_info == "'len_sublists' must be larger then one"

    with pytest.raises(ValueError) as e_info:
        ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, make_data_array="magic")
        assert e_info == "make_data_array must be True or False"

    with pytest.raises(ValueError) as e_info:
        ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, remove_duplicates="magic")
        assert e_info == "remove_duplicates must be None, True or False"

    with pytest.raises(ValueError) as e_info:
        ndl.ndl(FILE_PATH_SIMPLE,
                ALPHA,
                BETAS,
                method='threading',
                remove_duplicates="magic")
        assert e_info == "remove_duplicates must be None, True or False"

    with pytest.raises(FileNotFoundError,
                       match="No such file or directory") as e_info:
        ndl.ndl(FILE_PATH_SIMPLE,
                ALPHA,
                BETAS,
                method='threading',
                temporary_directory="./magic")

    with pytest.raises(
            ValueError,
            match="events_per_file has to be larger than 1") as e_info:
        ndl.ndl(FILE_PATH_SIMPLE,
                ALPHA,
                BETAS,
                method='threading',
                events_per_temporary_file=1)
Пример #2
0
def test_continue_learning_dict_ndl_data_array(result_dict_ndl,
                                               result_dict_ndl_data_array):
    continue_from_dict = ndl.dict_ndl(FILE_PATH_SIMPLE,
                                      ALPHA,
                                      BETAS,
                                      weights=result_dict_ndl)
    continue_from_data_array = ndl.dict_ndl(FILE_PATH_SIMPLE,
                                            ALPHA,
                                            BETAS,
                                            weights=result_dict_ndl_data_array)
    unequal, unequal_ratio = compare_arrays(FILE_PATH_SIMPLE,
                                            continue_from_dict,
                                            continue_from_data_array)
    print(continue_from_data_array)
    print('%.2f ratio unequal' % unequal_ratio)
    assert len(unequal) == 0
Пример #3
0
def test_dict_ndl_data_array_vs_ndl_threading(result_ndl_threading):
    result_dict_ndl = ndl.dict_ndl(FILE_PATH_SIMPLE,
                                   ALPHA,
                                   BETAS,
                                   make_data_array=True)

    unequal, unequal_ratio = compare_arrays(FILE_PATH_SIMPLE, result_dict_ndl,
                                            result_ndl_threading)
    print('%.2f ratio unequal' % unequal_ratio)
    assert len(unequal) == 0
Пример #4
0
def test_continue_learning_dict():
    events_simple = pd.read_csv(FILE_PATH_SIMPLE, sep="\t")
    part_1 = events_simple.head(CONTINUE_SPLIT_POINT)
    part_2 = events_simple.tail(len(events_simple) - CONTINUE_SPLIT_POINT)

    assert len(part_1) > 0
    assert len(part_2) > 0

    part_path_1 = os.path.join(TMP_PATH, "event_file_simple_1.tab.gz")
    part_path_2 = os.path.join(TMP_PATH, "event_file_simple_2.tab.gz")

    part_1.to_csv(part_path_1,
                  header=True,
                  index=None,
                  sep='\t',
                  columns=["cues", "outcomes"],
                  compression='gzip')
    part_2.to_csv(part_path_2,
                  header=True,
                  index=None,
                  sep='\t',
                  columns=["cues", "outcomes"],
                  compression='gzip')

    del events_simple, part_1, part_2

    result_part = ndl.dict_ndl(part_path_1, ALPHA, BETAS)
    result_part_copy = copy.deepcopy(result_part)

    result_inplace = ndl.dict_ndl(part_path_2,
                                  ALPHA,
                                  BETAS,
                                  weights=result_part,
                                  inplace=True)

    assert result_part is result_inplace
    assert result_part != result_part_copy

    result_part = ndl.dict_ndl(part_path_1, ALPHA, BETAS)

    result = ndl.dict_ndl(part_path_2, ALPHA, BETAS, weights=result_part)

    assert result_part != result
Пример #5
0
def test_exceptions():
    with pytest.raises(ValueError) as e_info:
        wm = ndl.dict_ndl(FILE_PATH_SIMPLE,
                          ALPHA,
                          BETAS,
                          remove_duplicates=None)
        activation(FILE_PATH_MULTIPLE_CUES, wm)
        assert e_info == 'cues or outcomes needs to be unique: cues "a a"; outcomes "A"; use remove_duplicates=True'

    with pytest.raises(ValueError) as e_info:
        activation(FILE_PATH_MULTIPLE_CUES, "magic")
        assert e_info == "Weights other than xarray.DataArray or dicts are not supported."
Пример #6
0
def test_multiple_cues_dict_ndl_vs_ndl_threading():
    result_dict_ndl = ndl.dict_ndl(FILE_PATH_MULTIPLE_CUES,
                                   ALPHA,
                                   BETAS,
                                   remove_duplicates=True)
    result_ndl_threading = ndl.ndl(FILE_PATH_MULTIPLE_CUES,
                                   ALPHA,
                                   BETAS,
                                   remove_duplicates=True,
                                   method='threading')

    unequal, unequal_ratio = compare_arrays(FILE_PATH_MULTIPLE_CUES,
                                            result_dict_ndl,
                                            result_ndl_threading)
    print('%.2f ratio unequal' % unequal_ratio)
    assert len(unequal) == 0
Пример #7
0
def test_multiple_cues_dict_ndl_vs_ndl2():
    """
    Checks whether the output of the R learner implemented in ndl2 and the
    python implementation of dict_ndl is equal.

    R code to generate the results::

        library(ndl2)
        learner <- learnWeightsTabular('tests/resources/event_file_multiple_cues.tab.gz',
                                       alpha=0.1, beta=0.1, lambda=1.0, removeDuplicates=FALSE)
        wm <- learner$getWeights()
        wm <- wm[order(rownames(wm)), order(colnames(wm))]
        write.csv(wm, 'tests/reference/weights_event_file_multiple_cues_ndl2.csv')

    """
    result_ndl2 = defaultdict(lambda: defaultdict(float))

    with open(REFERENCE_PATH_MULTIPLE_CUES_NDL2, 'rt') as reference_file:
        first_line = reference_file.readline().strip()
        outcomes = first_line.split(',')[1:]
        outcomes = [outcome.strip('"') for outcome in outcomes]
        for line in reference_file:
            cue, *cue_weights = line.strip().split(',')
            cue = cue.strip('"')
            for ii, outcome in enumerate(outcomes):
                result_ndl2[outcome][cue] = float(cue_weights[ii])

    result_python = ndl.dict_ndl(FILE_PATH_MULTIPLE_CUES,
                                 ALPHA,
                                 BETAS,
                                 remove_duplicates=False)

    unequal, unequal_ratio = compare_arrays(FILE_PATH_MULTIPLE_CUES,
                                            result_ndl2, result_python)
    print(set(outcome for outcome, *_ in unequal))
    print('%.2f ratio unequal' % unequal_ratio)
    assert len(unequal) == 0
Пример #8
0
def result_dict_ndl_data_array():
    return ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS, make_data_array=True)
Пример #9
0
def result_dict_ndl_generator():
    return ndl.dict_ndl(ndl.events_from_file(FILE_PATH_SIMPLE), ALPHA, BETAS)
Пример #10
0
def result_dict_ndl():
    return ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS)
Пример #11
0
def test_dict_ndl_vs_ndl_openmp(result_dict_ndl, result_ndl_openmp):
    result_dict_ndl = ndl.dict_ndl(FILE_PATH_SIMPLE, ALPHA, BETAS)
    unequal, unequal_ratio = compare_arrays(FILE_PATH_SIMPLE, result_dict_ndl,
                                            result_ndl_openmp)
    print('%.2f ratio unequal' % unequal_ratio)
    assert len(unequal) == 0