예제 #1
0
class SC(object):
    """
    Wrapper for sklearn package.  Performs sparse coding

    Sparse Coding, or Dictionary Learning has 5 methods:
       - fit(waveforms)
       update class instance with Sparse Coding fit

       - fit_transform()
       do what fit() does, but additionally return the projection onto new basis space

       - inverse_transform(A)
       inverses the decomposition, returns waveforms for an input A, using Z^\dagger

       - get_basis()
       returns the basis vectors Z^\dagger

       - get_params()
       returns metadata used for fits.
    """
    def __init__(self, num_components=10,
                 catalog_name='unknown',
                 alpha = 0.001,
                 transform_alpha = 0.01,
                 max_iter = 2000,
                 tol = 1e-9,
                 n_jobs = 1,
                 verbose = True,
                 random_state = None):

        self._decomposition   = 'Sparse Coding'
        self._num_components  = num_components
        self._catalog_name    = catalog_name
        self._alpha           = alpha
        self._transform_alpha = 0.001
        self._n_jobs          = n_jobs
        self._random_state    = random_state

        self._DL = DictionaryLearning(n_components=self._num_components,
                              alpha           = self._alpha,
                              transform_alpha = self._transform_alpha,
                              n_jobs          = self._n_jobs,
                              verbose         = verbose,
                              random_state    = self._random_state)

    def fit(self,waveforms):
        # TODO make sure there are more columns than rows (transpose if not)
        # normalize waveforms
        self._waveforms = waveforms
        self._DL.fit(self._waveforms)

    def fit_transform(self,waveforms):
        # TODO make sure there are more columns than rows (transpose if not)
        # normalize waveforms
        self._waveforms = waveforms
        self._A = self._DL.fit_transform(self._waveforms)
        return self._A

    def inverse_transform(self,A):
        # convert basis back to waveforms using fit
        new_waveforms = self._DL.inverse_transform(A)
        return new_waveforms

    def get_params(self):
        # TODO know what catalog was used! (include waveform metadata)
        params = self._DL.get_params()
        params['num_components'] = params.pop('n_components')
        params['Decompositon'] = self._decomposition
        return params

    def get_basis(self):
        """ Return the SPCA basis vectors (Z^\dagger)"""
        return self._DL.components_
예제 #2
0
class SparseCoding:

    DEFAULT_MODEL_PARAMS = {
        'n_components': 10,
        'n_features': 64,
        'max_iter': 5,
        'random_state': 1,
        'dict_init': None,
        'code_init': None
    }

    def __init__(self, model_filename=None):
        if model_filename is not None:
            self.load_model(model_filename)
        else:
            # default model params
            self.n_components = SparseCoding.DEFAULT_MODEL_PARAMS[
                'n_components']
            self.n_features = SparseCoding.DEFAULT_MODEL_PARAMS['n_features']
            self.max_iter = SparseCoding.DEFAULT_MODEL_PARAMS['max_iter']
            self.random_state = SparseCoding.DEFAULT_MODEL_PARAMS[
                'random_state']
            self.dict_init = SparseCoding.DEFAULT_MODEL_PARAMS['dict_init']
            self.code_init = SparseCoding.DEFAULT_MODEL_PARAMS['code_init']

            # initialize Dictionary Learning object with default params and weights
            self.DL_obj = DictionaryLearning(n_components=self.n_components,
                                             alpha=1,
                                             max_iter=self.max_iter,
                                             tol=1e-08,
                                             fit_algorithm='lars',
                                             transform_algorithm='omp',
                                             transform_n_nonzero_coefs=None,
                                             transform_alpha=None,
                                             n_jobs=1,
                                             code_init=self.code_init,
                                             dict_init=self.dict_init,
                                             verbose=False,
                                             split_sign=False,
                                             random_state=self.random_state)

    def save_model(self, filename):
        # save DL object to file, compress is also to prevent multiple model files.
        joblib.dump(self.DL_obj, filename, compress=3)

    def load_model(self, filename):
        # load DL Object from file
        self.DL_obj = joblib.load(filename)

        # set certain model params as class attributes. Get values from DL Obj.get_params() or use default values.
        DL_params = self.DL_obj.get_params()
        for param in SparseCoding.DEFAULT_MODEL_PARAMS:
            if param in DL_params:
                setattr(self, param, DL_params[param])
            else:
                setattr(self, param, SparseCoding.DEFAULT_MODEL_PARAMS[param])

    def learn_dictionary(self, whitened_patches):
        # assert correct dimensionality of input data
        if whitened_patches.ndim == 3:
            whitened_patches = whitened_patches.reshape(
                (whitened_patches.shape[0], -1))
        assert whitened_patches.ndim == 2, "Whitened patches ndim is %d instead of 2" % whitened_patches.ndim

        # learn dictionary
        self.DL_obj.fit(whitened_patches)

    def get_dictionary(self):
        try:
            return self.DL_obj.components_
        except AttributeError:
            raise AttributeError("Feature extraction dictionary has not yet been learnt for this model. " \
                                 + "Train the feature extraction model at least once to prevent this error.")

    def get_sparse_features(self, whitened_patches):
        # assert correct dimensionality of input data
        if whitened_patches.ndim == 3:
            whitened_patches = whitened_patches.reshape(
                (whitened_patches.shape[0], -1))
        assert whitened_patches.ndim == 2, "Whitened patches ndim is %d instead of 2" % whitened_patches.ndim
        try:
            sparse_code = self.DL_obj.transform(whitened_patches)
        except NotFittedError:
            raise NotFittedError("Feature extraction dictionary has not yet been learnt for this model, " \
                                 + "therefore Sparse Codes cannot be extracted. Train the feature extraction model " \
                                 + "at least once to prevent this error.")
        return sparse_code

    def get_sign_split_features(self, sparse_features):
        n_samples, n_components = sparse_features.shape
        sign_split_features = np.empty((n_samples, 2 * n_components))
        sign_split_features[:, :n_components] = np.maximum(sparse_features, 0)
        sign_split_features[:, n_components:] = -np.minimum(sparse_features, 0)
        return sign_split_features

    def get_pooled_features(self, input_feature_map, filter_size=(19, 19)):
        # assuming square filters and images
        filter_side = filter_size[0]

        # reshaping incoming features from 2d to 3d i.e. (3249,20) to (57,57,20)
        input_feature_map_shape = input_feature_map.shape
        if input_feature_map.ndim == 2:
            input_feature_map_side = int(np.sqrt(input_feature_map.shape[0]))
            input_feature_map = input_feature_map.reshape(
                (input_feature_map_side, input_feature_map_side,
                 input_feature_map_shape[-1]))
        assert input_feature_map.ndim == 3, "Input features dimension is %d instead of 3" % input_feature_map.ndim

        # get windows (57,57,20) to (3,3,1,19,19,20)
        input_feature_map_windows = view_as_windows(
            input_feature_map,
            window_shape=(filter_size[0], filter_size[1],
                          input_feature_map.shape[-1]),
            step=filter_size[0])

        # reshape windows (3,3,1,19,19,20) to (3**2, 19**2, 20) == (9, 361, 20)
        input_feature_map_windows = input_feature_map_windows.reshape(
            (input_feature_map_windows.shape[0]**2, filter_size[0]**2,
             input_feature_map.shape[-1]))

        # calculate norms (9, 361, 20) to (9,361)
        input_feature_map_window_norms = np.linalg.norm(
            input_feature_map_windows, ord=2, axis=-1)

        # calculate indexes of max norms per window (9,361) to (9,1). One max index per window.
        max_norm_indexes = np.argmax(input_feature_map_window_norms, axis=-1)

        # max pooled features are the features that have max norm indexes (9, 361, 20) to (9,20). One max index per window.
        pooled_features = input_feature_map_windows[
            np.arange(input_feature_map_windows.shape[0]), max_norm_indexes]

        # return pooled feature map
        return pooled_features

    # Combined Pipeline
    def get_pooled_features_from_whitened_patches(self, whitened_patches):
        sparse_features = self.get_sparse_features(whitened_patches)
        sign_split_features = self.get_sign_split_features(sparse_features)
        pooled_features = self.get_pooled_features(sign_split_features)
        return pooled_features
예제 #3
0
class SC(object):
    """
    Wrapper for sklearn package.  Performs sparse coding

    Sparse Coding, or Dictionary Learning has 5 methods:
       - fit(waveforms)
       update class instance with Sparse Coding fit

       - fit_transform()
       do what fit() does, but additionally return the projection onto new basis space

       - inverse_transform(A)
       inverses the decomposition, returns waveforms for an input A, using Z^\dagger

       - get_basis()
       returns the basis vectors Z^\dagger

       - get_params()
       returns metadata used for fits.
    """
    def __init__(self,
                 num_components=10,
                 catalog_name='unknown',
                 alpha=0.001,
                 transform_alpha=0.01,
                 max_iter=2000,
                 tol=1e-9,
                 n_jobs=1,
                 verbose=True,
                 random_state=None):

        self._decomposition = 'Sparse Coding'
        self._num_components = num_components
        self._catalog_name = catalog_name
        self._alpha = alpha
        self._transform_alpha = 0.001
        self._n_jobs = n_jobs
        self._random_state = random_state

        self._DL = DictionaryLearning(n_components=self._num_components,
                                      alpha=self._alpha,
                                      transform_alpha=self._transform_alpha,
                                      n_jobs=self._n_jobs,
                                      verbose=verbose,
                                      random_state=self._random_state)

    def fit(self, waveforms):
        # TODO make sure there are more columns than rows (transpose if not)
        # normalize waveforms
        self._waveforms = waveforms
        self._DL.fit(self._waveforms)

    def fit_transform(self, waveforms):
        # TODO make sure there are more columns than rows (transpose if not)
        # normalize waveforms
        self._waveforms = waveforms
        self._A = self._DL.fit_transform(self._waveforms)
        return self._A

    def inverse_transform(self, A):
        # convert basis back to waveforms using fit
        new_waveforms = self._DL.inverse_transform(A)
        return new_waveforms

    def get_params(self):
        # TODO know what catalog was used! (include waveform metadata)
        params = self._DL.get_params()
        params['num_components'] = params.pop('n_components')
        params['Decompositon'] = self._decomposition
        return params

    def get_basis(self):
        """ Return the SPCA basis vectors (Z^\dagger)"""
        return self._DL.components_
    def btnConvert_click(self):
        msgBox = QMessageBox()


        Fit = ui.cbFit.currentText()

        Transform = ui.cbTransform.currentText()

        # Tol
        try:
            Tol = np.float(ui.txtTole.text())
        except:
            msgBox.setText("Tolerance is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # MaxIte
        try:
            MaxIter = np.int32(ui.txtMaxIter.text())
        except:
            msgBox.setText("Maximum number of iterations is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        if MaxIter < 1:
            msgBox.setText("Maximum number of iterations is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # Alpha
        try:
            Alpha = np.float(ui.txtAlpha.text())
        except:
            msgBox.setText("Alpha is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # Number of Job
        try:
            NJob = np.int32(ui.txtJobs.text())
        except:
            msgBox.setText("The number of parallel jobs is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        if NJob < 1:
            msgBox.setText("The number of parallel jobs must be greater than 1!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # OutFile
        OutFile = ui.txtOutFile.text()
        if not len(OutFile):
            msgBox.setText("Please enter out file!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # InFile
        InFile = ui.txtInFile.text()
        if not len(InFile):
            msgBox.setText("Please enter input file!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        if not os.path.isfile(InFile):
            msgBox.setText("Input file not found!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        if ui.rbScale.isChecked() == True and ui.rbALScale.isChecked() == False:
            msgBox.setText("Subject Level Normalization is just available for Subject Level Analysis!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        InData = io.loadmat(InFile)
        OutData = dict()
        OutData["imgShape"] = InData["imgShape"]

        if not len(ui.txtData.currentText()):
            msgBox.setText("Please enter Data variable name!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        try:
            X = InData[ui.txtData.currentText()]

            if ui.cbScale.isChecked() and (not ui.rbScale.isChecked()):
                X = preprocessing.scale(X)
                print("Whole of data is scaled X~N(0,1).")
        except:
            print("Cannot load data")
            return

        try:
            NumFea = np.int32(ui.txtNumFea.text())
        except:
            msgBox.setText("Number of features is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False
        if NumFea < 1:
            msgBox.setText("Number of features must be greater than zero!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        if NumFea > np.shape(X)[1]:
            msgBox.setText("Number of features is wrong!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        # Subject
        if not len(ui.txtSubject.currentText()):
            msgBox.setText("Please enter Subject variable name!")
            msgBox.setIcon(QMessageBox.Critical)
            msgBox.setStandardButtons(QMessageBox.Ok)
            msgBox.exec_()
            return False

        try:
            Subject = InData[ui.txtSubject.currentText()]
            OutData[ui.txtOSubject.text()] = Subject
        except:
            print("Cannot load Subject ID")
            return

        # Label
        if not len(ui.txtLabel.currentText()):
                msgBox.setText("Please enter Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
        OutData[ui.txtOLabel.text()] = InData[ui.txtLabel.currentText()]


        # Task
        if ui.cbTask.isChecked():
            if not len(ui.txtTask.currentText()):
                msgBox.setText("Please enter Task variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOTask.text()] = InData[ui.txtTask.currentText()]

        # Run
        if ui.cbRun.isChecked():
            if not len(ui.txtRun.currentText()):
                msgBox.setText("Please enter Run variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtORun.text()] = InData[ui.txtRun.currentText()]


        # Counter
        if ui.cbCounter.isChecked():
            if not len(ui.txtCounter.currentText()):
                msgBox.setText("Please enter Counter variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOCounter.text()] = InData[ui.txtCounter.currentText()]




        # Matrix Label
        if ui.cbmLabel.isChecked():
            if not len(ui.txtmLabel.currentText()):
                msgBox.setText("Please enter Matrix Label variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOmLabel.text()] = InData[ui.txtmLabel.currentText()]


        # Design
        if ui.cbDM.isChecked():
            if not len(ui.txtDM.currentText()):
                msgBox.setText("Please enter Design Matrix variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtODM.text()] = InData[ui.txtDM.currentText()]

        # Coordinate
        if ui.cbCol.isChecked():
            if not len(ui.txtCol.currentText()):
                msgBox.setText("Please enter Coordinator variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOCol.text()] = InData[ui.txtCol.currentText()]

        # Condition
        if ui.cbCond.isChecked():
            if not len(ui.txtCond.currentText()):
                msgBox.setText("Please enter Condition variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOCond.text()] = InData[ui.txtCond.currentText()]

        # Number of Scan
        if ui.cbNScan.isChecked():
            if not len(ui.txtScan.currentText()):
                msgBox.setText("Please enter Number of Scan variable name!")
                msgBox.setIcon(QMessageBox.Critical)
                msgBox.setStandardButtons(QMessageBox.Ok)
                msgBox.exec_()
                return False
            OutData[ui.txtOScan.text()] = InData[ui.txtScan.currentText()]

        Models = dict()
        Models["Name"] = "DictionaryLearning"

        if ui.rbALScale.isChecked():
            print("Partition data to subject level ...")
            SubjectUniq = np.unique(Subject)
            X_Sub = list()
            for subj in SubjectUniq:
                if ui.cbScale.isChecked() and ui.rbScale.isChecked():
                    X_Sub.append(preprocessing.scale(X[np.where(Subject == subj)[1], :]))
                    print("Data in subject level is scaled, X_" + str(subj) + "~N(0,1).")
                else:
                    X_Sub.append(X[np.where(Subject == subj)[1],:])
                print("Subject ", subj, " is extracted from data.")

            print("Running Dictionary Learning in subject level ...")
            X_Sub_PCA = list()
            lenPCA    = len(X_Sub)

            for xsubindx, xsub in enumerate(X_Sub):
                model = DictionaryLearning(n_components=NumFea,alpha=Alpha,max_iter=MaxIter,\
                                           tol=Tol,fit_algorithm=Fit,transform_alpha=Transform,n_jobs=NJob)
                X_Sub_PCA.append(model.fit_transform(xsub))
                Models["Model" + str(xsubindx + 1)] = str(model.get_params(deep=True))

                print("Dictionary Learning: ", xsubindx + 1, " of ", lenPCA, " is done.")

            print("Data integration ... ")
            X_new = None
            for xsubindx, xsub in enumerate(X_Sub_PCA):
                X_new = np.concatenate((X_new, xsub)) if X_new is not None else xsub
                print("Integration: ", xsubindx + 1, " of ", lenPCA, " is done.")
            OutData[ui.txtOData.text()] = X_new
        else:
            print("Running Dictionary Learning ...")
            model = DictionaryLearning(n_components=NumFea, alpha=Alpha, max_iter=MaxIter, \
                                   tol=Tol, fit_algorithm=Fit, transform_alpha=Transform, n_jobs=NJob)
            OutData[ui.txtOData.text()] = model.fit_transform(X)
            Models["Model"] = str(model.get_params(deep=True))

        OutData["ModelParameter"] = Models

        print("Saving ...")
        io.savemat(ui.txtOutFile.text(), mdict=OutData)
        print("DONE.")
        msgBox.setText("Dictionary Learning is done.")
        msgBox.setIcon(QMessageBox.Information)
        msgBox.setStandardButtons(QMessageBox.Ok)
        msgBox.exec_()
예제 #5
0
class SparseCoding:

    DEFAULT_MODEL_PARAMS = {
        'n_components' : 10,
        'n_features' : 64,
        'max_iter' : 5,
        'random_state' : 1,
        'dict_init' : None,
        'code_init' : None
    }

    def __init__(self, model_filename=None):
        if model_filename is not None:
            self.load_model(model_filename)
        else:
            # default model params
            self.n_components = SparseCoding.DEFAULT_MODEL_PARAMS['n_components']
            self.n_features = SparseCoding.DEFAULT_MODEL_PARAMS['n_features']
            self.max_iter = SparseCoding.DEFAULT_MODEL_PARAMS['max_iter']
            self.random_state = SparseCoding.DEFAULT_MODEL_PARAMS['random_state']
            self.dict_init = SparseCoding.DEFAULT_MODEL_PARAMS['dict_init']
            self.code_init = SparseCoding.DEFAULT_MODEL_PARAMS['code_init']

            # initialize Dictionary Learning object with default params and weights
            self.DL_obj = DictionaryLearning(n_components=self.n_components,
                                       alpha=1,
                                       max_iter=self.max_iter,
                                       tol=1e-08,
                                       fit_algorithm='lars',
                                       transform_algorithm='omp',
                                       transform_n_nonzero_coefs=None,
                                       transform_alpha=None,
                                       n_jobs=1,
                                       code_init=self.code_init,
                                       dict_init=self.dict_init,
                                       verbose=False,
                                       split_sign=False,
                                       random_state=self.random_state)


    def save_model(self, filename):
        # save DL object to file, compress is also to prevent multiple model files.
        joblib.dump(self.DL_obj, filename, compress=3)


    def load_model(self, filename):
        # load DL Object from file
        self.DL_obj = joblib.load(filename)

        # set certain model params as class attributes. Get values from DL Obj.get_params() or use default values.
        DL_params = self.DL_obj.get_params()
        for param in SparseCoding.DEFAULT_MODEL_PARAMS:
            if param in DL_params:
                setattr(self, param, DL_params[param])
            else:
                setattr(self, param, SparseCoding.DEFAULT_MODEL_PARAMS[param])


    def learn_dictionary(self, whitened_patches):
        # assert correct dimensionality of input data
        if whitened_patches.ndim == 3:
            whitened_patches = whitened_patches.reshape((whitened_patches.shape[0], -1))
        assert whitened_patches.ndim == 2, "Whitened patches ndim is %d instead of 2" %whitened_patches.ndim

        # learn dictionary
        self.DL_obj.fit(whitened_patches)


    def get_dictionary(self):
        try:
            return self.DL_obj.components_
        except AttributeError:
            raise AttributeError("Feature extraction dictionary has not yet been learnt for this model. " \
                                 + "Train the feature extraction model at least once to prevent this error.")


    def get_sparse_features(self, whitened_patches):
        # assert correct dimensionality of input data
        if whitened_patches.ndim == 3:
            whitened_patches = whitened_patches.reshape((whitened_patches.shape[0], -1))
        assert whitened_patches.ndim == 2, "Whitened patches ndim is %d instead of 2" %whitened_patches.ndim
        try:
            sparse_code = self.DL_obj.transform(whitened_patches)
        except NotFittedError:
            raise NotFittedError("Feature extraction dictionary has not yet been learnt for this model, " \
                                 + "therefore Sparse Codes cannot be extracted. Train the feature extraction model " \
                                 + "at least once to prevent this error.")
        return sparse_code


    def get_sign_split_features(self, sparse_features):
        n_samples, n_components = sparse_features.shape
        sign_split_features = np.empty((n_samples, 2 * n_components))
        sign_split_features[:, :n_components] = np.maximum(sparse_features, 0)
        sign_split_features[:, n_components:] = -np.minimum(sparse_features, 0)
        return sign_split_features


    def get_pooled_features(self, input_feature_map, filter_size=(19,19)):
        # assuming square filters and images
        filter_side = filter_size[0]

        # reshaping incoming features from 2d to 3d i.e. (3249,20) to (57,57,20)
        input_feature_map_shape = input_feature_map.shape
        if input_feature_map.ndim == 2:
            input_feature_map_side = int(np.sqrt(input_feature_map.shape[0]))
            input_feature_map = input_feature_map.reshape((input_feature_map_side, input_feature_map_side, input_feature_map_shape[-1]))
        assert input_feature_map.ndim == 3, "Input features dimension is %d instead of 3" %input_feature_map.ndim

        # get windows (57,57,20) to (3,3,1,19,19,20)
        input_feature_map_windows = view_as_windows(input_feature_map,
                                                    window_shape=(filter_size[0], filter_size[1], input_feature_map.shape[-1]),
                                                    step=filter_size[0])

        # reshape windows (3,3,1,19,19,20) to (3**2, 19**2, 20) == (9, 361, 20)
        input_feature_map_windows = input_feature_map_windows.reshape((input_feature_map_windows.shape[0]**2,
                                                                       filter_size[0]**2,
                                                                       input_feature_map.shape[-1]))

        # calculate norms (9, 361, 20) to (9,361)
        input_feature_map_window_norms = np.linalg.norm(input_feature_map_windows, ord=2, axis=-1)

        # calculate indexes of max norms per window (9,361) to (9,1). One max index per window.
        max_norm_indexes = np.argmax(input_feature_map_window_norms, axis=-1)

        # max pooled features are the features that have max norm indexes (9, 361, 20) to (9,20). One max index per window.
        pooled_features = input_feature_map_windows[np.arange(input_feature_map_windows.shape[0]), max_norm_indexes]

        # return pooled feature map
        return pooled_features


    # Combined Pipeline
    def get_pooled_features_from_whitened_patches(self, whitened_patches):
        sparse_features = self.get_sparse_features(whitened_patches)
        sign_split_features = self.get_sign_split_features(sparse_features)
        pooled_features = self.get_pooled_features(sign_split_features)
        return pooled_features