def test_confusion_call(self): # Also tests for the consistency of the labels as # either provided or collected by ConfusionMatrix through its lifetime self.assertRaises(RuntimeError, ConfusionMatrix(), [1], [1]) self.assertRaises(ValueError, ConfusionMatrix(labels=[2]), [1], [1]) # Now lets test proper matrix and either we obtain the same t = ['ho', 'ho', 'ho', 'fa', 'fa', 'ho', 'ho'] p = ['ho','ho', 'ho', 'ho', 'fa', 'fa', 'fa'] cm1 = ConfusionMatrix(labels=['ho', 'fa']) cm2 = ConfusionMatrix(labels=['fa', 'ho']) assert_array_equal(cm1(p, t), [[3, 1], [2, 1]]) assert_array_equal(cm2(p, t), [[1, 2], [1, 3]]) # reverse order of labels cm1_ = ConfusionMatrix(labels=['ho', 'fa'], sets=[(t,p)]) assert_array_equal(cm1(p, t), cm1_.matrix) # both should be identical # Lets provoke "mother" CM to get to know more labels which could get ahead # of the known ones cm1.add(['ho', 'aa'], ['ho', 'aa']) # compare and cause recomputation so .__labels get reassigned assert_equal(cm1.labels, ['ho', 'fa', 'aa']) assert_array_equal(cm1(p, t), [[3, 1, 0], [2, 1, 0], [0, 0, 0]]) assert_equal(len(cm1.sets), 1) # just 1 must be known atm from above add assert_array_equal(cm1(p, t, store=True), [[3, 1, 0], [2, 1, 0], [0, 0, 0]]) assert_equal(len(cm1.sets), 2) # and now 2 assert_array_equal(cm1(p + ['ho', 'aa'], t + ['ho', 'aa']), cm1.matrix)
def cross_decoding_confusion(predictions, targets, map_list): ''' map_list = two-element list, is what we expect the classifier did when it cross-decodes first element is the prediction second element is the label of second ds to be predicted ''' mask_predictions = np.array(predictions) == map_list[0] mask_targets = np.array(targets) == map_list[1] new_predictions = np.array(predictions, dtype='|S20').copy() new_targets = np.array(targets, dtype='|S20').copy() new_label_1 = map_list[0] + '-' + map_list[1] new_label_2 = np.unique(new_predictions[~mask_predictions])[0]+'-'+ \ np.unique(new_targets[~mask_targets])[0] new_predictions[mask_predictions] = new_label_1 new_targets[mask_targets] = new_label_1 new_predictions[~mask_predictions] = new_label_2 new_targets[~mask_targets] = new_label_2 c_matrix = ConfusionMatrix(predictions=new_predictions, targets=new_targets) c_matrix.compute() return c_matrix
def cross_decoding_confusion(predictions, targets, map_list): ''' map_list = two-element list, is what we expect the classifier did when it cross-decodes first element is the prediction second element is the label of second ds to be predicted ''' mask_predictions = np.array(predictions) == map_list[0] mask_targets = np.array(targets) == map_list[1] new_predictions = np.array(predictions, dtype='|S20').copy() new_targets = np.array(targets, dtype='|S20').copy() new_label_1 = map_list[0]+'-'+map_list[1] new_label_2 = np.unique(new_predictions[~mask_predictions])[0]+'-'+ \ np.unique(new_targets[~mask_targets])[0] new_predictions[mask_predictions] = new_label_1 new_targets[mask_targets] = new_label_1 new_predictions[~mask_predictions] = new_label_2 new_targets[~mask_targets] = new_label_2 c_matrix = ConfusionMatrix(predictions=new_predictions, targets=new_targets) c_matrix.compute() return c_matrix
def test_confusion_plot2(self): array = np.array uint8 = np.uint8 sets = [(array([1, 2]), array([1, 1]), array([[0.54343765, 0.45656235], [0.92395853, 0.07604147]])), (array([1, 2]), array([1, 1]), array([[0.98030832, 0.01969168], [0.78998763, 0.21001237]])), (array([1, 2]), array([1, 1]), array([[0.86125263, 0.13874737], [0.83674113, 0.16325887]])), (array([1, 2]), array([1, 1]), array([[0.57870383, 0.42129617], [0.59702509, 0.40297491]])), (array([1, 2]), array([1, 1]), array([[0.89530255, 0.10469745], [0.69373919, 0.30626081]])), (array([1, 2]), array([1, 1]), array([[0.75015218, 0.24984782], [0.9339767, 0.0660233]])), (array([1, 2]), array([1, 2]), array([[0.97826616, 0.02173384], [0.38620638, 0.61379362]])), (array([2]), array([2]), array([[0.46893776, 0.53106224]]))] try: cm = ConfusionMatrix(sets=sets) except: self.fail() if externals.exists("pylab plottable"): import pylab as pl #pl.figure() #print cm fig, im, cb = cm.plot(origin='lower', numbers=True) #pl.plot() self.assertTrue((cm._plotted_confusionmatrix == cm.matrix).all()) pl.close(fig)
def test_confusion_call(self): # Also tests for the consistency of the labels as # either provided or collected by ConfusionMatrix through its lifetime self.assertRaises(RuntimeError, ConfusionMatrix(), [1], [1]) self.assertRaises(ValueError, ConfusionMatrix(labels=[2]), [1], [1]) # Now lets test proper matrix and either we obtain the same t = ['ho', 'ho', 'ho', 'fa', 'fa', 'ho', 'ho'] p = ['ho', 'ho', 'ho', 'ho', 'fa', 'fa', 'fa'] cm1 = ConfusionMatrix(labels=['ho', 'fa']) cm2 = ConfusionMatrix(labels=['fa', 'ho']) assert_array_equal(cm1(p, t), [[3, 1], [2, 1]]) assert_array_equal(cm2(p, t), [[1, 2], [1, 3]]) # reverse order of labels cm1_ = ConfusionMatrix(labels=['ho', 'fa'], sets=[(t, p)]) assert_array_equal(cm1(p, t), cm1_.matrix) # both should be identical # Lets provoke "mother" CM to get to know more labels which could get ahead # of the known ones cm1.add(['ho', 'aa'], ['ho', 'aa']) # compare and cause recomputation so .__labels get reassigned assert_equal(cm1.labels, ['ho', 'fa', 'aa']) assert_array_equal(cm1(p, t), [[3, 1, 0], [2, 1, 0], [0, 0, 0]]) assert_equal(len(cm1.sets), 1) # just 1 must be known atm from above add assert_array_equal(cm1(p, t, store=True), [[3, 1, 0], [2, 1, 0], [0, 0, 0]]) assert_equal(len(cm1.sets), 2) # and now 2 assert_array_equal(cm1(p + ['ho', 'aa'], t + ['ho', 'aa']), cm1.matrix)
def test_confusion_matrix_addition(self): """Test confusions addition inconsistent results (GH #51) Was fixed by deepcopying instead of copying in __add__ """ cm1 = ConfusionMatrix(sets=[[np.array((1, 2)), np.array((1, 2))]]) cm2 = ConfusionMatrix(sets=[[np.array((3, 2)), np.array((3, 2))]]) assert_array_equal(cm1.stats['P'], [1, 1]) assert_array_equal(cm2.stats['P'], [1, 1]) # actual bug scenario -- results would be different r1 = (cm1 + cm2).stats['P'] r2 = (cm1 + cm2).stats['P'] assert_array_equal(r1, r2) assert_array_equal(r1, [1, 2, 1])
def test_conditional_attr(): import copy import cPickle for node in (TestNodeOnDefault(enable_ca=['test', 'stats']), TestNodeOffDefault(enable_ca=['test', 'stats'])): node.ca.test = range(5) node.ca.stats = ConfusionMatrix(labels=['one', 'two']) node.ca.stats.add(('one', 'two', 'one', 'two'), ('one', 'two', 'two', 'one')) node.ca.stats.compute() dc_node = copy.deepcopy(node) assert_equal(set(node.ca.enabled), set(dc_node.ca.enabled)) assert (node.ca['test'].enabled) assert (node.ca['stats'].enabled) assert_array_equal(node.ca['test'].value, dc_node.ca['test'].value) assert_array_equal(node.ca['stats'].value.matrix, dc_node.ca['stats'].value.matrix) # check whether values survive pickling pickled = cPickle.dumps(node) up_node = cPickle.loads(pickled) assert_array_equal(up_node.ca['test'].value, range(5)) assert_array_equal(up_node.ca['stats'].value.matrix, node.ca['stats'].value.matrix)
def test_confusion_matrix_acc(self): reg = [0, 0, 1, 1] regl = [1, 0, 1, 0] cm = ConfusionMatrix(targets=reg, predictions=regl) self.assertTrue('ACC% 50' in str(cm)) skip_if_no_external('scipy') self.assertTrue(cm.stats['CHI^2'] == (0., 1.))
def test_degenerate_confusion(self): # We must not just puke -- some testing splits might # have just a single target label for orig in ([1], [1, 1], [0], [0, 0]): cm = ConfusionMatrix(targets=orig, predictions=orig, estimates=orig) scm = str(cm) self.assertTrue(cm.stats['ACC%'] == 100)
def test_confusion_plot2(self): array = np.array uint8 = np.uint8 sets = [(array([1, 2]), array([1, 1]), array([[ 0.54343765, 0.45656235], [ 0.92395853, 0.07604147]])), (array([1, 2]), array([1, 1]), array([[ 0.98030832, 0.01969168], [ 0.78998763, 0.21001237]])), (array([1, 2]), array([1, 1]), array([[ 0.86125263, 0.13874737], [ 0.83674113, 0.16325887]])), (array([1, 2]), array([1, 1]), array([[ 0.57870383, 0.42129617], [ 0.59702509, 0.40297491]])), (array([1, 2]), array([1, 1]), array([[ 0.89530255, 0.10469745], [ 0.69373919, 0.30626081]])), (array([1, 2]), array([1, 1]), array([[ 0.75015218, 0.24984782], [ 0.9339767 , 0.0660233 ]])), (array([1, 2]), array([1, 2]), array([[ 0.97826616, 0.02173384], [ 0.38620638, 0.61379362]])), (array([2]), array([2]), array([[ 0.46893776, 0.53106224]]))] try: cm = ConfusionMatrix(sets=sets) except: self.fail() if externals.exists("pylab plottable"): import pylab as pl #pl.figure() #print cm fig, im, cb = cm.plot(origin='lower', numbers=True) #pl.plot() self.assertTrue((cm._plotted_confusionmatrix == cm.matrix).all()) pl.close(fig)
def test_confusion_matrix_with_mappings(self): data = np.array([1, 2, 1, 2, 2, 2, 3, 2, 1], ndmin=2).T reg = [1, 1, 1, 2, 2, 2, 3, 3, 3] regl = [1, 2, 1, 2, 2, 2, 3, 2, 1] correct_cm = [[2, 0, 1], [1, 3, 1], [0, 0, 1]] lm = {'apple': 1, 'orange': 2, 'shitty apple': 1, 'candy': 3} cm = ConfusionMatrix(targets=reg, predictions=regl, labels_map=lm) # check table content self.assertTrue((cm.matrix == correct_cm).all()) # assure that all labels are somewhere listed ;-) s = str(cm) for l in lm.keys(): self.assertTrue(l in s)
def test_confusion_plot(self): """Basic test of confusion plot Based on existing cell dataset results. Let in for possible future testing, but is not a part of the unittests suite """ #from matplotlib import rc as rcmpl #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']}) ##rcmpl('text', usetex=True) ##rcmpl('font', family='sans', style='normal', variant='normal', ## weight='bold', stretch='normal', size='large') #import numpy as np #from mvpa2.clfs.transerror import \ # TransferError, ConfusionMatrix, ConfusionBasedError array = np.array uint8 = np.uint8 sets = [ (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8), array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46, 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41, 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44, 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43, 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)), (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8), array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40, 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40, 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)), (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8), array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45, 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38, 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39, 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)), (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8), array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40, 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42, 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46, 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)), (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45, 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42, 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38, 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45, 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)), (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8), array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38, 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43, 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45, 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39, 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46, 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)), (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8), array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46, 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)), (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41, 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46, 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42, 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43, 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47, 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41, 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41, 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))] labels_map = {'12kHz': 40, '20kHz': 41, '30kHz': 42, '3kHz': 38, '7kHz': 39, 'song1': 43, 'song2': 44, 'song3': 45, 'song4': 46, 'song5': 47} try: cm = ConfusionMatrix(sets=sets, labels_map=labels_map) except: self.fail() cms = str(cm) self.assertTrue('3kHz / 38' in cms) if externals.exists("scipy"): self.assertTrue('ACC(i) = 0.82-0.012*i p=0.12 r=-0.59 r^2=0.35' in cms) if externals.exists("pylab plottable"): import pylab as pl pl.figure() labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None, "song1","song2","song3","song4","song5") #print cm #fig, im, cb = cm.plot(origin='lower', labels=labels_order) fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] + labels_order[2:], numbers=True) self.assertTrue(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1]) self.assertTrue(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0]) self.assertTrue(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0]) self.assertTrue(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1]) pl.close(fig) fig, im, cb = cm.plot(labels=labels_order, numbers=True) pl.close(fig)
def test_confusion_matrix(self): data = np.array([1, 2, 1, 2, 2, 2, 3, 2, 1], ndmin=2).T reg = [1, 1, 1, 2, 2, 2, 3, 3, 3] regl = [1, 2, 1, 2, 2, 2, 3, 2, 1] correct_cm = [[2, 0, 1], [1, 3, 1], [0, 0, 1]] # Check if we are ok with any input type - either list, or np.array, or tuple for t in [reg, tuple(reg), list(reg), np.array(reg)]: for p in [regl, tuple(regl), list(regl), np.array(regl)]: cm = ConfusionMatrix(targets=t, predictions=p) # check table content self.assertTrue((cm.matrix == correct_cm).all()) # Do a bit more thorough checking cm = ConfusionMatrix() self.assertRaises(ZeroDivisionError, lambda x: x.percent_correct, cm) """No samples -- raise exception""" cm.add(reg, regl) self.assertEqual(len(cm.sets), 1, msg="Should have a single set so far") self.assertEqual( cm.matrix.shape, (3, 3), msg="should be square matrix (len(reglabels) x len(reglabels)") self.assertRaises(ValueError, cm.add, reg, np.array([1])) """ConfusionMatrix must complaint if number of samples different""" # check table content self.assertTrue((cm.matrix == correct_cm).all()) # lets add with new labels (not yet known) cm.add(reg, np.array([1, 4, 1, 2, 2, 2, 4, 2, 1])) self.assertEqual(cm.labels, [1, 2, 3, 4], msg="We should have gotten 4th label") matrices = cm.matrices # separate CM per each given set self.assertEqual(len(matrices), 2, msg="Have gotten two splits") self.assertTrue( (matrices[0].matrix + matrices[1].matrix == cm.matrix).all(), msg="Total votes should match the sum across split CMs") # check pretty print # just a silly test to make sure that printing works self.assertTrue( len(cm.as_string(header=True, summary=True, description=True)) > 100) self.assertTrue(len(str(cm)) > 100) # and that it knows some parameters for printing self.assertTrue(len(cm.as_string(summary=True, header=False)) > 100) # lets check iadd -- just itself to itself cm += cm self.assertEqual(len(cm.matrices), 4, msg="Must be 4 sets now") # lets check add -- just itself to itself cm2 = cm + cm self.assertEqual(len(cm2.matrices), 8, msg="Must be 8 sets now") self.assertEqual(cm2.percent_correct, cm.percent_correct, msg="Percent of corrrect should remain the same ;-)") self.assertEqual(cm2.error, 1.0 - cm.percent_correct / 100.0, msg="Test if we get proper error value")
def test_confusion_plot(self): """Basic test of confusion plot Based on existing cell dataset results. Let in for possible future testing, but is not a part of the unittests suite """ #from matplotlib import rc as rcmpl #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']}) ##rcmpl('text', usetex=True) ##rcmpl('font', family='sans', style='normal', variant='normal', ## weight='bold', stretch='normal', size='large') #import numpy as np #from mvpa2.clfs.transerror import \ # TransferError, ConfusionMatrix, ConfusionBasedError array = np.array uint8 = np.uint8 sets = [(array([ 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44 ], dtype=uint8), array([ 40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46, 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41, 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44, 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43, 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44 ], dtype=uint8)), (array([ 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43 ], dtype=uint8), array([ 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40, 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40, 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43 ], dtype=uint8)), (array([ 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47 ], dtype=uint8), array([ 45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45, 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38, 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39, 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40 ], dtype=uint8)), (array([ 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40 ], dtype=uint8), array([ 39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40, 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42, 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46, 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40 ], dtype=uint8)), (array([ 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45 ], dtype=uint8), array([ 46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45, 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42, 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38, 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45, 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45 ], dtype=uint8)), (array([ 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39 ], dtype=uint8), array([ 41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38, 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43, 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45, 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39, 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46, 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38 ], dtype=uint8)), (array([ 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46 ], dtype=uint8), array([ 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46, 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46 ], dtype=uint8)), (array([ 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45 ], dtype=uint8), array([ 42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41, 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46, 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42, 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43, 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47, 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41, 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41, 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45 ], dtype=uint8))] labels_map = { '12kHz': 40, '20kHz': 41, '30kHz': 42, '3kHz': 38, '7kHz': 39, 'song1': 43, 'song2': 44, 'song3': 45, 'song4': 46, 'song5': 47 } try: cm = ConfusionMatrix(sets=sets, labels_map=labels_map) except: self.fail() cms = str(cm) self.assertTrue('3kHz / 38' in cms) if externals.exists("scipy"): self.assertTrue( 'ACC(i) = 0.82-0.012*i p=0.12 r=-0.59 r^2=0.35' in cms) if externals.exists("pylab plottable"): import pylab as pl pl.figure() labels_order = ("3kHz", "7kHz", "12kHz", "20kHz", "30kHz", None, "song1", "song2", "song3", "song4", "song5") #print cm #fig, im, cb = cm.plot(origin='lower', labels=labels_order) fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] + labels_order[2:], numbers=True) self.assertTrue(cm._plotted_confusionmatrix[0, 0] == cm.matrix[1, 1]) self.assertTrue(cm._plotted_confusionmatrix[0, 1] == cm.matrix[1, 0]) self.assertTrue(cm._plotted_confusionmatrix[1, 1] == cm.matrix[0, 0]) self.assertTrue(cm._plotted_confusionmatrix[1, 0] == cm.matrix[0, 1]) pl.close(fig) fig, im, cb = cm.plot(labels=labels_order, numbers=True) pl.close(fig)
def _group_transfer_learning(path, subjects, analysis, conf_file, source='task', analysis_type='single', **kwargs): if source == 'task': target = 'rest' else: if source == 'rest': target = 'task' if source == 'saccade': target = 'face' else: if source == 'face': target = 'saccade' ############################################## ############################################## ## conf_src['label_included'] = 'all' ## ## conf_src['label_dropped'] = 'none' ## ## conf_src['mean_samples'] = 'False' ## ############################################## ############################################## if analysis_type == 'group': if path.__class__ == conf_file.__class__ == list: ds_src, s, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs) conf_src['permutations'] = 0 else: print 'In group analysis path, subjects and conf_file must be lists: \ Check configuration file and/or parameters!!' return 0 else: conf_src = read_configuration(path, conf_file, source) for arg in conf_src: if arg == 'map_list': map_list = conf_src[arg].split(',') r_group = spatial(ds_src, **conf_src) total_results = dict() total_results['group'] = r_group clf = r_group['classifier'] for subj_, conf_, path_ in zip(subjects, conf_file, path): for subj in subj_: print '-----------' r = dict() if len(subj_) > 1: conf_tar = read_configuration(path_, conf_, target) for arg in kwargs: conf_tar[arg] = kwargs[arg] data_path = conf_tar['data_path'] try: ds_tar = load_dataset(data_path, subj, target, **conf_tar) except Exception, err: print err continue ds_tar = detrend_dataset(ds_tar, target, **conf_tar) if conf_src['label_included'] == 'all' and \ conf_src['label_dropped'] != 'fixation': print 'Balancing dataset...' ds_src = balance_dataset_timewise(ds_src, 'fixation') predictions = clf.predict(ds_tar) pred = np.array(predictions) targets = ds_tar.targets for arg in r_group.keys(): r[arg] = copy.copy(r_group[arg]) r['targets'] = targets r['predictions'] = predictions r['fclf'] = clf c_m = ConfusionMatrix(predictions=pred, targets=targets) c_m.compute() r['confusion_target'] = c_m print c_m tr_pred = similarity_measure_mahalanobis(ds_tar, ds_src, r) r['mahalanobis_similarity'] = tr_pred #print tr_pred c_mat_mahala = ConfusionMatrix(predictions=tr_pred.T[1], targets=tr_pred.T[0]) c_mat_mahala.compute() r['confusion_mahala'] = c_mat_mahala d_prime, beta, c, c_new = signal_detection_measures( pred, targets, map_list) r['d_prime'] = d_prime print d_prime r['beta'] = beta r['c'] = c r['confusion_total'] = c_new ''' d_prime_maha, c_new_maha = d_prime_statistics(tr_pred.T[1], tr_pred.T[0], map_list) r['d_prime_maha'] = d_prime_maha r['confusion_tot_maha'] = c_new_maha ''' total_results[subj] = r
''' Get some data ''' dt = np.dtype([('labels', np.str_, 20), ('predictions', np.str_, 20), ('distances', np.float32), ('pvalues', np.float32)]) full_data = np.array(zip(ds_tar.targets, classifier_prediction_tar, mahalanobis_values, p_values), dtype=dt) classifier_prediction_tar = np.array(classifier_prediction_tar) ds_tar_targets = np.array(ds_tar.targets) c_mat_mahala = ConfusionMatrix( predictions=classifier_prediction_tar[similarity_mask], targets=ds_tar_targets[similarity_mask]) c_mat_mahala.compute() header = [ 'similarity_data', 'similarity_mask', 'threshold_value', 'pvalues', 'distances', 'confusion_mahalanobis' ] result_data = [ full_data, similarity_mask, threshold, p_values, distances, c_mat_mahala ] results = dict(zip(header, result_data)) return results
''' dt = np.dtype([('labels', np.str_, 20), ('predictions', np.str_, 20), ('distances', np.float32), ('pvalues', np.float32)]) full_data = np.array(zip(ds_tar.targets, classifier_prediction_tar, mahalanobis_values, p_values), dtype=dt) classifier_prediction_tar = np.array(classifier_prediction_tar) ds_tar_targets = np.array(ds_tar.targets) c_mat_mahala = ConfusionMatrix(predictions=classifier_prediction_tar[similarity_mask], targets=ds_tar_targets[similarity_mask]) c_mat_mahala.compute() header = ['similarity_data', 'similarity_mask', 'threshold_value', 'pvalues', 'distances', 'confusion_mahalanobis'] result_data = [full_data, similarity_mask, threshold, p_values, distances, c_mat_mahala] results = dict(zip(header, result_data)) return results def similarity_measure_correlation (ds_tar, ds_src, results, p_value): print 'Computing Mahalanobis similarity...' #classifier = results['classifier']
def test_partial_searchlight_with_confusion_matrix(self): ds = self.dataset from mvpa2.clfs.stats import MCNullDist from mvpa2.mappers.fx import mean_sample, sum_sample # compute N-1 cross-validation for each sphere cm = ConfusionMatrix(labels=ds.UT) cv = CrossValidation( sample_clf_lin, NFoldPartitioner(), # we have to assure that matrix does not get flatted by # first vstack in cv and then hstack in searchlight -- # thus 2 leading dimensions # TODO: RF? make searchlight/crossval smarter? errorfx=lambda *a: cm(*a)[None, None, :]) # contruct diameter 2 (or just radius 1) searchlight sl = sphere_searchlight(cv, radius=1, center_ids=[3, 5, 50]) # our regular searchlight -- to compare results cv_gross = CrossValidation(sample_clf_lin, NFoldPartitioner()) sl_gross = sphere_searchlight(cv_gross, radius=1, center_ids=[3, 5, 50]) # run searchlights res = sl(ds) res_gross = sl_gross(ds) # only two spheres but error for all CV-folds and complete confusion matrix assert_equal(res.shape, (len(ds.UC), 3, len(ds.UT), len(ds.UT))) assert_equal(res_gross.shape, (len(ds.UC), 3)) # briefly inspect the confusion matrices mat = res.samples # since input dataset is probably balanced (otherwise adjust # to be per label): sum within columns (thus axis=-2) should # be identical to per-class/chunk number of samples samples_per_classchunk = len(ds) / (len(ds.UT) * len(ds.UC)) ok_(np.all(np.sum(mat, axis=-2) == samples_per_classchunk)) # and if we compute accuracies manually -- they should # correspond to the one from sl_gross assert_array_almost_equal( res_gross.samples, # from accuracies to errors 1 - (mat[..., 0, 0] + mat[..., 1, 1]).astype(float) / (2 * samples_per_classchunk)) # and now for those who remained sited -- lets perform H0 MC # testing of this searchlight... just a silly one with minimal # number of permutations no_permutations = 10 permutator = AttributePermutator('targets', count=no_permutations) # once again -- need explicit leading dimension to avoid # vstacking during cross-validation cv.postproc = lambda x: sum_sample()(x)[None, :] sl = sphere_searchlight(cv, radius=1, center_ids=[3, 5, 50], null_dist=MCNullDist( permutator, tail='right', enable_ca=['dist_samples'])) res_perm = sl(ds) # XXX all of the res_perm, sl.ca.null_prob and # sl.null_dist.ca.dist_samples carry a degenerate leading # dimension which was probably due to introduced new axis # above within cv.postproc assert_equal(res_perm.shape, (1, 3, 2, 2)) assert_equal(sl.null_dist.ca.dist_samples.shape, res_perm.shape + (no_permutations, )) assert_equal(sl.ca.null_prob.shape, res_perm.shape) # just to make sure ;) ok_(np.all(sl.ca.null_prob.samples >= 0)) ok_(np.all(sl.ca.null_prob.samples <= 1)) # we should have got sums of hits across the splits assert_array_equal(np.sum(mat, axis=0), res_perm.samples[0])
def test_transfer_learning(path, subjects, analysis, conf_file, source='task', \ analysis_type='single', calculateSimilarity='True', **kwargs): if source == 'task': target = 'rest' else: if source == 'rest': target = 'task' if source == 'saccade': target = 'face' else: if source == 'face': target = 'saccade' p = kwargs['p'] ############################################## ############################################## ## conf_src['label_included'] = 'all' ## ## conf_src['label_dropped'] = 'none' ## ## conf_src['mean_samples'] = 'False' ## ############################################## ############################################## if analysis_type == 'group': if path.__class__ == conf_file.__class__ == list: ds_src, _, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs) ds_tar, subjects, conf_tar = sources_merged_ds(path, subjects, conf_file, target, **kwargs) conf_src['permutations'] = 0 conf_tar['permutations'] = 0 else: print 'In group analysis path, subjects and conf_file must be lists: \ Check configuration file and/or parameters!!' return 0 else: conf_src = read_configuration(path, conf_file, source) conf_tar = read_configuration(path, conf_file, target) for arg in kwargs: conf_src[arg] = kwargs[arg] conf_tar[arg] = kwargs[arg] data_path = conf_src['data_path'] conf_src['analysis_type'] = 'transfer_learning' conf_src['analysis_task'] = source conf_src['analysis_func'] = analysis.func_name for arg in conf_src: if arg == 'map_list': map_list = conf_src[arg].split(',') if arg == 'p_dist': p = float(conf_src[arg]) print p total_results = dict() summarizers = [rs.CrossDecodingSummarizer(), rs.SimilaritySummarizer(), rs.DecodingSummarizer(), rs.SignalDetectionSummarizer(), ] savers = [rs.CrossDecodingSaver(), rs.SimilaritySaver(), rs.DecodingSaver(), rs.SignalDetectionSaver(), ] collection = rs.ResultsCollection(conf_src, path, summarizers) for subj in subjects: print '-------------------' if (len(subjects) > 1) or (subj != 'group'): try: ds_src = load_dataset(data_path, subj, source, **conf_src) ds_tar = load_dataset(data_path, subj, target, **conf_tar) except Exception, err: print err continue # Evaluate if is correct to do further normalization after merging two ds. ds_src = detrend_dataset(ds_src, source, **conf_src) ds_tar = detrend_dataset(ds_tar, target, **conf_tar) if conf_src['label_included'] == 'all' and \ conf_src['label_dropped'] != 'fixation': print 'Balancing dataset...' ds_src = balance_dataset_timewise(ds_src, 'fixation') # Make cross-decoding r = transfer_learning(ds_src, ds_tar, analysis, **conf_src) # Now we have cross-decoding results we could process it pred = np.array(r['classifier'].ca.predictions) targets = r['targets'] c_m = ConfusionMatrix(predictions=pred, targets=targets) c_m.compute() r['confusion_target'] = c_m c_new = cross_decoding_confusion(pred, targets, map_list) r['confusion_total'] = c_new print c_new # Similarity Analysis if calculateSimilarity == 'True': if 'p' not in locals(): print 'Ciao!' mahala_data = similarity_measure(r['ds_tar'], r['ds_src'], r, p_value=p, method='mahalanobis') #r['mahalanobis_similarity'] = mahala_data for k_,v_ in mahala_data.items(): r[k_] = v_ r['confusion_mahala'] = mahala_data['confusion_mahalanobis'] else: #r['mahalanobis_similarity'] = [] r['confusion_mahala'] = 'Null' # Signal Detection Theory Analysis sdt_res = signal_detection_measures(c_new) for k_,v_ in sdt_res.items(): r[k_] = v_ ''' Same code of: r['d_prime'] = d_prime r['beta'] = beta r['c'] = c ''' total_results[subj] = r subj_result = rs.SubjectResult(subj, r, savers=savers) collection.add(subj_result)
def test_gideon_weird_case(self): """Test if MappedClassifier could handle a mapper altering number of samples 'The utter collapse' -- communicated by Peter J. Kohler Desire to collapse all samples per each category in training and testing sets, thus resulting only in a single sample/category per training and per testing. It is a peculiar scenario which pin points the problem that so far mappers assumed not to change number of samples """ from mvpa2.mappers.fx import mean_group_sample from mvpa2.clfs.knn import kNN from mvpa2.mappers.base import ChainMapper ds = datasets['uni2large'].copy() #ds = ds[ds.sa.chunks < 9] accs = [] k = 1 # for kNN nf = 1 # for NFoldPartitioner for i in xrange(1): # # of random runs ds.samples = np.random.randn(*ds.shape) # # There are 3 ways to accomplish needed goal # # 0. Hard way: overcome the problem by manually # pre-splitting/meaning in a loop from mvpa2.clfs.transerror import ConfusionMatrix partitioner = NFoldPartitioner(nf) meaner = mean_group_sample(['targets', 'partitions']) cm = ConfusionMatrix() te = TransferMeasure(kNN(k), Splitter('partitions'), postproc=BinaryFxNode(mean_mismatch_error, 'targets'), enable_ca=['stats']) errors = [] for part in partitioner.generate(ds): ds_meaned = meaner(part) errors.append(np.asscalar(te(ds_meaned))) cm += te.ca.stats #print i, cm.stats['ACC'] accs.append(cm.stats['ACC']) if False: # not yet working -- see _tent/allow_ch_nsamples # branch for attempt to make it work # 1. This is a "native way" IF we allow change of number # of samples via _call to be done by MappedClassifier # while operating solely on the mapped dataset clf2 = MappedClassifier( clf=kNN(k), #clf, mapper=mean_group_sample(['targets', 'partitions'])) cv = CrossValidation(clf2, NFoldPartitioner(nf), postproc=None, enable_ca=['stats']) # meaning all should be ok since we should have ballanced # sets across all chunks here errors_native = cv(ds) self.assertEqual( np.max(np.abs(errors_native.samples[:, 0] - errors)), 0) # 2. Work without fixes to MappedClassifier allowing # change of # of samples # # CrossValidation will operate on a chain mapper which # would perform necessary meaning first before dealing with # kNN cons: .stats would not be exposed since ChainMapper # doesn't expose them from ChainMapper (yet) if __debug__ and 'ENFORCE_CA_ENABLED' in debug.active: raise SkipTest("Known to fail while trying to enable " "training_stats for the ChainMapper") cv2 = CrossValidation(ChainMapper( [mean_group_sample(['targets', 'partitions']), kNN(k)], space='targets'), NFoldPartitioner(nf), postproc=None) errors_native2 = cv2(ds) self.assertEqual( np.max(np.abs(errors_native2.samples[:, 0] - errors)), 0) # All of the ways should provide the same results #print i, np.max(np.abs(errors_native.samples[:,0] - errors)), \ # np.max(np.abs(errors_native2.samples[:,0] - errors)) if False: # just to investigate the distribution if we have enough iterations import pylab as pl uaccs = np.unique(accs) step = np.asscalar(np.unique(np.round(uaccs[1:] - uaccs[:-1], 4))) bins = np.linspace(0., 1., np.round(1. / step + 1)) xx = pl.hist(accs, bins=bins, align='left') pl.xlim((0. - step / 2, 1. + step / 2))
def _group_transfer_learning(path, subjects, analysis, conf_file, source='task', analysis_type='single', **kwargs): if source == 'task': target = 'rest' else: if source == 'rest': target = 'task' if source == 'saccade': target = 'face' else: if source == 'face': target = 'saccade' ############################################## ############################################## ## conf_src['label_included'] = 'all' ## ## conf_src['label_dropped'] = 'none' ## ## conf_src['mean_samples'] = 'False' ## ############################################## ############################################## if analysis_type == 'group': if path.__class__ == conf_file.__class__ == list: ds_src, s, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs) conf_src['permutations'] = 0 else: print 'In group analysis path, subjects and conf_file must be lists: \ Check configuration file and/or parameters!!' return 0 else: conf_src = read_configuration(path, conf_file, source) for arg in conf_src: if arg == 'map_list': map_list = conf_src[arg].split(',') r_group = spatial(ds_src, **conf_src) total_results = dict() total_results['group'] = r_group clf = r_group['classifier'] for subj_, conf_, path_ in zip(subjects, conf_file, path): for subj in subj_: print '-----------' r = dict() if len(subj_) > 1: conf_tar = read_configuration(path_, conf_, target) for arg in kwargs: conf_tar[arg] = kwargs[arg] data_path = conf_tar['data_path'] try: ds_tar = load_dataset(data_path, subj, target, **conf_tar) except Exception, err: print err continue ds_tar = detrend_dataset(ds_tar, target, **conf_tar) if conf_src['label_included'] == 'all' and \ conf_src['label_dropped'] != 'fixation': print 'Balancing dataset...' ds_src = balance_dataset_timewise(ds_src, 'fixation') predictions = clf.predict(ds_tar) pred = np.array(predictions) targets = ds_tar.targets for arg in r_group.keys(): r[arg] = copy.copy(r_group[arg]) r['targets'] = targets r['predictions'] = predictions r['fclf'] = clf c_m = ConfusionMatrix(predictions=pred, targets=targets) c_m.compute() r['confusion_target'] = c_m print c_m tr_pred = similarity_measure_mahalanobis(ds_tar, ds_src, r) r['mahalanobis_similarity'] = tr_pred #print tr_pred c_mat_mahala = ConfusionMatrix(predictions=tr_pred.T[1], targets=tr_pred.T[0]) c_mat_mahala.compute() r['confusion_mahala'] = c_mat_mahala d_prime, beta, c, c_new = signal_detection_measures(pred, targets, map_list) r['d_prime'] = d_prime print d_prime r['beta'] = beta r['c'] = c r['confusion_total'] = c_new ''' d_prime_maha, c_new_maha = d_prime_statistics(tr_pred.T[1], tr_pred.T[0], map_list) r['d_prime_maha'] = d_prime_maha r['confusion_tot_maha'] = c_new_maha ''' total_results[subj] = r
def test_confusion_matrix(self): data = np.array([1,2,1,2,2,2,3,2,1], ndmin=2).T reg = [1,1,1,2,2,2,3,3,3] regl = [1,2,1,2,2,2,3,2,1] correct_cm = [[2,0,1],[1,3,1],[0,0,1]] # Check if we are ok with any input type - either list, or np.array, or tuple for t in [reg, tuple(reg), list(reg), np.array(reg)]: for p in [regl, tuple(regl), list(regl), np.array(regl)]: cm = ConfusionMatrix(targets=t, predictions=p) # check table content self.assertTrue((cm.matrix == correct_cm).all()) # Do a bit more thorough checking cm = ConfusionMatrix() self.assertRaises(ZeroDivisionError, lambda x:x.percent_correct, cm) """No samples -- raise exception""" cm.add(reg, regl) self.assertEqual(len(cm.sets), 1, msg="Should have a single set so far") self.assertEqual(cm.matrix.shape, (3,3), msg="should be square matrix (len(reglabels) x len(reglabels)") self.assertRaises(ValueError, cm.add, reg, np.array([1])) """ConfusionMatrix must complaint if number of samples different""" # check table content self.assertTrue((cm.matrix == correct_cm).all()) # lets add with new labels (not yet known) cm.add(reg, np.array([1,4,1,2,2,2,4,2,1])) self.assertEqual(cm.labels, [1,2,3,4], msg="We should have gotten 4th label") matrices = cm.matrices # separate CM per each given set self.assertEqual(len(matrices), 2, msg="Have gotten two splits") self.assertTrue((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(), msg="Total votes should match the sum across split CMs") # check pretty print # just a silly test to make sure that printing works self.assertTrue(len(cm.as_string( header=True, summary=True, description=True))>100) self.assertTrue(len(str(cm))>100) # and that it knows some parameters for printing self.assertTrue(len(cm.as_string(summary=True, header=False))>100) # lets check iadd -- just itself to itself cm += cm self.assertEqual(len(cm.matrices), 4, msg="Must be 4 sets now") # lets check add -- just itself to itself cm2 = cm + cm self.assertEqual(len(cm2.matrices), 8, msg="Must be 8 sets now") self.assertEqual(cm2.percent_correct, cm.percent_correct, msg="Percent of corrrect should remain the same ;-)") self.assertEqual(cm2.error, 1.0-cm.percent_correct/100.0, msg="Test if we get proper error value")
def test_transfer_learning(path, subjects, analysis, conf_file, source='task', \ analysis_type='single', calculateSimilarity='True', **kwargs): if source == 'task': target = 'rest' else: if source == 'rest': target = 'task' if source == 'saccade': target = 'face' else: if source == 'face': target = 'saccade' p = kwargs['p'] ############################################## ############################################## ## conf_src['label_included'] = 'all' ## ## conf_src['label_dropped'] = 'none' ## ## conf_src['mean_samples'] = 'False' ## ############################################## ############################################## if analysis_type == 'group': if path.__class__ == conf_file.__class__ == list: ds_src, _, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs) ds_tar, subjects, conf_tar = sources_merged_ds( path, subjects, conf_file, target, **kwargs) conf_src['permutations'] = 0 conf_tar['permutations'] = 0 else: print 'In group analysis path, subjects and conf_file must be lists: \ Check configuration file and/or parameters!!' return 0 else: conf_src = read_configuration(path, conf_file, source) conf_tar = read_configuration(path, conf_file, target) for arg in kwargs: conf_src[arg] = kwargs[arg] conf_tar[arg] = kwargs[arg] data_path = conf_src['data_path'] conf_src['analysis_type'] = 'transfer_learning' conf_src['analysis_task'] = source conf_src['analysis_func'] = analysis.func_name for arg in conf_src: if arg == 'map_list': map_list = conf_src[arg].split(',') if arg == 'p_dist': p = float(conf_src[arg]) print p total_results = dict() summarizers = [ rs.CrossDecodingSummarizer(), rs.SimilaritySummarizer(), rs.DecodingSummarizer(), rs.SignalDetectionSummarizer(), ] savers = [ rs.CrossDecodingSaver(), rs.SimilaritySaver(), rs.DecodingSaver(), rs.SignalDetectionSaver(), ] collection = rs.ResultsCollection(conf_src, path, summarizers) for subj in subjects: print '-------------------' if (len(subjects) > 1) or (subj != 'group'): try: ds_src = load_dataset(data_path, subj, source, **conf_src) ds_tar = load_dataset(data_path, subj, target, **conf_tar) except Exception, err: print err continue # Evaluate if is correct to do further normalization after merging two ds. ds_src = detrend_dataset(ds_src, source, **conf_src) ds_tar = detrend_dataset(ds_tar, target, **conf_tar) if conf_src['label_included'] == 'all' and \ conf_src['label_dropped'] != 'fixation': print 'Balancing dataset...' ds_src = balance_dataset_timewise(ds_src, 'fixation') # Make cross-decoding r = transfer_learning(ds_src, ds_tar, analysis, **conf_src) # Now we have cross-decoding results we could process it pred = np.array(r['classifier'].ca.predictions) targets = r['targets'] c_m = ConfusionMatrix(predictions=pred, targets=targets) c_m.compute() r['confusion_target'] = c_m c_new = cross_decoding_confusion(pred, targets, map_list) r['confusion_total'] = c_new print c_new # Similarity Analysis if calculateSimilarity == 'True': if 'p' not in locals(): print 'Ciao!' mahala_data = similarity_measure(r['ds_tar'], r['ds_src'], r, p_value=p, method='mahalanobis') #r['mahalanobis_similarity'] = mahala_data for k_, v_ in mahala_data.items(): r[k_] = v_ r['confusion_mahala'] = mahala_data['confusion_mahalanobis'] else: #r['mahalanobis_similarity'] = [] r['confusion_mahala'] = 'Null' # Signal Detection Theory Analysis sdt_res = signal_detection_measures(c_new) for k_, v_ in sdt_res.items(): r[k_] = v_ ''' Same code of: r['d_prime'] = d_prime r['beta'] = beta r['c'] = c ''' total_results[subj] = r subj_result = rs.SubjectResult(subj, r, savers=savers) collection.add(subj_result)