Esempio n. 1
0
def test_exclude_targets_combinations():
    partitioner = ChainNode([
        NFoldPartitioner(),
        ExcludeTargetsCombinationsPartitioner(
            k=2, targets_attr='targets', space='partitions')
    ],
                            space='partitions')
    from mvpa2.misc.data_generators import normal_feature_dataset
    ds = normal_feature_dataset(snr=0.,
                                nlabels=4,
                                perlabel=3,
                                nchunks=3,
                                nonbogus_features=[0, 1, 2, 3],
                                nfeatures=4)
    partitions = list(partitioner.generate(ds))
    assert_equal(len(partitions), 3 * 6)
    splitter = Splitter('partitions')
    combs = []
    comb_chunks = []
    for p in partitions:
        trds, teds = list(splitter.generate(p))[:2]
        comb = tuple(np.unique(teds.targets))
        combs.append(comb)
        comb_chunks.append(comb + tuple(np.unique(teds.chunks)))
    assert_equal(len(set(combs)),
                 6)  # just 6 possible combinations of 2 out of 4
    assert_equal(len(set(comb_chunks)), 3 * 6)  # all unique
Esempio n. 2
0
def test_exclude_targets_combinations_subjectchunks():
    partitioner = ChainNode([NFoldPartitioner(attr='subjects'),
                             ExcludeTargetsCombinationsPartitioner(
                                 k=1,
                                 targets_attr='chunks',
                                 space='partitions')],
                            space='partitions')
    # targets do not need even to be defined!
    ds = Dataset(np.arange(18).reshape(9, 2),
                 sa={'chunks': np.arange(9) // 3,
                     'subjects': np.arange(9) % 3})
    dss = list(partitioner.generate(ds))
    assert_equal(len(dss), 9)

    testing_subjs, testing_chunks = [], []
    for ds_ in dss:
        testing_partition = ds_.sa.partitions == 2
        training_partition = ds_.sa.partitions == 1
        # must be scalars -- so implicit test here
        # if not -- would be error
        testing_subj = np.asscalar(np.unique(ds_.sa.subjects[testing_partition]))
        testing_subjs.append(testing_subj)
        testing_chunk = np.asscalar(np.unique(ds_.sa.chunks[testing_partition]))
        testing_chunks.append(testing_chunk)
        # and those must not appear for training
        ok_(not testing_subj in ds_.sa.subjects[training_partition])
        ok_(not testing_chunk in ds_.sa.chunks[training_partition])
    # and we should have gone through all chunks/subjs pairs
    testing_pairs = set(zip(testing_subjs, testing_chunks))
    assert_equal(len(testing_pairs), 9)
    # yoh: equivalent to set(itertools.product(range(3), range(3))))
    #      but .product is N/A for python2.5
    assert_equal(testing_pairs, set(zip(*np.where(np.ones((3,3))))))