Пример #1
0
    def test_confusion_call(self):
        # Also tests for the consistency of the labels as
        # either provided or collected by ConfusionMatrix through its lifetime
        self.assertRaises(RuntimeError, ConfusionMatrix(), [1], [1])
        self.assertRaises(ValueError, ConfusionMatrix(labels=[2]), [1], [1])
        # Now lets test proper matrix and either we obtain the same
        t = ['ho', 'ho', 'ho', 'fa', 'fa', 'ho', 'ho']
        p = ['ho','ho', 'ho', 'ho', 'fa', 'fa', 'fa']
        cm1 = ConfusionMatrix(labels=['ho', 'fa'])
        cm2 = ConfusionMatrix(labels=['fa', 'ho'])
        assert_array_equal(cm1(p, t), [[3, 1], [2, 1]])
        assert_array_equal(cm2(p, t), [[1, 2], [1, 3]]) # reverse order of labels

        cm1_ = ConfusionMatrix(labels=['ho', 'fa'], sets=[(t,p)])
        assert_array_equal(cm1(p, t), cm1_.matrix) # both should be identical
        # Lets provoke "mother" CM to get to know more labels which could get ahead
        # of the known ones
        cm1.add(['ho', 'aa'], ['ho', 'aa'])
        # compare and cause recomputation so .__labels get reassigned
        assert_equal(cm1.labels, ['ho', 'fa', 'aa'])
        assert_array_equal(cm1(p, t), [[3, 1, 0], [2, 1, 0], [0, 0, 0]])
        assert_equal(len(cm1.sets), 1)  # just 1 must be known atm from above add
        assert_array_equal(cm1(p, t, store=True), [[3, 1, 0], [2, 1, 0], [0, 0, 0]])
        assert_equal(len(cm1.sets), 2)  # and now 2
        assert_array_equal(cm1(p + ['ho', 'aa'], t + ['ho', 'aa']), cm1.matrix)
Пример #2
0
def cross_decoding_confusion(predictions, targets, map_list):
    '''
    map_list =  two-element list, is what we expect the classifier did when it cross-decodes
                first element is the prediction
                second element is the label of second ds to be predicted
    '''

    mask_predictions = np.array(predictions) == map_list[0]
    mask_targets = np.array(targets) == map_list[1]

    new_predictions = np.array(predictions, dtype='|S20').copy()
    new_targets = np.array(targets, dtype='|S20').copy()

    new_label_1 = map_list[0] + '-' + map_list[1]
    new_label_2 = np.unique(new_predictions[~mask_predictions])[0]+'-'+ \
                    np.unique(new_targets[~mask_targets])[0]

    new_predictions[mask_predictions] = new_label_1
    new_targets[mask_targets] = new_label_1

    new_predictions[~mask_predictions] = new_label_2
    new_targets[~mask_targets] = new_label_2

    c_matrix = ConfusionMatrix(predictions=new_predictions,
                               targets=new_targets)
    c_matrix.compute()

    return c_matrix
Пример #3
0
def cross_decoding_confusion(predictions, targets, map_list):
    '''
    map_list =  two-element list, is what we expect the classifier did when it cross-decodes
                first element is the prediction
                second element is the label of second ds to be predicted
    '''
    
    mask_predictions = np.array(predictions) == map_list[0]
    mask_targets = np.array(targets) == map_list[1]
    
    new_predictions = np.array(predictions, dtype='|S20').copy()
    new_targets = np.array(targets, dtype='|S20').copy()
    
    new_label_1 = map_list[0]+'-'+map_list[1]
    new_label_2 = np.unique(new_predictions[~mask_predictions])[0]+'-'+ \
                    np.unique(new_targets[~mask_targets])[0]
    
    new_predictions[mask_predictions] = new_label_1
    new_targets[mask_targets] = new_label_1
    
    new_predictions[~mask_predictions] = new_label_2
    new_targets[~mask_targets] = new_label_2
    
    c_matrix = ConfusionMatrix(predictions=new_predictions, targets=new_targets)
    c_matrix.compute()
    
    return c_matrix
Пример #4
0
    def test_confusion_plot2(self):

        array = np.array
        uint8 = np.uint8
        sets = [(array([1, 2]), array([1, 1]),
                 array([[0.54343765, 0.45656235], [0.92395853, 0.07604147]])),
                (array([1, 2]), array([1, 1]),
                 array([[0.98030832, 0.01969168], [0.78998763, 0.21001237]])),
                (array([1, 2]), array([1, 1]),
                 array([[0.86125263, 0.13874737], [0.83674113, 0.16325887]])),
                (array([1, 2]), array([1, 1]),
                 array([[0.57870383, 0.42129617], [0.59702509, 0.40297491]])),
                (array([1, 2]), array([1, 1]),
                 array([[0.89530255, 0.10469745], [0.69373919, 0.30626081]])),
                (array([1, 2]), array([1, 1]),
                 array([[0.75015218, 0.24984782], [0.9339767, 0.0660233]])),
                (array([1, 2]), array([1, 2]),
                 array([[0.97826616, 0.02173384], [0.38620638, 0.61379362]])),
                (array([2]), array([2]), array([[0.46893776, 0.53106224]]))]
        try:
            cm = ConfusionMatrix(sets=sets)
        except:
            self.fail()

        if externals.exists("pylab plottable"):
            import pylab as pl
            #pl.figure()
            #print cm
            fig, im, cb = cm.plot(origin='lower', numbers=True)
            #pl.plot()
            self.assertTrue((cm._plotted_confusionmatrix == cm.matrix).all())
            pl.close(fig)
Пример #5
0
    def test_confusion_call(self):
        # Also tests for the consistency of the labels as
        # either provided or collected by ConfusionMatrix through its lifetime
        self.assertRaises(RuntimeError, ConfusionMatrix(), [1], [1])
        self.assertRaises(ValueError, ConfusionMatrix(labels=[2]), [1], [1])
        # Now lets test proper matrix and either we obtain the same
        t = ['ho', 'ho', 'ho', 'fa', 'fa', 'ho', 'ho']
        p = ['ho', 'ho', 'ho', 'ho', 'fa', 'fa', 'fa']
        cm1 = ConfusionMatrix(labels=['ho', 'fa'])
        cm2 = ConfusionMatrix(labels=['fa', 'ho'])
        assert_array_equal(cm1(p, t), [[3, 1], [2, 1]])
        assert_array_equal(cm2(p, t),
                           [[1, 2], [1, 3]])  # reverse order of labels

        cm1_ = ConfusionMatrix(labels=['ho', 'fa'], sets=[(t, p)])
        assert_array_equal(cm1(p, t), cm1_.matrix)  # both should be identical
        # Lets provoke "mother" CM to get to know more labels which could get ahead
        # of the known ones
        cm1.add(['ho', 'aa'], ['ho', 'aa'])
        # compare and cause recomputation so .__labels get reassigned
        assert_equal(cm1.labels, ['ho', 'fa', 'aa'])
        assert_array_equal(cm1(p, t), [[3, 1, 0], [2, 1, 0], [0, 0, 0]])
        assert_equal(len(cm1.sets),
                     1)  # just 1 must be known atm from above add
        assert_array_equal(cm1(p, t, store=True),
                           [[3, 1, 0], [2, 1, 0], [0, 0, 0]])
        assert_equal(len(cm1.sets), 2)  # and now 2
        assert_array_equal(cm1(p + ['ho', 'aa'], t + ['ho', 'aa']), cm1.matrix)
Пример #6
0
    def test_confusion_matrix_addition(self):
        """Test confusions addition inconsistent results (GH #51)

        Was fixed by deepcopying instead of copying in __add__
        """
        cm1 = ConfusionMatrix(sets=[[np.array((1, 2)), np.array((1, 2))]])
        cm2 = ConfusionMatrix(sets=[[np.array((3, 2)), np.array((3, 2))]])
        assert_array_equal(cm1.stats['P'], [1, 1])
        assert_array_equal(cm2.stats['P'], [1, 1])

        # actual bug scenario -- results would be different
        r1 = (cm1 + cm2).stats['P']
        r2 = (cm1 + cm2).stats['P']
        assert_array_equal(r1, r2)
        assert_array_equal(r1, [1, 2, 1])
Пример #7
0
def test_conditional_attr():
    import copy
    import cPickle
    for node in (TestNodeOnDefault(enable_ca=['test', 'stats']),
                 TestNodeOffDefault(enable_ca=['test', 'stats'])):
        node.ca.test = range(5)
        node.ca.stats = ConfusionMatrix(labels=['one', 'two'])
        node.ca.stats.add(('one', 'two', 'one', 'two'),
                          ('one', 'two', 'two', 'one'))
        node.ca.stats.compute()

        dc_node = copy.deepcopy(node)
        assert_equal(set(node.ca.enabled), set(dc_node.ca.enabled))
        assert (node.ca['test'].enabled)
        assert (node.ca['stats'].enabled)
        assert_array_equal(node.ca['test'].value, dc_node.ca['test'].value)
        assert_array_equal(node.ca['stats'].value.matrix,
                           dc_node.ca['stats'].value.matrix)

        # check whether values survive pickling
        pickled = cPickle.dumps(node)
        up_node = cPickle.loads(pickled)
        assert_array_equal(up_node.ca['test'].value, range(5))
        assert_array_equal(up_node.ca['stats'].value.matrix,
                           node.ca['stats'].value.matrix)
Пример #8
0
 def test_confusion_matrix_acc(self):
     reg = [0, 0, 1, 1]
     regl = [1, 0, 1, 0]
     cm = ConfusionMatrix(targets=reg, predictions=regl)
     self.assertTrue('ACC%         50' in str(cm))
     skip_if_no_external('scipy')
     self.assertTrue(cm.stats['CHI^2'] == (0., 1.))
Пример #9
0
    def test_degenerate_confusion(self):
        # We must not just puke -- some testing splits might
        # have just a single target label

        for orig in ([1], [1, 1], [0], [0, 0]):
            cm = ConfusionMatrix(targets=orig, predictions=orig, estimates=orig)

            scm = str(cm)
            self.assertTrue(cm.stats['ACC%'] == 100)
Пример #10
0
    def test_confusion_plot2(self):

        array = np.array
        uint8 = np.uint8
        sets = [(array([1, 2]), array([1, 1]),
                 array([[ 0.54343765,  0.45656235],
                        [ 0.92395853,  0.07604147]])),
                (array([1, 2]), array([1, 1]),
                 array([[ 0.98030832,  0.01969168],
                        [ 0.78998763,  0.21001237]])),
                (array([1, 2]), array([1, 1]),
                 array([[ 0.86125263,  0.13874737],
                        [ 0.83674113,  0.16325887]])),
                (array([1, 2]), array([1, 1]),
                 array([[ 0.57870383,  0.42129617],
                        [ 0.59702509,  0.40297491]])),
                (array([1, 2]), array([1, 1]),
                 array([[ 0.89530255,  0.10469745],
                        [ 0.69373919,  0.30626081]])),
                (array([1, 2]), array([1, 1]),
                 array([[ 0.75015218,  0.24984782],
                        [ 0.9339767 ,  0.0660233 ]])),
                (array([1, 2]), array([1, 2]),
                 array([[ 0.97826616,  0.02173384],
                        [ 0.38620638,  0.61379362]])),
                (array([2]), array([2]),
                 array([[ 0.46893776,  0.53106224]]))]
        try:
            cm = ConfusionMatrix(sets=sets)
        except:
            self.fail()

        if externals.exists("pylab plottable"):
            import pylab as pl
            #pl.figure()
            #print cm
            fig, im, cb = cm.plot(origin='lower', numbers=True)
            #pl.plot()
            self.assertTrue((cm._plotted_confusionmatrix == cm.matrix).all())
            pl.close(fig)
Пример #11
0
 def test_confusion_matrix_with_mappings(self):
     data = np.array([1, 2, 1, 2, 2, 2, 3, 2, 1], ndmin=2).T
     reg = [1, 1, 1, 2, 2, 2, 3, 3, 3]
     regl = [1, 2, 1, 2, 2, 2, 3, 2, 1]
     correct_cm = [[2, 0, 1], [1, 3, 1], [0, 0, 1]]
     lm = {'apple': 1, 'orange': 2, 'shitty apple': 1, 'candy': 3}
     cm = ConfusionMatrix(targets=reg, predictions=regl, labels_map=lm)
     # check table content
     self.assertTrue((cm.matrix == correct_cm).all())
     # assure that all labels are somewhere listed ;-)
     s = str(cm)
     for l in lm.keys():
         self.assertTrue(l in s)
Пример #12
0
    def test_confusion_plot(self):
        """Basic test of confusion plot

        Based on existing cell dataset results.

        Let in for possible future testing, but is not a part of the
        unittests suite
        """
        #from matplotlib import rc as rcmpl
        #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']})
        ##rcmpl('text', usetex=True)
        ##rcmpl('font',  family='sans', style='normal', variant='normal',
        ##   weight='bold',  stretch='normal', size='large')
        #import numpy as np
        #from mvpa2.clfs.transerror import \
        #     TransferError, ConfusionMatrix, ConfusionBasedError

        array = np.array
        uint8 = np.uint8
        sets = [
           (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8),
            array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46,
                 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38,
                 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39,
                 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45,
                 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46,
                 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41,
                 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40,
                 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41,
                 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44,
                 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43,
                 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)),
           (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8),
            array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40,
                 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40,
                 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46,
                 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41,
                 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
                 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44,
                 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43,
                 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39,
                 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)),
           (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8),
            array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40,
                 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45,
                 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
                 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40,
                 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38,
                 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42,
                 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44,
                 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38,
                 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39,
                 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40,
                 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)),
           (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8),
            array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39,
                 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40,
                 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42,
                 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39,
                 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43,
                 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47,
                 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46,
                 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
                 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)),
           (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
            array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45,
                 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
                 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41,
                 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
                 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42,
                 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43,
                 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46,
                 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38,
                 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45,
                 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40,
                 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)),
           (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8),
            array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40,
                 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38,
                 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
                 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39,
                 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44,
                 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43,
                 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40,
                 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40,
                 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45,
                 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39,
                 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46,
                 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)),
           (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8),
            array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46,
                 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38,
                 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42,
                 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39,
                 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40,
                 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39,
                 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46,
                 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41,
                 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)),
           (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
            array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41,
                 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46,
                 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42,
                 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44,
                 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43,
                 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47,
                 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41,
                 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45,
                 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38,
                 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46,
                 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41,
                 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38,
                 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))]
        labels_map = {'12kHz': 40,
                      '20kHz': 41,
                      '30kHz': 42,
                      '3kHz': 38,
                      '7kHz': 39,
                      'song1': 43,
                      'song2': 44,
                      'song3': 45,
                      'song4': 46,
                      'song5': 47}
        try:
            cm = ConfusionMatrix(sets=sets, labels_map=labels_map)
        except:
            self.fail()

        cms = str(cm)
        self.assertTrue('3kHz / 38' in cms)
        if externals.exists("scipy"):
            self.assertTrue('ACC(i) = 0.82-0.012*i p=0.12 r=-0.59 r^2=0.35' in cms)

        if externals.exists("pylab plottable"):
            import pylab as pl
            pl.figure()
            labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None,
                            "song1","song2","song3","song4","song5")
            #print cm
            #fig, im, cb = cm.plot(origin='lower', labels=labels_order)
            fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1]
                                         + labels_order[2:], numbers=True)
            self.assertTrue(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1])
            self.assertTrue(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0])
            self.assertTrue(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0])
            self.assertTrue(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1])
            pl.close(fig)
            fig, im, cb = cm.plot(labels=labels_order, numbers=True)
            pl.close(fig)
Пример #13
0
    def test_confusion_matrix(self):
        data = np.array([1, 2, 1, 2, 2, 2, 3, 2, 1], ndmin=2).T
        reg = [1, 1, 1, 2, 2, 2, 3, 3, 3]
        regl = [1, 2, 1, 2, 2, 2, 3, 2, 1]
        correct_cm = [[2, 0, 1], [1, 3, 1], [0, 0, 1]]
        # Check if we are ok with any input type - either list, or np.array, or tuple
        for t in [reg, tuple(reg), list(reg), np.array(reg)]:
            for p in [regl, tuple(regl), list(regl), np.array(regl)]:
                cm = ConfusionMatrix(targets=t, predictions=p)
                # check table content
                self.assertTrue((cm.matrix == correct_cm).all())

        # Do a bit more thorough checking
        cm = ConfusionMatrix()
        self.assertRaises(ZeroDivisionError, lambda x: x.percent_correct, cm)
        """No samples -- raise exception"""

        cm.add(reg, regl)

        self.assertEqual(len(cm.sets),
                         1,
                         msg="Should have a single set so far")
        self.assertEqual(
            cm.matrix.shape, (3, 3),
            msg="should be square matrix (len(reglabels) x len(reglabels)")

        self.assertRaises(ValueError, cm.add, reg, np.array([1]))
        """ConfusionMatrix must complaint if number of samples different"""

        # check table content
        self.assertTrue((cm.matrix == correct_cm).all())

        # lets add with new labels (not yet known)
        cm.add(reg, np.array([1, 4, 1, 2, 2, 2, 4, 2, 1]))

        self.assertEqual(cm.labels, [1, 2, 3, 4],
                         msg="We should have gotten 4th label")

        matrices = cm.matrices  # separate CM per each given set
        self.assertEqual(len(matrices), 2, msg="Have gotten two splits")

        self.assertTrue(
            (matrices[0].matrix + matrices[1].matrix == cm.matrix).all(),
            msg="Total votes should match the sum across split CMs")

        # check pretty print
        # just a silly test to make sure that printing works
        self.assertTrue(
            len(cm.as_string(header=True, summary=True, description=True)) >
            100)
        self.assertTrue(len(str(cm)) > 100)
        # and that it knows some parameters for printing
        self.assertTrue(len(cm.as_string(summary=True, header=False)) > 100)

        # lets check iadd -- just itself to itself
        cm += cm
        self.assertEqual(len(cm.matrices), 4, msg="Must be 4 sets now")

        # lets check add -- just itself to itself
        cm2 = cm + cm
        self.assertEqual(len(cm2.matrices), 8, msg="Must be 8 sets now")
        self.assertEqual(cm2.percent_correct,
                         cm.percent_correct,
                         msg="Percent of corrrect should remain the same ;-)")

        self.assertEqual(cm2.error,
                         1.0 - cm.percent_correct / 100.0,
                         msg="Test if we get proper error value")
Пример #14
0
    def test_confusion_plot(self):
        """Basic test of confusion plot

        Based on existing cell dataset results.

        Let in for possible future testing, but is not a part of the
        unittests suite
        """
        #from matplotlib import rc as rcmpl
        #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']})
        ##rcmpl('text', usetex=True)
        ##rcmpl('font',  family='sans', style='normal', variant='normal',
        ##   weight='bold',  stretch='normal', size='large')
        #import numpy as np
        #from mvpa2.clfs.transerror import \
        #     TransferError, ConfusionMatrix, ConfusionBasedError

        array = np.array
        uint8 = np.uint8
        sets = [(array([
            47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
            40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
            45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
            39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
            46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
            41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
            38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
            42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
            44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
            43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
            47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
            40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
            45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44
        ],
                       dtype=uint8),
                 array([
                     40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43,
                     45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40,
                     46, 42, 47, 39, 38, 46, 45, 38, 44, 39, 46, 38, 39, 39,
                     38, 43, 45, 41, 44, 40, 46, 42, 38, 40, 47, 43, 45, 41,
                     44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 46, 42,
                     47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43,
                     45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40,
                     46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                     47, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41,
                     44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 47, 46, 42,
                     47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 43,
                     45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                     46, 41, 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38,
                     38, 43, 45, 41, 44, 40, 45, 42, 47, 39, 39, 43, 45, 41,
                     44, 38, 46, 42, 47, 38, 42, 43, 45, 41, 44, 39, 46, 42,
                     39, 39, 39, 47, 45, 41, 44
                 ],
                       dtype=uint8)),
                (array([
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43
                ],
                       dtype=uint8),
                 array([
                     40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47,
                     39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                     41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 39, 46,
                     42, 47, 47, 47, 43, 45, 41, 44, 40, 46, 42, 43, 39, 38,
                     43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 41, 44,
                     40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47,
                     40, 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45,
                     41, 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46,
                     42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39, 42,
                     43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44,
                     47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47,
                     39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45,
                     41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46,
                     46, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39,
                     43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 47, 45, 41, 44,
                     40, 46, 42, 47, 39, 38, 43
                 ],
                       dtype=uint8)),
                (array([
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47
                ],
                       dtype=uint8),
                 array([
                     45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40,
                     46, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38,
                     39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41,
                     44, 40, 46, 42, 47, 39, 43, 43, 45, 40, 44, 40, 46, 42,
                     47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 40, 43,
                     45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40,
                     40, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                     38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 39, 43, 45, 41,
                     44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 38,
                     38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 40, 39, 43,
                     45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 40,
                     40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39,
                     38, 38, 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38,
                     44, 40, 46, 42, 39, 39, 45, 43, 45, 41, 44, 38, 46, 42,
                     38, 39, 39, 43, 45, 41, 38, 40, 46, 42, 47, 38, 39, 43,
                     45, 41, 44, 40, 46, 42, 40
                 ],
                       dtype=uint8)),
                (array([
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
                    39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                    44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40
                ],
                       dtype=uint8),
                 array([
                     39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45,
                     41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46,
                     42, 47, 39, 38, 43, 45, 41, 44, 40, 38, 43, 47, 38, 38,
                     43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44,
                     43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47,
                     39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45,
                     40, 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                     42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38,
                     43, 45, 47, 44, 45, 46, 42, 38, 39, 41, 43, 45, 41, 44,
                     38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 40, 46, 42, 47,
                     39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 45,
                     41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46,
                     42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43,
                     45, 46, 44, 38, 46, 42, 47, 38, 44, 43, 45, 42, 44, 41,
                     46, 42, 47, 47, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39,
                     38, 43, 45, 41, 44, 40
                 ],
                       dtype=uint8)),
                (array([
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
                    43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
                    46, 42, 47, 39, 38, 43, 45
                ],
                       dtype=uint8),
                 array([
                     46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
                     42, 43, 45, 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41,
                     44, 40, 46, 42, 47, 47, 40, 43, 45, 41, 44, 41, 46, 42,
                     38, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43,
                     45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 44, 39,
                     46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                     39, 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41,
                     44, 39, 46, 42, 47, 39, 46, 43, 45, 41, 44, 39, 46, 42,
                     39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43,
                     45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40,
                     38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 46,
                     38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41,
                     44, 38, 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42,
                     47, 38, 39, 47, 45, 46, 44, 40, 46, 42, 47, 40, 38, 43,
                     45, 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 38,
                     46, 42, 38, 39, 38, 47, 45
                 ],
                       dtype=uint8)),
                (array([
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
                    47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
                    41, 44, 40, 46, 42, 47, 39
                ],
                       dtype=uint8),
                 array([
                     41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46,
                     42, 38, 40, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38,
                     43, 45, 41, 44, 46, 38, 42, 40, 38, 39, 43, 45, 41, 44,
                     41, 46, 42, 47, 47, 38, 43, 45, 41, 44, 40, 46, 42, 38,
                     39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 43, 45,
                     41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46,
                     42, 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39,
                     43, 45, 41, 44, 40, 46, 42, 39, 38, 47, 43, 45, 38, 44,
                     40, 38, 42, 47, 38, 38, 43, 45, 41, 44, 40, 38, 46, 47,
                     38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 38, 38, 40, 45,
                     41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 40,
                     42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47,
                     43, 45, 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44,
                     39, 43, 42, 47, 39, 40, 43, 45, 41, 44, 42, 46, 42, 47,
                     40, 46, 43, 45, 41, 44, 38, 46, 42, 47, 47, 38, 43, 45,
                     41, 44, 40, 38, 39, 47, 38
                 ],
                       dtype=uint8)),
                (array([
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
                    40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
                    38, 43, 45, 41, 44, 40, 46
                ],
                       dtype=uint8),
                 array([
                     39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41,
                     44, 41, 46, 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42,
                     47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 40, 43,
                     45, 41, 44, 40, 46, 42, 47, 45, 38, 43, 45, 41, 44, 38,
                     46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 39,
                     38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
                     44, 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42,
                     40, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                     45, 41, 44, 38, 46, 42, 39, 39, 47, 43, 45, 41, 44, 40,
                     46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 46, 47,
                     39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41,
                     44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42,
                     47, 39, 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43,
                     47, 40, 44, 40, 46, 42, 39, 39, 38, 39, 45, 41, 44, 40,
                     46, 42, 47, 38, 38, 43, 45, 41, 44, 46, 38, 42, 47, 39,
                     43, 43, 45, 41, 44, 40, 46
                 ],
                       dtype=uint8)),
                (array([
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
                    45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
                    42, 47, 39, 38, 43, 45
                ],
                       dtype=uint8),
                 array([
                     42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38,
                     43, 45, 41, 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44,
                     40, 41, 42, 47, 38, 46, 43, 45, 41, 44, 41, 41, 42, 40,
                     39, 39, 43, 45, 41, 44, 46, 45, 42, 39, 39, 40, 43, 45,
                     41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 39, 46,
                     42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39,
                     43, 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44,
                     39, 46, 42, 47, 41, 38, 43, 45, 41, 44, 42, 46, 42, 46,
                     39, 38, 43, 45, 41, 44, 41, 46, 42, 46, 39, 38, 43, 45,
                     41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 38, 46,
                     42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 38,
                     43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44,
                     40, 46, 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41,
                     39, 39, 43, 45, 41, 44, 40, 38, 42, 40, 39, 46, 43, 45,
                     41, 44, 47, 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46,
                     42, 43, 39, 39, 43, 45
                 ],
                       dtype=uint8))]
        labels_map = {
            '12kHz': 40,
            '20kHz': 41,
            '30kHz': 42,
            '3kHz': 38,
            '7kHz': 39,
            'song1': 43,
            'song2': 44,
            'song3': 45,
            'song4': 46,
            'song5': 47
        }
        try:
            cm = ConfusionMatrix(sets=sets, labels_map=labels_map)
        except:
            self.fail()

        cms = str(cm)
        self.assertTrue('3kHz / 38' in cms)
        if externals.exists("scipy"):
            self.assertTrue(
                'ACC(i) = 0.82-0.012*i p=0.12 r=-0.59 r^2=0.35' in cms)

        if externals.exists("pylab plottable"):
            import pylab as pl
            pl.figure()
            labels_order = ("3kHz", "7kHz", "12kHz", "20kHz", "30kHz", None,
                            "song1", "song2", "song3", "song4", "song5")
            #print cm
            #fig, im, cb = cm.plot(origin='lower', labels=labels_order)
            fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] +
                                  labels_order[2:],
                                  numbers=True)
            self.assertTrue(cm._plotted_confusionmatrix[0, 0] == cm.matrix[1,
                                                                           1])
            self.assertTrue(cm._plotted_confusionmatrix[0, 1] == cm.matrix[1,
                                                                           0])
            self.assertTrue(cm._plotted_confusionmatrix[1, 1] == cm.matrix[0,
                                                                           0])
            self.assertTrue(cm._plotted_confusionmatrix[1, 0] == cm.matrix[0,
                                                                           1])
            pl.close(fig)
            fig, im, cb = cm.plot(labels=labels_order, numbers=True)
            pl.close(fig)
Пример #15
0
def _group_transfer_learning(path,
                             subjects,
                             analysis,
                             conf_file,
                             source='task',
                             analysis_type='single',
                             **kwargs):

    if source == 'task':
        target = 'rest'
    else:
        if source == 'rest':
            target = 'task'

    if source == 'saccade':
        target = 'face'
    else:
        if source == 'face':
            target = 'saccade'

    ##############################################
    ##############################################
    ##    conf_src['label_included'] = 'all'    ##
    ##    conf_src['label_dropped'] = 'none'    ##
    ##    conf_src['mean_samples'] = 'False'    ##
    ##############################################
    ##############################################

    if analysis_type == 'group':

        if path.__class__ == conf_file.__class__ == list:
            ds_src, s, conf_src = sources_merged_ds(path, subjects, conf_file,
                                                    source, **kwargs)

            conf_src['permutations'] = 0

        else:
            print 'In group analysis path, subjects and conf_file must be lists: \
                    Check configuration file and/or parameters!!'

            return 0

    else:

        conf_src = read_configuration(path, conf_file, source)

    for arg in conf_src:
        if arg == 'map_list':
            map_list = conf_src[arg].split(',')

    r_group = spatial(ds_src, **conf_src)

    total_results = dict()
    total_results['group'] = r_group

    clf = r_group['classifier']

    for subj_, conf_, path_ in zip(subjects, conf_file, path):
        for subj in subj_:
            print '-----------'
            r = dict()
            if len(subj_) > 1:
                conf_tar = read_configuration(path_, conf_, target)

                for arg in kwargs:

                    conf_tar[arg] = kwargs[arg]

                data_path = conf_tar['data_path']
                try:
                    ds_tar = load_dataset(data_path, subj, target, **conf_tar)
                except Exception, err:
                    print err
                    continue

            ds_tar = detrend_dataset(ds_tar, target, **conf_tar)

            if conf_src['label_included'] == 'all' and \
               conf_src['label_dropped'] != 'fixation':
                print 'Balancing dataset...'
                ds_src = balance_dataset_timewise(ds_src, 'fixation')

            predictions = clf.predict(ds_tar)

            pred = np.array(predictions)
            targets = ds_tar.targets

            for arg in r_group.keys():
                r[arg] = copy.copy(r_group[arg])

            r['targets'] = targets
            r['predictions'] = predictions

            r['fclf'] = clf

            c_m = ConfusionMatrix(predictions=pred, targets=targets)
            c_m.compute()
            r['confusion_target'] = c_m
            print c_m

            tr_pred = similarity_measure_mahalanobis(ds_tar, ds_src, r)
            r['mahalanobis_similarity'] = tr_pred

            #print tr_pred

            c_mat_mahala = ConfusionMatrix(predictions=tr_pred.T[1],
                                           targets=tr_pred.T[0])
            c_mat_mahala.compute()
            r['confusion_mahala'] = c_mat_mahala

            d_prime, beta, c, c_new = signal_detection_measures(
                pred, targets, map_list)
            r['d_prime'] = d_prime
            print d_prime
            r['beta'] = beta
            r['c'] = c
            r['confusion_total'] = c_new
            '''
            d_prime_maha, c_new_maha = d_prime_statistics(tr_pred.T[1], tr_pred.T[0], map_list)
            r['d_prime_maha'] = d_prime_maha
            r['confusion_tot_maha'] = c_new_maha
            '''

            total_results[subj] = r
Пример #16
0
    '''
    Get some data
    '''

    dt = np.dtype([('labels', np.str_, 20), ('predictions', np.str_, 20),
                   ('distances', np.float32), ('pvalues', np.float32)])

    full_data = np.array(zip(ds_tar.targets, classifier_prediction_tar,
                             mahalanobis_values, p_values),
                         dtype=dt)

    classifier_prediction_tar = np.array(classifier_prediction_tar)
    ds_tar_targets = np.array(ds_tar.targets)

    c_mat_mahala = ConfusionMatrix(
        predictions=classifier_prediction_tar[similarity_mask],
        targets=ds_tar_targets[similarity_mask])
    c_mat_mahala.compute()

    header = [
        'similarity_data', 'similarity_mask', 'threshold_value', 'pvalues',
        'distances', 'confusion_mahalanobis'
    ]
    result_data = [
        full_data, similarity_mask, threshold, p_values, distances,
        c_mat_mahala
    ]
    results = dict(zip(header, result_data))

    return results
Пример #17
0
    '''
    
    dt = np.dtype([('labels', np.str_, 20), 
                   ('predictions', np.str_, 20),
                   ('distances', np.float32),
                   ('pvalues', np.float32)])
    
    full_data = np.array(zip(ds_tar.targets, 
                             classifier_prediction_tar, 
                             mahalanobis_values, 
                             p_values), dtype=dt)
    
    classifier_prediction_tar = np.array(classifier_prediction_tar)
    ds_tar_targets = np.array(ds_tar.targets)
    
    c_mat_mahala = ConfusionMatrix(predictions=classifier_prediction_tar[similarity_mask], 
                                    targets=ds_tar_targets[similarity_mask])
    c_mat_mahala.compute()
    
    header = ['similarity_data', 'similarity_mask', 'threshold_value', 
              'pvalues', 'distances', 'confusion_mahalanobis']
    result_data = [full_data, similarity_mask, threshold, 
                   p_values, distances, c_mat_mahala]
    results = dict(zip(header, result_data))
    
    return results
    
def similarity_measure_correlation (ds_tar, ds_src, results, p_value):


    print 'Computing Mahalanobis similarity...'
    #classifier = results['classifier']
Пример #18
0
    def test_partial_searchlight_with_confusion_matrix(self):
        ds = self.dataset
        from mvpa2.clfs.stats import MCNullDist
        from mvpa2.mappers.fx import mean_sample, sum_sample

        # compute N-1 cross-validation for each sphere
        cm = ConfusionMatrix(labels=ds.UT)
        cv = CrossValidation(
            sample_clf_lin,
            NFoldPartitioner(),
            # we have to assure that matrix does not get flatted by
            # first vstack in cv and then hstack in searchlight --
            # thus 2 leading dimensions
            # TODO: RF? make searchlight/crossval smarter?
            errorfx=lambda *a: cm(*a)[None, None, :])
        # contruct diameter 2 (or just radius 1) searchlight
        sl = sphere_searchlight(cv, radius=1, center_ids=[3, 5, 50])

        # our regular searchlight -- to compare results
        cv_gross = CrossValidation(sample_clf_lin, NFoldPartitioner())
        sl_gross = sphere_searchlight(cv_gross,
                                      radius=1,
                                      center_ids=[3, 5, 50])

        # run searchlights
        res = sl(ds)
        res_gross = sl_gross(ds)

        # only two spheres but error for all CV-folds and complete confusion matrix
        assert_equal(res.shape, (len(ds.UC), 3, len(ds.UT), len(ds.UT)))
        assert_equal(res_gross.shape, (len(ds.UC), 3))

        # briefly inspect the confusion matrices
        mat = res.samples
        # since input dataset is probably balanced (otherwise adjust
        # to be per label): sum within columns (thus axis=-2) should
        # be identical to per-class/chunk number of samples
        samples_per_classchunk = len(ds) / (len(ds.UT) * len(ds.UC))
        ok_(np.all(np.sum(mat, axis=-2) == samples_per_classchunk))
        # and if we compute accuracies manually -- they should
        # correspond to the one from sl_gross
        assert_array_almost_equal(
            res_gross.samples,
            # from accuracies to errors
            1 - (mat[..., 0, 0] + mat[..., 1, 1]).astype(float) /
            (2 * samples_per_classchunk))

        # and now for those who remained sited -- lets perform H0 MC
        # testing of this searchlight... just a silly one with minimal
        # number of permutations
        no_permutations = 10
        permutator = AttributePermutator('targets', count=no_permutations)

        # once again -- need explicit leading dimension to avoid
        # vstacking during cross-validation
        cv.postproc = lambda x: sum_sample()(x)[None, :]

        sl = sphere_searchlight(cv,
                                radius=1,
                                center_ids=[3, 5, 50],
                                null_dist=MCNullDist(
                                    permutator,
                                    tail='right',
                                    enable_ca=['dist_samples']))
        res_perm = sl(ds)
        # XXX all of the res_perm, sl.ca.null_prob and
        #     sl.null_dist.ca.dist_samples carry a degenerate leading
        #     dimension which was probably due to introduced new axis
        #     above within cv.postproc
        assert_equal(res_perm.shape, (1, 3, 2, 2))
        assert_equal(sl.null_dist.ca.dist_samples.shape,
                     res_perm.shape + (no_permutations, ))
        assert_equal(sl.ca.null_prob.shape, res_perm.shape)
        # just to make sure ;)
        ok_(np.all(sl.ca.null_prob.samples >= 0))
        ok_(np.all(sl.ca.null_prob.samples <= 1))

        # we should have got sums of hits across the splits
        assert_array_equal(np.sum(mat, axis=0), res_perm.samples[0])
Пример #19
0
def test_transfer_learning(path, subjects, analysis,  conf_file, source='task', \
                           analysis_type='single', calculateSimilarity='True', **kwargs):
    
    if source == 'task':
        target = 'rest'
    else:
        if source == 'rest':
            target = 'task'
    
    
    if source == 'saccade':
        target = 'face'
    else:
        if source == 'face':
            target = 'saccade'
    
    p = kwargs['p']
    ##############################################    
    ##############################################
    ##    conf_src['label_included'] = 'all'    ##   
    ##    conf_src['label_dropped'] = 'none'    ##
    ##    conf_src['mean_samples'] = 'False'    ##
    ##############################################
    ##############################################

    if analysis_type == 'group':
        
        if path.__class__ == conf_file.__class__ == list:  
            ds_src, _, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs)
            ds_tar, subjects, conf_tar = sources_merged_ds(path, subjects, conf_file, target, **kwargs)
            
            conf_src['permutations'] = 0
            conf_tar['permutations'] = 0
        else:
            print 'In group analysis path, subjects and conf_file must be lists: \
                    Check configuration file and/or parameters!!'
            return 0
    
    else:
        
        conf_src = read_configuration(path, conf_file, source)
        conf_tar = read_configuration(path, conf_file, target)
    
        for arg in kwargs:
            conf_src[arg] = kwargs[arg]
            conf_tar[arg] = kwargs[arg]
        
        
        data_path = conf_src['data_path']
    
    
    conf_src['analysis_type'] = 'transfer_learning'
    conf_src['analysis_task'] = source
    conf_src['analysis_func'] = analysis.func_name
    
    
    for arg in conf_src:
        if arg == 'map_list':
            map_list = conf_src[arg].split(',')
        if arg == 'p_dist':
            p = float(conf_src[arg])
            print p
    
    
    total_results = dict()
    
    
    
    
    summarizers = [rs.CrossDecodingSummarizer(),
                   rs.SimilaritySummarizer(),
                   rs.DecodingSummarizer(),
                   rs.SignalDetectionSummarizer(),
                   ]
    
    savers = [rs.CrossDecodingSaver(),
                   rs.SimilaritySaver(),
                   rs.DecodingSaver(),
                   rs.SignalDetectionSaver(),
                   ]
    
    collection = rs.ResultsCollection(conf_src, path, summarizers)
    
    
    for subj in subjects:
        print '-------------------'
        
        if (len(subjects) > 1) or (subj != 'group'):
            try:
                ds_src = load_dataset(data_path, subj, source, **conf_src)
                ds_tar = load_dataset(data_path, subj, target, **conf_tar)
            except Exception, err:
                print err
                continue
         
        # Evaluate if is correct to do further normalization after merging two ds. 
        ds_src = detrend_dataset(ds_src, source, **conf_src)
        ds_tar = detrend_dataset(ds_tar, target, **conf_tar) 
        
        if conf_src['label_included'] == 'all' and \
           conf_src['label_dropped'] != 'fixation':
                print 'Balancing dataset...'
                ds_src = balance_dataset_timewise(ds_src, 'fixation')        
        
        # Make cross-decoding
        r = transfer_learning(ds_src, ds_tar, analysis, **conf_src)
        
        
        
        
        # Now we have cross-decoding results we could process it
        pred = np.array(r['classifier'].ca.predictions)

        targets = r['targets']
        
        c_m = ConfusionMatrix(predictions=pred, targets=targets)
        c_m.compute()
        r['confusion_target'] = c_m
        
        c_new = cross_decoding_confusion(pred, targets, map_list)
        r['confusion_total'] = c_new
        
        print c_new
        
        # Similarity Analysis
        if calculateSimilarity == 'True':
            if 'p' not in locals():
                print 'Ciao!'

            
            mahala_data = similarity_measure(r['ds_tar'], r['ds_src'], 
                                             r, p_value=p, method='mahalanobis')
            
            #r['mahalanobis_similarity'] = mahala_data
            for k_,v_ in mahala_data.items():
                r[k_] = v_
            r['confusion_mahala'] = mahala_data['confusion_mahalanobis']
        
        else:
            #r['mahalanobis_similarity'] = []
            r['confusion_mahala'] = 'Null'
        
        # Signal Detection Theory Analysis
        sdt_res = signal_detection_measures(c_new)
        
        for k_,v_ in sdt_res.items():
            r[k_] = v_
            
            '''
            Same code of:
        
            r['d_prime'] = d_prime
            r['beta'] = beta
            r['c'] = c
            '''
        
        total_results[subj] = r
        subj_result = rs.SubjectResult(subj, r, savers=savers)
        
        collection.add(subj_result)
Пример #20
0
    def test_gideon_weird_case(self):
        """Test if MappedClassifier could handle a mapper altering number of samples

        'The utter collapse' -- communicated by Peter J. Kohler

        Desire to collapse all samples per each category in training
        and testing sets, thus resulting only in a single
        sample/category per training and per testing.

        It is a peculiar scenario which pin points the problem that so
        far mappers assumed not to change number of samples
        """
        from mvpa2.mappers.fx import mean_group_sample
        from mvpa2.clfs.knn import kNN
        from mvpa2.mappers.base import ChainMapper
        ds = datasets['uni2large'].copy()
        #ds = ds[ds.sa.chunks < 9]
        accs = []
        k = 1  # for kNN
        nf = 1  # for NFoldPartitioner
        for i in xrange(1):  # # of random runs
            ds.samples = np.random.randn(*ds.shape)
            #
            # There are 3 ways to accomplish needed goal
            #

            # 0. Hard way: overcome the problem by manually
            #    pre-splitting/meaning in a loop
            from mvpa2.clfs.transerror import ConfusionMatrix
            partitioner = NFoldPartitioner(nf)
            meaner = mean_group_sample(['targets', 'partitions'])
            cm = ConfusionMatrix()
            te = TransferMeasure(kNN(k),
                                 Splitter('partitions'),
                                 postproc=BinaryFxNode(mean_mismatch_error,
                                                       'targets'),
                                 enable_ca=['stats'])
            errors = []
            for part in partitioner.generate(ds):
                ds_meaned = meaner(part)
                errors.append(np.asscalar(te(ds_meaned)))
                cm += te.ca.stats
            #print i, cm.stats['ACC']
            accs.append(cm.stats['ACC'])

            if False:  # not yet working -- see _tent/allow_ch_nsamples
                # branch for attempt to make it work
                # 1. This is a "native way" IF we allow change of number
                #    of samples via _call to be done by MappedClassifier
                #    while operating solely on the mapped dataset
                clf2 = MappedClassifier(
                    clf=kNN(k),  #clf,
                    mapper=mean_group_sample(['targets', 'partitions']))
                cv = CrossValidation(clf2,
                                     NFoldPartitioner(nf),
                                     postproc=None,
                                     enable_ca=['stats'])
                # meaning all should be ok since we should have ballanced
                # sets across all chunks here
                errors_native = cv(ds)

                self.assertEqual(
                    np.max(np.abs(errors_native.samples[:, 0] - errors)), 0)

            # 2. Work without fixes to MappedClassifier allowing
            #    change of # of samples
            #
            # CrossValidation will operate on a chain mapper which
            # would perform necessary meaning first before dealing with
            # kNN cons: .stats would not be exposed since ChainMapper
            # doesn't expose them from ChainMapper (yet)
            if __debug__ and 'ENFORCE_CA_ENABLED' in debug.active:
                raise SkipTest("Known to fail while trying to enable "
                               "training_stats for the ChainMapper")
            cv2 = CrossValidation(ChainMapper(
                [mean_group_sample(['targets', 'partitions']),
                 kNN(k)],
                space='targets'),
                                  NFoldPartitioner(nf),
                                  postproc=None)
            errors_native2 = cv2(ds)

            self.assertEqual(
                np.max(np.abs(errors_native2.samples[:, 0] - errors)), 0)

            # All of the ways should provide the same results
            #print i, np.max(np.abs(errors_native.samples[:,0] - errors)), \
            #      np.max(np.abs(errors_native2.samples[:,0] - errors))

        if False:  # just to investigate the distribution if we have enough iterations
            import pylab as pl
            uaccs = np.unique(accs)
            step = np.asscalar(np.unique(np.round(uaccs[1:] - uaccs[:-1], 4)))
            bins = np.linspace(0., 1., np.round(1. / step + 1))
            xx = pl.hist(accs, bins=bins, align='left')
            pl.xlim((0. - step / 2, 1. + step / 2))
Пример #21
0
def _group_transfer_learning(path, subjects, analysis,  conf_file, source='task', analysis_type='single', **kwargs):
    
    if source == 'task':
        target = 'rest'
    else:
        if source == 'rest':
            target = 'task'
    
    
    if source == 'saccade':
        target = 'face'
    else:
        if source == 'face':
            target = 'saccade'
    
   
    ##############################################    
    ##############################################
    ##    conf_src['label_included'] = 'all'    ##   
    ##    conf_src['label_dropped'] = 'none'    ##
    ##    conf_src['mean_samples'] = 'False'    ##
    ##############################################
    ##############################################

    if analysis_type == 'group':
        
        if path.__class__ == conf_file.__class__ == list:  
            ds_src, s, conf_src = sources_merged_ds(path, subjects, conf_file, source, **kwargs)
            
            conf_src['permutations'] = 0
            
        else:
            print 'In group analysis path, subjects and conf_file must be lists: \
                    Check configuration file and/or parameters!!'
            return 0
    
    else:
        
        conf_src = read_configuration(path, conf_file, source)
        
    
    
    for arg in conf_src:
        if arg == 'map_list':
            map_list = conf_src[arg].split(',')
    
    
    r_group = spatial(ds_src, **conf_src)
    
    total_results = dict()
    total_results['group'] = r_group
    
    clf = r_group['classifier']
    
    for subj_, conf_, path_ in zip(subjects, conf_file, path):
        for subj in subj_:
            print '-----------'
            r = dict()
            if len(subj_) > 1:
                conf_tar = read_configuration(path_, conf_, target)
        
                for arg in kwargs:
                    
                    conf_tar[arg] = kwargs[arg]
            
            
                data_path = conf_tar['data_path']
                try:
                    ds_tar = load_dataset(data_path, subj, target, **conf_tar)
                except Exception, err:
                    print err
                    continue
    
            
            ds_tar = detrend_dataset(ds_tar, target, **conf_tar) 
    
            if conf_src['label_included'] == 'all' and \
               conf_src['label_dropped'] != 'fixation':
                    print 'Balancing dataset...'
                    ds_src = balance_dataset_timewise(ds_src, 'fixation')       
            
            predictions = clf.predict(ds_tar)
           
            pred = np.array(predictions)
            targets = ds_tar.targets
            
            
            for arg in r_group.keys():
                r[arg] = copy.copy(r_group[arg])
            
            r['targets'] = targets
            r['predictions'] = predictions
            
            r['fclf'] = clf
            
            c_m = ConfusionMatrix(predictions=pred, targets=targets)
            c_m.compute()
            r['confusion_target'] = c_m
            print c_m
            
            tr_pred = similarity_measure_mahalanobis(ds_tar, ds_src, r)
            r['mahalanobis_similarity'] = tr_pred
            
            #print tr_pred
            
            c_mat_mahala = ConfusionMatrix(predictions=tr_pred.T[1], targets=tr_pred.T[0])
            c_mat_mahala.compute()
            r['confusion_mahala'] = c_mat_mahala
            
            d_prime, beta, c, c_new = signal_detection_measures(pred, targets, map_list)
            r['d_prime'] = d_prime
            print d_prime
            r['beta'] = beta
            r['c'] = c
            r['confusion_total'] = c_new
            
            '''
            d_prime_maha, c_new_maha = d_prime_statistics(tr_pred.T[1], tr_pred.T[0], map_list)
            r['d_prime_maha'] = d_prime_maha
            r['confusion_tot_maha'] = c_new_maha
            '''
            
            total_results[subj] = r
Пример #22
0
    def test_confusion_matrix(self):
        data = np.array([1,2,1,2,2,2,3,2,1], ndmin=2).T
        reg = [1,1,1,2,2,2,3,3,3]
        regl = [1,2,1,2,2,2,3,2,1]
        correct_cm = [[2,0,1],[1,3,1],[0,0,1]]
        # Check if we are ok with any input type - either list, or np.array, or tuple
        for t in [reg, tuple(reg), list(reg), np.array(reg)]:
            for p in [regl, tuple(regl), list(regl), np.array(regl)]:
                cm = ConfusionMatrix(targets=t, predictions=p)
                # check table content
                self.assertTrue((cm.matrix == correct_cm).all())


        # Do a bit more thorough checking
        cm = ConfusionMatrix()
        self.assertRaises(ZeroDivisionError, lambda x:x.percent_correct, cm)
        """No samples -- raise exception"""

        cm.add(reg, regl)

        self.assertEqual(len(cm.sets), 1,
            msg="Should have a single set so far")
        self.assertEqual(cm.matrix.shape, (3,3),
            msg="should be square matrix (len(reglabels) x len(reglabels)")

        self.assertRaises(ValueError, cm.add, reg, np.array([1]))
        """ConfusionMatrix must complaint if number of samples different"""

        # check table content
        self.assertTrue((cm.matrix == correct_cm).all())

        # lets add with new labels (not yet known)
        cm.add(reg, np.array([1,4,1,2,2,2,4,2,1]))

        self.assertEqual(cm.labels, [1,2,3,4],
                             msg="We should have gotten 4th label")

        matrices = cm.matrices          # separate CM per each given set
        self.assertEqual(len(matrices), 2,
                             msg="Have gotten two splits")

        self.assertTrue((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(),
                        msg="Total votes should match the sum across split CMs")

        # check pretty print
        # just a silly test to make sure that printing works
        self.assertTrue(len(cm.as_string(
            header=True, summary=True,
            description=True))>100)
        self.assertTrue(len(str(cm))>100)
        # and that it knows some parameters for printing
        self.assertTrue(len(cm.as_string(summary=True,
                                       header=False))>100)

        # lets check iadd -- just itself to itself
        cm += cm
        self.assertEqual(len(cm.matrices), 4, msg="Must be 4 sets now")

        # lets check add -- just itself to itself
        cm2 = cm + cm
        self.assertEqual(len(cm2.matrices), 8, msg="Must be 8 sets now")
        self.assertEqual(cm2.percent_correct, cm.percent_correct,
                             msg="Percent of corrrect should remain the same ;-)")

        self.assertEqual(cm2.error, 1.0-cm.percent_correct/100.0,
                             msg="Test if we get proper error value")
Пример #23
0
def test_transfer_learning(path, subjects, analysis,  conf_file, source='task', \
                           analysis_type='single', calculateSimilarity='True', **kwargs):

    if source == 'task':
        target = 'rest'
    else:
        if source == 'rest':
            target = 'task'

    if source == 'saccade':
        target = 'face'
    else:
        if source == 'face':
            target = 'saccade'

    p = kwargs['p']
    ##############################################
    ##############################################
    ##    conf_src['label_included'] = 'all'    ##
    ##    conf_src['label_dropped'] = 'none'    ##
    ##    conf_src['mean_samples'] = 'False'    ##
    ##############################################
    ##############################################

    if analysis_type == 'group':

        if path.__class__ == conf_file.__class__ == list:
            ds_src, _, conf_src = sources_merged_ds(path, subjects, conf_file,
                                                    source, **kwargs)
            ds_tar, subjects, conf_tar = sources_merged_ds(
                path, subjects, conf_file, target, **kwargs)

            conf_src['permutations'] = 0
            conf_tar['permutations'] = 0
        else:
            print 'In group analysis path, subjects and conf_file must be lists: \
                    Check configuration file and/or parameters!!'

            return 0

    else:

        conf_src = read_configuration(path, conf_file, source)
        conf_tar = read_configuration(path, conf_file, target)

        for arg in kwargs:
            conf_src[arg] = kwargs[arg]
            conf_tar[arg] = kwargs[arg]

        data_path = conf_src['data_path']

    conf_src['analysis_type'] = 'transfer_learning'
    conf_src['analysis_task'] = source
    conf_src['analysis_func'] = analysis.func_name

    for arg in conf_src:
        if arg == 'map_list':
            map_list = conf_src[arg].split(',')
        if arg == 'p_dist':
            p = float(conf_src[arg])
            print p

    total_results = dict()

    summarizers = [
        rs.CrossDecodingSummarizer(),
        rs.SimilaritySummarizer(),
        rs.DecodingSummarizer(),
        rs.SignalDetectionSummarizer(),
    ]

    savers = [
        rs.CrossDecodingSaver(),
        rs.SimilaritySaver(),
        rs.DecodingSaver(),
        rs.SignalDetectionSaver(),
    ]

    collection = rs.ResultsCollection(conf_src, path, summarizers)

    for subj in subjects:
        print '-------------------'

        if (len(subjects) > 1) or (subj != 'group'):
            try:
                ds_src = load_dataset(data_path, subj, source, **conf_src)
                ds_tar = load_dataset(data_path, subj, target, **conf_tar)
            except Exception, err:
                print err
                continue

        # Evaluate if is correct to do further normalization after merging two ds.
        ds_src = detrend_dataset(ds_src, source, **conf_src)
        ds_tar = detrend_dataset(ds_tar, target, **conf_tar)

        if conf_src['label_included'] == 'all' and \
           conf_src['label_dropped'] != 'fixation':
            print 'Balancing dataset...'
            ds_src = balance_dataset_timewise(ds_src, 'fixation')

        # Make cross-decoding
        r = transfer_learning(ds_src, ds_tar, analysis, **conf_src)

        # Now we have cross-decoding results we could process it
        pred = np.array(r['classifier'].ca.predictions)

        targets = r['targets']

        c_m = ConfusionMatrix(predictions=pred, targets=targets)
        c_m.compute()
        r['confusion_target'] = c_m

        c_new = cross_decoding_confusion(pred, targets, map_list)
        r['confusion_total'] = c_new

        print c_new

        # Similarity Analysis
        if calculateSimilarity == 'True':
            if 'p' not in locals():
                print 'Ciao!'

            mahala_data = similarity_measure(r['ds_tar'],
                                             r['ds_src'],
                                             r,
                                             p_value=p,
                                             method='mahalanobis')

            #r['mahalanobis_similarity'] = mahala_data
            for k_, v_ in mahala_data.items():
                r[k_] = v_
            r['confusion_mahala'] = mahala_data['confusion_mahalanobis']

        else:
            #r['mahalanobis_similarity'] = []
            r['confusion_mahala'] = 'Null'

        # Signal Detection Theory Analysis
        sdt_res = signal_detection_measures(c_new)

        for k_, v_ in sdt_res.items():
            r[k_] = v_
            '''
            Same code of:
        
            r['d_prime'] = d_prime
            r['beta'] = beta
            r['c'] = c
            '''

        total_results[subj] = r
        subj_result = rs.SubjectResult(subj, r, savers=savers)

        collection.add(subj_result)