def test_exclude_targets_combinations(): partitioner = ChainNode([ NFoldPartitioner(), ExcludeTargetsCombinationsPartitioner( k=2, targets_attr='targets', space='partitions') ], space='partitions') from mvpa2.misc.data_generators import normal_feature_dataset ds = normal_feature_dataset(snr=0., nlabels=4, perlabel=3, nchunks=3, nonbogus_features=[0, 1, 2, 3], nfeatures=4) partitions = list(partitioner.generate(ds)) assert_equal(len(partitions), 3 * 6) splitter = Splitter('partitions') combs = [] comb_chunks = [] for p in partitions: trds, teds = list(splitter.generate(p))[:2] comb = tuple(np.unique(teds.targets)) combs.append(comb) comb_chunks.append(comb + tuple(np.unique(teds.chunks))) assert_equal(len(set(combs)), 6) # just 6 possible combinations of 2 out of 4 assert_equal(len(set(comb_chunks)), 3 * 6) # all unique
def test_exclude_targets_combinations_subjectchunks(): partitioner = ChainNode([NFoldPartitioner(attr='subjects'), ExcludeTargetsCombinationsPartitioner( k=1, targets_attr='chunks', space='partitions')], space='partitions') # targets do not need even to be defined! ds = Dataset(np.arange(18).reshape(9, 2), sa={'chunks': np.arange(9) // 3, 'subjects': np.arange(9) % 3}) dss = list(partitioner.generate(ds)) assert_equal(len(dss), 9) testing_subjs, testing_chunks = [], [] for ds_ in dss: testing_partition = ds_.sa.partitions == 2 training_partition = ds_.sa.partitions == 1 # must be scalars -- so implicit test here # if not -- would be error testing_subj = np.asscalar(np.unique(ds_.sa.subjects[testing_partition])) testing_subjs.append(testing_subj) testing_chunk = np.asscalar(np.unique(ds_.sa.chunks[testing_partition])) testing_chunks.append(testing_chunk) # and those must not appear for training ok_(not testing_subj in ds_.sa.subjects[training_partition]) ok_(not testing_chunk in ds_.sa.chunks[training_partition]) # and we should have gone through all chunks/subjs pairs testing_pairs = set(zip(testing_subjs, testing_chunks)) assert_equal(len(testing_pairs), 9) # yoh: equivalent to set(itertools.product(range(3), range(3)))) # but .product is N/A for python2.5 assert_equal(testing_pairs, set(zip(*np.where(np.ones((3,3))))))