Exemple #1
0
    def fit_dev(self, net_data, nnet_cluster='auto', nSubtypes=3):
        self.nnet_cluster = nnet_cluster
        self.nSubtypes = nSubtypes

        if nnet_cluster == 'auto':
            #self.nnet_cluster = self.getClusters(net_data)
            self.valid_cluster, self.valid_net_idx = self.get_match_network(
                net_data, nnet_cluster, algo='meanshift')
        else:
            self.valid_cluster, self.valid_net_idx = self.get_match_network(
                net_data, nnet_cluster, algo='kmeans')

        #self.valid_cluster = self.clust_list
        #self.valid_net_idx = range(len(self.valid_cluster))
        for i in range(net_data.shape[0]):
            if i == 0:
                self.assign_net = self.assigneDist(net_data[i, :, :],
                                                   self.valid_cluster,
                                                   self.valid_net_idx)
            else:
                self.assign_net = np.vstack(
                    ((self.assign_net,
                      self.assigneDist(net_data[i, :, :], self.valid_cluster,
                                       self.valid_net_idx))))
        print 'Size of the new data map: ', self.assign_net.shape
        # group subjects with the most network classifing them together
        # compute the consensus clustering
        self.consensus = cls.hclustering(self.assign_net, self.nSubtypes)
        # save the centroids in a method
        self.clf_subtypes = NearestCentroid()
        self.clf_subtypes.fit(self.assign_net, self.consensus)
        self.consensus = self.clf_subtypes.predict(self.assign_net)
        #print "score: ", self.clf_subtypes.score(self.assign_net,self.consensus)

        return self.consensus
Exemple #2
0
    def fit_network(self, net_data_low, nSubtypes=3, reshape_w=True):
        self.flag_2level = False
        self.nnet_cluster = 1
        self.nSubtypes = nSubtypes
        # self.scalers = []
        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork

        self.normalized_net_template = []
        # average template
        # self.scalers.append(preprocessing.StandardScaler())
        # net_data_low = self.scalers[-1].fit_transform(net_data_low_)
        self.normalized_net_template.append(np.mean(net_data_low, axis=0))
        # self.normalized_net_template.append(np.zeros_like(net_data_low[0,:]))
        # identity matrix of the correlation between subjects
        ind_st = cls.hclustering(net_data_low, nSubtypes)

        for j in range(nSubtypes):
            data_tmp = np.median(net_data_low[ind_st == j + 1, :],
                                 axis=0)[np.newaxis, ...]
            if j == 0:
                st_templates_tmp = data_tmp
            else:
                st_templates_tmp = np.vstack((st_templates_tmp, data_tmp))

        # st_templates --> Dimensions: nNetwork_low,nSubtypes, nNetwork
        self.st_templates = st_templates_tmp[np.newaxis, ...]
        del st_templates_tmp
        # calculate the weights for each subjects
        self.W = self.compute_weights(net_data_low, self.st_templates)

        if reshape_w:
            return reshapeW(self.W)
        else:
            return self.W
Exemple #3
0
    def fit_dev(self,net_data,nnet_cluster='auto',nSubtypes=3):
        self.nnet_cluster = nnet_cluster
        self.nSubtypes = nSubtypes

        if nnet_cluster == 'auto':
            #self.nnet_cluster = self.getClusters(net_data)
            self.valid_cluster, self.valid_net_idx = self.get_match_network(net_data,nnet_cluster,algo='meanshift')
        else:
            self.valid_cluster, self.valid_net_idx = self.get_match_network(net_data,nnet_cluster,algo='kmeans')

        #self.valid_cluster = self.clust_list
        #self.valid_net_idx = range(len(self.valid_cluster))
        for i in range(net_data.shape[0]):
            if i == 0 :
                self.assign_net = self.assigneDist(net_data[i,:,:],self.valid_cluster, self.valid_net_idx)
            else:
                self.assign_net = np.vstack(((self.assign_net,self.assigneDist(net_data[i,:,:],self.valid_cluster, self.valid_net_idx))))
        print 'Size of the new data map: ',self.assign_net.shape
        # group subjects with the most network classifing them together
        # compute the consensus clustering
        self.consensus = cls.hclustering(self.assign_net,self.nSubtypes)
        # save the centroids in a method
        self.clf_subtypes = NearestCentroid()
        self.clf_subtypes.fit(self.assign_net,self.consensus)
        self.consensus = self.clf_subtypes.predict(self.assign_net)
        #print "score: ", self.clf_subtypes.score(self.assign_net,self.consensus)

        return self.consensus
Exemple #4
0
    def fit_network(self, net_data_low, nSubtypes=3, reshape_w=True):
        self.flag_2level = False
        self.nnet_cluster = 1
        self.nSubtypes = nSubtypes
        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork

        self.normalized_net_template = []
        # average template
        self.normalized_net_template.append(np.mean(net_data_low, axis=0))
        # self.normalized_net_template.append(np.zeros_like(net_data_low[0,:]))
        # indentity matrix of the corelation between subjects
        ind_st = cls.hclustering(net_data_low - self.normalized_net_template[-1], nSubtypes)

        for j in range(nSubtypes):
            data_tmp = np.median(net_data_low[ind_st == j + 1, :] - self.normalized_net_template[-1], axis=0)[
                np.newaxis, ...
            ]
            if j == 0:
                st_templates_tmp = data_tmp
            else:
                st_templates_tmp = np.vstack((st_templates_tmp, data_tmp))

        # st_templates --> Dimensions: nNetwork_low,nSubtypes, nNetwork
        self.st_templates = st_templates_tmp[np.newaxis, ...]
        del st_templates_tmp
        # calculate the weights for each subjects
        self.W = self.compute_weights(net_data_low, self.st_templates)
        if reshape_w:
            return self.reshapeW(self.W)
        else:
            return self.W
Exemple #5
0
    def fit(self, net_data_low, nSubtypes=3, reshape_w=True):
        # net_data_low = net_data_low_main.copy()
        self.flag_2level = False
        self.nnet_cluster = net_data_low.shape[1]
        self.nSubtypes = nSubtypes

        # ind_low_scale = cls.get_ind_high2low(low_res_template,orig_template)
        # self.ind_low_scale = ind_low_scale

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        # net_data_low = transform_low_scale(ts_data,self.ind_low_scale)
        # self.net_data_low = net_data_low

        self.normalized_net_template = []
        for i in range(net_data_low.shape[1]):
            # average template
            if nSubtypes < 1:
                self.normalized_net_template.append(
                    np.zeros_like(net_data_low[0, i, :]).astype(float))
            else:
                self.normalized_net_template.append(
                    np.mean(net_data_low[:, i, :], axis=0))
                # self.normalized_net_template.append(np.zeros_like(net_data_low[0,i,:])).astype(float))

            # indentity matrix of the corelation between subjects
            # tmp_subj_identity = np.corrcoef(net_data_low[:,i,:])
            # ind_st = cls.hclustering(tmp_subj_identity,nSubtypes)
            # subjects X network_nodes
            ind_st = cls.hclustering(net_data_low[:, i, :], nSubtypes)

            for j in range(nSubtypes):
                if j == 0:
                    st_templates_tmp = np.median(
                        net_data_low[:, i, :][ind_st == j + 1, :],
                        axis=0)[np.newaxis, ...]
                else:
                    st_templates_tmp = np.vstack(
                        (st_templates_tmp,
                         np.median(net_data_low[:, i, :][ind_st == j + 1, :],
                                   axis=0)[np.newaxis, ...]))

            # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
            if i == 0:
                self.st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                self.st_templates = np.vstack(
                    (self.st_templates, st_templates_tmp[np.newaxis, ...]))
            del st_templates_tmp

        # calculate the weights for each subjects
        self.W = self.compute_weights(net_data_low, self.st_templates)

        if reshape_w:
            return reshapeW(self.W)
        else:
            return self.W
Exemple #6
0
    def fit(self, net_data_low, nSubtypes=3, reshape_w=True):
        # net_data_low = net_data_low_main.copy()
        self.flag_2level = False
        self.nnet_cluster = net_data_low.shape[1]
        self.nSubtypes = nSubtypes

        # ind_low_scale = cls.get_ind_high2low(low_res_template,orig_template)
        # self.ind_low_scale = ind_low_scale

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        # net_data_low = transform_low_scale(ts_data,self.ind_low_scale)
        # self.net_data_low = net_data_low

        self.normalized_net_template = []
        for i in range(net_data_low.shape[1]):
            # average template
            if nSubtypes < 1:
                self.normalized_net_template.append(np.zeros_like(net_data_low[0, i, :]).astype(float))
            else:
                self.normalized_net_template.append(np.mean(net_data_low[:, i, :], axis=0))
                # self.normalized_net_template.append(np.zeros_like(net_data_low[0,i,:])).astype(float))

            # indentity matrix of the corelation between subjects
            # tmp_subj_identity = np.corrcoef(net_data_low[:,i,:])
            # ind_st = cls.hclustering(tmp_subj_identity,nSubtypes)
            # subjects X network_nodes
            ind_st = cls.hclustering(net_data_low[:, i, :] - self.normalized_net_template[-1], nSubtypes)
            # ind_st = cls.hclustering(net_data_low[:,i,:],nSubtypes)

            for j in range(nSubtypes):
                if j == 0:
                    st_templates_tmp = np.median(net_data_low[:, i, :][ind_st == j + 1, :], axis=0)[np.newaxis, ...]
                else:
                    st_templates_tmp = np.vstack(
                        (
                            st_templates_tmp,
                            np.median(net_data_low[:, i, :][ind_st == j + 1, :], axis=0)[np.newaxis, ...],
                        )
                    )

            # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
            if i == 0:
                self.st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                self.st_templates = np.vstack((self.st_templates, st_templates_tmp[np.newaxis, ...]))
            del st_templates_tmp

        # calculate the weights for each subjects
        self.W = self.compute_weights(net_data_low, self.st_templates)
        if reshape_w:
            return self.reshapeW(self.W)
        else:
            return self.W
Exemple #7
0
    def fit(self, net_data_low, nSubtypes=3, reshape_w=True):
        self.nnet_cluster = net_data_low.shape[1]
        self.nSubtypes = nSubtypes

        #ind_low_scale = cls.get_ind_high2low(low_res_template,orig_template)
        #self.ind_low_scale = ind_low_scale

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        #net_data_low = transform_low_scale(ts_data,self.ind_low_scale)
        self.net_data_low = net_data_low

        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        for i in range(len(net_data_low[1])):
            # indentity matrix of the corelation between subjects
            #tmp_subj_identity = np.corrcoef(net_data_low[:,i,:])
            #ind_st = cls.hclustering(tmp_subj_identity,nSubtypes)
            # subjects X network_nodes
            #ind_st = cls.hclustering(net_data_low[:,i,:]-np.mean(net_data_low[:,i,:],axis=0),nSubtypes)
            ind_st = cls.hclustering(net_data_low[:, i, :], nSubtypes)

            for j in range(nSubtypes):
                if j == 0:
                    st_templates_tmp = net_data_low[:, i, :][
                        ind_st == j + 1, :].mean(axis=0)[np.newaxis, ...]
                else:
                    st_templates_tmp = np.vstack(
                        (st_templates_tmp,
                         net_data_low[:, i, :][ind_st == j +
                                               1, :].mean(axis=0)[np.newaxis,
                                                                  ...]))

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                st_templates = np.vstack(
                    (st_templates, st_templates_tmp[np.newaxis, ...]))

        self.st_templates = st_templates

        # calculate the weights for each subjects
        self.W = self.compute_weights(net_data_low)
        if reshape_w:
            return self.reshapeW(self.W)
        else:
            return self.W
Exemple #8
0
    def robust_st(self, net_data_low, nSubtypes, n_iter=50):
        bs_cluster = []
        n = net_data_low.shape[0]
        stab_ = np.zeros((n, n)).astype(float)
        rs = ShuffleSplit(net_data_low.shape[0], n_iter=n_iter, test_size=0.05, random_state=1)
        for train, test in rs:
            # indentity matrix of the corelation between subjects
            ind_st = cls.hclustering(net_data_low[train, :], nSubtypes)
            mat_ = (cls.ind2matrix(ind_st) > 0).astype(float)
            for ii in range(len(train)):
                stab_[train, train[ii]] += mat_[:, ii]

        stab_ = stab_ / n_iter
        ms = KMeans(nSubtypes)
        ind = ms.fit_predict(stab_)
        # row_clusters = linkage(stab_, method='ward')
        # ind = fcluster(row_clusters, nSubtypes, criterion='maxclust')
        return ind + 1, stab_
Exemple #9
0
    def _robust_st(self, net_data_low, nSubtypes, n_iter=50):
        bs_cluster = []
        n = net_data_low.shape[0]
        stab_ = np.zeros((n, n)).astype(float)
        rs = ShuffleSplit(net_data_low.shape[0],
                          n_iter=n_iter,
                          test_size=.05,
                          random_state=1)
        for train, test in rs:
            # indentity matrix of the corelation between subjects
            ind_st = cls.hclustering(net_data_low[train, :], nSubtypes)
            mat_ = (cls.ind2matrix(ind_st) > 0).astype(float)
            for ii in range(len(train)):
                stab_[train, train[ii]] += mat_[:, ii]

        stab_ = stab_ / n_iter
        ms = KMeans(nSubtypes)
        ind = ms.fit_predict(stab_)
        # row_clusters = linkage(stab_, method='ward')
        # ind = fcluster(row_clusters, nSubtypes, criterion='maxclust')
        return ind + 1, stab_
Exemple #10
0
    def fit(self,net_data_low,nSubtypes=3,reshape_w=True):
        self.nnet_cluster = net_data_low.shape[1]
        self.nSubtypes = nSubtypes

        #ind_low_scale = cls.get_ind_high2low(low_res_template,orig_template)
        #self.ind_low_scale = ind_low_scale

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        #net_data_low = transform_low_scale(ts_data,self.ind_low_scale)
        self.net_data_low = net_data_low

        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        for i in range(len(net_data_low[1])):
            # indentity matrix of the corelation between subjects
            #tmp_subj_identity = np.corrcoef(net_data_low[:,i,:])
            #ind_st = cls.hclustering(tmp_subj_identity,nSubtypes)
            # subjects X network_nodes
            #ind_st = cls.hclustering(net_data_low[:,i,:]-np.mean(net_data_low[:,i,:],axis=0),nSubtypes)
            ind_st = cls.hclustering(net_data_low[:,i,:],nSubtypes)

            for j in range(nSubtypes):
                if j == 0:
                    st_templates_tmp = net_data_low[:,i,:][ind_st==j+1,:].mean(axis=0)[np.newaxis,...]
                else:
                    st_templates_tmp = np.vstack((st_templates_tmp,net_data_low[:,i,:][ind_st==j+1,:].mean(axis=0)[np.newaxis,...]))

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis,...]
            else:
                st_templates = np.vstack((st_templates,st_templates_tmp[np.newaxis,...]))

        self.st_templates = st_templates

        # calculate the weights for each subjects
        self.W =  self.compute_weights(net_data_low)
        if reshape_w:
            return self.reshapeW(self.W)
        else:
            return self.W
Exemple #11
0
    def _fit_2level(self,
                    net_data_low_l1,
                    net_data_low_l2,
                    nSubtypes_l1=5,
                    nSubtypes_l2=2,
                    reshape_w=True):

        # Discontinued function

        self.flag_2level = True
        self.nnet_cluster = net_data_low_l1.shape[1]
        self.nSubtypes = nSubtypes_l1 * nSubtypes_l2
        self.nSubtypes_l1 = nSubtypes_l1
        self.nSubtypes_l2 = nSubtypes_l2

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        self.net_data_low = net_data_low_l1
        self.net_data_low_l2 = net_data_low_l2

        ####
        # LEVEL 1
        ####
        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        for i in range(net_data_low_l1.shape[1]):
            # indentity matrix of the corelation between subjects
            ind_st = cls.hclustering(net_data_low_l1[:, i, :], nSubtypes_l1)

            for j in range(nSubtypes_l1):
                if j == 0:
                    st_templates_tmp = net_data_low_l1[:, i, :][
                        ind_st == j + 1, :].mean(axis=0)[np.newaxis, ...]
                else:
                    st_templates_tmp = np.vstack(
                        (st_templates_tmp,
                         net_data_low_l1[:,
                                         i, :][ind_st == j +
                                               1, :].mean(axis=0)[np.newaxis,
                                                                  ...]))

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                st_templates = np.vstack(
                    (st_templates, st_templates_tmp[np.newaxis, ...]))

        self.st_templates_l1 = st_templates

        # calculate the weights for each subjects
        # W --> Dimensions: nSubjects,nNetwork_low, nSubtypes
        net_data_low_l2_tmp = np.vstack((net_data_low_l1, net_data_low_l2))
        self.W_l1 = self.compute_weights(net_data_low_l2_tmp,
                                         self.st_templates_l1)

        ####
        # LEVEL 2
        ####
        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        # st_templates = self.st_templates_l1.copy()
        # st_templates = st_templates[:,:,np.newaxis,:]
        for i in range(net_data_low_l2.shape[1]):

            # Iterate on all the Level1 subtypes (normal variability subtypes)
            for k in range(self.st_templates_l1.shape[1]):
                # Find the L1 subtype
                max_w = np.max(self.W_l1[:, i, :], axis=1)
                mask_selected_subj = (self.W_l1[:, i, k] == max_w)
                template2substract = self.st_templates_l1[i, k, :]
                if np.sum(mask_selected_subj) <= 3:
                    print('Less then 2 subjects for network: ' + str(i) +
                          ' level1 ST: ' + str(k))
                    for j in range(nSubtypes_l2):
                        if (k == 0) & (j == 0):
                            st_templates_tmp = self.st_templates_l1[i, k, :][
                                np.newaxis, ...]
                        else:
                            st_templates_tmp = np.vstack(
                                (st_templates_tmp,
                                 self.st_templates_l1[i, k, :][np.newaxis,
                                                               ...]))

                else:
                    # indentity matrix of the corelation between subjects
                    ind_st = cls.hclustering(
                        net_data_low_l2_tmp[:, i, :][mask_selected_subj, ...] -
                        template2substract, nSubtypes_l2)
                    # ind_st = cls.hclustering(net_data_low[:,i,:],nSubtypes)
                    if len(np.unique(ind_st)) < nSubtypes_l2:
                        print(
                            'Clustering generated less class then asked nsubjects: '
                            + str(len(ind_st)) + ' network: ' + str(i) +
                            ' level1 ST: ' + str(k))
                        # if (i==6) & (k==3):
                        # print ind_st
                    for j in range(nSubtypes_l2):
                        if (k == 0) & (j == 0):
                            st_templates_tmp = (net_data_low_l2_tmp[:, i, :][
                                mask_selected_subj, ...][ind_st == j + 1, :] -
                                                template2substract).mean(
                                                    axis=0)[np.newaxis, ...]
                        else:
                            st_templates_tmp = np.vstack(
                                (st_templates_tmp,
                                 (net_data_low_l2_tmp[:,
                                                      i, :][mask_selected_subj,
                                                            ...][ind_st == j +
                                                                 1, :] -
                                  template2substract).mean(axis=0)[np.newaxis,
                                                                   ...]))

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                print(st_templates.shape, st_templates_tmp.shape)
                st_templates = np.vstack(
                    (st_templates, st_templates_tmp[np.newaxis, ...]))

        self.st_templates_l2 = st_templates

        # calculate the weights for each subjects
        self.W_l2 = self.compute_weights(net_data_low_l2, self.st_templates_l2)
        if reshape_w:
            return reshapeW(self.W_l2)
        else:
            return self.W_l2
Exemple #12
0
    def fit_2level(self, net_data_low_l1, net_data_low_l2, nSubtypes_l1=5, nSubtypes_l2=2, reshape_w=True):
        self.flag_2level = True
        self.nnet_cluster = net_data_low_l1.shape[1]
        self.nSubtypes = nSubtypes_l1 * nSubtypes_l2
        self.nSubtypes_l1 = nSubtypes_l1
        self.nSubtypes_l2 = nSubtypes_l2

        # net_data_low --> Dimensions: nSubjects, nNetwork_low, nNetwork
        self.net_data_low = net_data_low_l1
        self.net_data_low_l2 = net_data_low_l2

        ####
        # LEVEL 1
        ####
        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        for i in range(net_data_low_l1.shape[1]):
            # indentity matrix of the corelation between subjects
            ind_st = cls.hclustering(net_data_low_l1[:, i, :], nSubtypes_l1)

            for j in range(nSubtypes_l1):
                if j == 0:
                    st_templates_tmp = net_data_low_l1[:, i, :][ind_st == j + 1, :].mean(axis=0)[np.newaxis, ...]
                else:
                    st_templates_tmp = np.vstack(
                        (st_templates_tmp, net_data_low_l1[:, i, :][ind_st == j + 1, :].mean(axis=0)[np.newaxis, ...])
                    )

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                st_templates = np.vstack((st_templates, st_templates_tmp[np.newaxis, ...]))

        self.st_templates_l1 = st_templates

        # calculate the weights for each subjects
        # W --> Dimensions: nSubjects,nNetwork_low, nSubtypes
        net_data_low_l2_tmp = np.vstack((net_data_low_l1, net_data_low_l2))
        self.W_l1 = self.compute_weights(net_data_low_l2_tmp, self.st_templates_l1)

        ####
        # LEVEL 2
        ####
        # st_templates --> Dimensions: nNetwork_low, nSubtypes, nNetwork
        st_templates = []
        # st_templates = self.st_templates_l1.copy()
        # st_templates = st_templates[:,:,np.newaxis,:]
        for i in range(net_data_low_l2.shape[1]):

            # Iterate on all the Level1 subtypes (normal variability subtypes)
            for k in range(self.st_templates_l1.shape[1]):
                # Find the L1 subtype
                max_w = np.max(self.W_l1[:, i, :], axis=1)
                mask_selected_subj = self.W_l1[:, i, k] == max_w
                template2substract = self.st_templates_l1[i, k, :]
                if np.sum(mask_selected_subj) <= 3:
                    print ("Less then 2 subjects for network: " + str(i) + " level1 ST: " + str(k))
                    for j in range(nSubtypes_l2):
                        if (k == 0) & (j == 0):
                            st_templates_tmp = self.st_templates_l1[i, k, :][np.newaxis, ...]
                        else:
                            st_templates_tmp = np.vstack(
                                (st_templates_tmp, self.st_templates_l1[i, k, :][np.newaxis, ...])
                            )

                else:
                    # indentity matrix of the corelation between subjects
                    ind_st = cls.hclustering(
                        net_data_low_l2_tmp[:, i, :][mask_selected_subj, ...] - template2substract, nSubtypes_l2
                    )
                    # ind_st = cls.hclustering(net_data_low[:,i,:],nSubtypes)
                    if len(np.unique(ind_st)) < nSubtypes_l2:
                        print (
                            "Clustering generated less class then asked nsubjects: "
                            + str(len(ind_st))
                            + " network: "
                            + str(i)
                            + " level1 ST: "
                            + str(k)
                        )
                    # if (i==6) & (k==3):
                    # print ind_st
                    for j in range(nSubtypes_l2):
                        if (k == 0) & (j == 0):
                            st_templates_tmp = (
                                net_data_low_l2_tmp[:, i, :][mask_selected_subj, ...][ind_st == j + 1, :]
                                - template2substract
                            ).mean(axis=0)[np.newaxis, ...]
                        else:
                            st_templates_tmp = np.vstack(
                                (
                                    st_templates_tmp,
                                    (
                                        net_data_low_l2_tmp[:, i, :][mask_selected_subj, ...][ind_st == j + 1, :]
                                        - template2substract
                                    ).mean(axis=0)[np.newaxis, ...],
                                )
                            )

            if i == 0:
                st_templates = st_templates_tmp[np.newaxis, ...]
            else:
                print st_templates.shape, st_templates_tmp.shape
                st_templates = np.vstack((st_templates, st_templates_tmp[np.newaxis, ...]))

        self.st_templates_l2 = st_templates

        # calculate the weights for each subjects
        self.W_l2 = self.compute_weights(net_data_low_l2, self.st_templates_l2)
        if reshape_w:
            return self.reshapeW(self.W_l2)
        else:
            return self.W_l2