Beispiel #1
0
    def predict(self,
                K1pred=None,
                K2pred=None,
                row_inds_K1pred=None,
                row_inds_K2pred=None,
                pko=None):
        """Computes predictions for test examples.

        Parameters
        ----------
        K1pred : {array-like, list of equally shaped array-likes}, shape = [n_samples1, n_train_pairs]
            the first part of the test data matrix
        K2pred : {array-like, list of equally shaped array-likes}, shape = [n_samples2, n_train_pairs]
            the second part of the test data matrix
        row_inds_K1pred : list of indices, shape = [n_test_pairs], optional
            maps rows of K1pred to vector of predictions P. If not supplied, predictions are computed for all possible test pair combinations.
        row_inds_K2pred : list of indices, shape = [n_test_pairs], optional
            maps rows of K2pred to vector of predictions P. If not supplied, predictions are computed for all possible test pair combinations.
            
        Returns
        ----------
        P : array, shape = [n_test_pairs] or [n_samples1*n_samples2]
            predictions, either ordered according to the supplied row indices, or if no such are supplied by default
            prediction for (K1[i], K2[j]) maps to P[i + j*n_samples1].
        """
        if pko == None:
            pko = pairwise_kernel_operator.PairwiseKernelOperator(
                K1pred, K2pred, row_inds_K1pred, row_inds_K2pred,
                self.row_inds_K1training, self.row_inds_K2training,
                self.weights)
        return pko.mv(self.A)
        '''
Beispiel #2
0
    def predict(self,
                X1pred,
                X2pred,
                inds_X1pred=None,
                inds_X2pred=None,
                pko=None):
        """Computes predictions for test examples.

        Parameters
        ----------
        X1pred : array-like, shape = [n_samples1, n_features1]
            the first part of the test data matrix
        X2pred : array-like, shape = [n_samples2, n_features2]
            the second part of the test data matrix
        inds_X1pred : list of indices, shape = [n_test_pairs], optional
            maps rows of X1pred to vector of predictions P. If not supplied, predictions are computed for all possible test pair combinations.
        inds_X2pred : list of indices, shape = [n_test_pairs], optional
            maps rows of X2pred to vector of predictions P. If not supplied, predictions are computed for all possible test pair combinations.
            
        Returns
        ----------
        P : array, shape = [n_test_pairs] or [n_samples1*n_samples2]
            predictions, either ordered according to the supplied row indices, or if no such are supplied by default
            prediction for (X1[i], X2[j]) maps to P[i + j*n_samples1].
        """

        if pko == None:
            pko = pairwise_kernel_operator.PairwiseKernelOperator(
                X1pred, X2pred, inds_X1pred, inds_X2pred, None, None,
                self.weights)
        return pko.matvec(self.W)
Beispiel #3
0
    def __init__(self, **kwargs):
        self.Y = kwargs["Y"]
        #self.Y = array_tools.as_2d_array(Y)
        self.trained = False
        if "regparam" in kwargs:
            self.regparam = kwargs["regparam"]
        else:
            self.regparam = 0.
        regparam = self.regparam
        if CALLBACK_FUNCTION in kwargs:
            self.callbackfun = kwargs[CALLBACK_FUNCTION]
        else:
            self.callbackfun = None
        if "compute_risk" in kwargs:
            self.compute_risk = kwargs["compute_risk"]
        else:
            self.compute_risk = False

        if 'K1' in kwargs or 'pko' in kwargs:
            if 'pko' in kwargs:
                pko = kwargs['pko']
            else:
                self.input1_inds = np.array(kwargs["label_row_inds"],
                                            dtype=np.int32)
                self.input2_inds = np.array(kwargs["label_col_inds"],
                                            dtype=np.int32)
                K1 = kwargs['K1']
                K2 = kwargs['K2']
                if 'weights' in kwargs: weights = kwargs['weights']
                else: weights = None
                pko = pairwise_kernel_operator.PairwiseKernelOperator(
                    K1, K2, self.input1_inds, self.input2_inds,
                    self.input1_inds, self.input2_inds, weights)
            self.pko = pko
            if 'maxiter' in kwargs: maxiter = int(kwargs['maxiter'])
            else: maxiter = None

            Y = np.array(self.Y).ravel(order='F')
            self.bestloss = float("inf")

            def mv(v):
                return pko.matvec(v) + regparam * v

            def mvr(v):
                raise Exception('This function should not be called!')

            def cgcb(v):
                if self.compute_risk:
                    P = sampled_kronecker_products.sampled_vec_trick(
                        v, K2, K1, self.input2_inds, self.input1_inds,
                        self.input2_inds, self.input1_inds)
                    z = (Y - P)
                    Ka = sampled_kronecker_products.sampled_vec_trick(
                        v, K2, K1, self.input2_inds, self.input1_inds,
                        self.input2_inds, self.input1_inds)
                    loss = (np.dot(z, z) + regparam * np.dot(v, Ka))
                    print("loss", 0.5 * loss)
                    if loss < self.bestloss:
                        self.A = v.copy()
                        self.bestloss = loss
                else:
                    self.A = v
                if not self.callbackfun is None:
                    #self.predictor = KernelPairwisePredictor(self.A, self.input1_inds, self.input2_inds, self.pko.weights)
                    self.callbackfun.callback(self)

            G = LinearOperator((self.Y.shape[0], self.Y.shape[0]),
                               matvec=mv,
                               rmatvec=mvr,
                               dtype=np.float64)
            self.A = minres(G,
                            self.Y,
                            maxiter=maxiter,
                            callback=cgcb,
                            tol=1e-20)[0]
            self.predictor = KernelPairwisePredictor(
                self.A, self.pko.original_col_inds_K1,
                self.pko.original_col_inds_K2, self.pko.weights)
            if not self.callbackfun is None:
                self.callbackfun.finished(self)
        else:
            self.input1_inds = np.array(kwargs["label_row_inds"],
                                        dtype=np.int32)
            self.input2_inds = np.array(kwargs["label_col_inds"],
                                        dtype=np.int32)
            X1 = kwargs['X1']
            X2 = kwargs['X2']
            self.X1, self.X2 = X1, X2

            if 'maxiter' in kwargs: maxiter = int(kwargs['maxiter'])
            else: maxiter = None

            if 'weights' in kwargs: weights = kwargs['weights']
            else: weights = None

            Y = np.array(self.Y).ravel(order='F')
            self.bestloss = float("inf")

            def mv(v):
                v_after = pko.matvec(v)
                v_after = pko.rmatvec(v_after) + regparam * v
                return v_after

            def cgcb(v):
                if self.compute_risk:
                    P = sampled_kronecker_products.sampled_vec_trick(
                        v, X2, X1, self.input2_inds, self.input1_inds)
                    z = (Y - P)
                    loss = (np.dot(z, z) + regparam * np.dot(v, v))
                    if loss < self.bestloss:
                        self.W = v.copy().reshape(pko.shape, order='F')
                        self.bestloss = loss
                else:
                    self.W = v
                if not self.callbackfun is None:
                    self.predictor = LinearPairwisePredictor(self.W)
                    self.callbackfun.callback(self)

            v_init = np.array(self.Y).reshape(self.Y.shape[0])
            pko = pairwise_kernel_operator.PairwiseKernelOperator(
                X1, X2, self.input1_inds, self.input2_inds, None, None,
                weights)
            G = LinearOperator((pko.shape[1], pko.shape[1]),
                               matvec=mv,
                               dtype=np.float64)
            v_init = pko.rmatvec(v_init)
            '''if 'warm_start' in kwargs:
                x0 = np.array(kwargs['warm_start']).reshape(kronfcount, order = 'F')
            else:
                x0 = None'''
            minres(G, v_init, maxiter=maxiter, callback=cgcb, tol=1e-20
                   )  #[0].reshape((pko_T.shape[0], pko.shape[1]), order='F')
            self.predictor = LinearPairwisePredictor(self.W, self.input1_inds,
                                                     self.input2_inds, weights)
            if not self.callbackfun is None:
                self.callbackfun.finished(self)
Beispiel #4
0
    def test_cg_kron_rls(self):

        regparam = 0.0001

        K_train1, K_train2, Y_train, K_test1, K_test2, Y_test, X_train1, X_train2, X_test1, X_test2 = self.generate_xortask(
        )
        #K_train1, K_train2, Y_train, K_test1, K_test2, Y_test, X_train1, X_train2, X_test1, X_test2 = self.generate_xortask(trainpos1 = 1, trainneg1 = 1, trainpos2 = 1, trainneg2 = 1, testpos1 = 1, testneg1 = 1, testpos2 = 1, testneg2 = 1)
        Y_train = Y_train.ravel(order='F')
        Y_test = Y_test.ravel(order='F')
        train_rows, train_columns = K_train1.shape[0], K_train2.shape[0]
        test_rows, test_columns = K_test1.shape[0], K_test2.shape[0]
        rowstimescols = train_rows * train_columns
        allindices = np.arange(rowstimescols)
        all_label_row_inds, all_label_col_inds = np.unravel_index(
            allindices, (train_rows, train_columns), order='F')
        #incinds = np.random.permutation(allindices)
        #incinds = np.random.choice(allindices, 50, replace = False)
        incinds = np.random.choice(allindices, 40, replace=False)
        label_row_inds, label_col_inds = all_label_row_inds[
            incinds], all_label_col_inds[incinds]
        Y_train_known_outputs = Y_train.reshape(rowstimescols,
                                                order='F')[incinds]

        alltestindices = np.arange(test_rows * test_columns)
        all_test_label_row_inds, all_test_label_col_inds = np.unravel_index(
            alltestindices, (test_rows, test_columns), order='F')

        #Train an ordinary RLS regressor for reference
        params = {}
        params["X"] = np.kron(K_train2, K_train1)[np.ix_(incinds, incinds)]
        params["kernel"] = "PrecomputedKernel"
        params["Y"] = Y_train_known_outputs
        params["regparam"] = regparam
        ordrls_learner = RLS(**params)
        ordrls_model = ordrls_learner.predictor
        K_Kron_test = np.kron(K_test2, K_test1)[:, incinds]
        ordrls_testpred = ordrls_model.predict(K_Kron_test)
        ordrls_testpred = ordrls_testpred.reshape((test_rows, test_columns),
                                                  order='F')

        #Train linear Kronecker RLS
        class TestCallback():
            def __init__(self):
                self.round = 0

            def callback(self, learner):
                self.round = self.round + 1
                tp = LinearPairwisePredictor(learner.W).predict(
                    X_test1, X_test2)
                print(
                    str(self.round) + ' ' +
                    str(np.mean(np.abs(tp -
                                       ordrls_testpred.ravel(order='F')))))

            def finished(self, learner):
                print('finished')

        params = {}
        params["regparam"] = regparam
        params["X1"] = X_train1
        params["X2"] = X_train2
        params["Y"] = Y_train_known_outputs
        params["label_row_inds"] = label_row_inds
        params["label_col_inds"] = label_col_inds
        tcb = TestCallback()
        params['callback'] = tcb
        linear_kron_learner = CGKronRLS(**params)
        linear_kron_testpred = linear_kron_learner.predict(
            X_test1, X_test2).reshape((test_rows, test_columns), order='F')
        linear_kron_testpred_alt = linear_kron_learner.predict(
            X_test1, X_test2, [0, 0, 1], [0, 1, 0])

        #Train kernel Kronecker RLS
        params = {}
        params["regparam"] = regparam
        params["K1"] = K_train1
        params["K2"] = K_train2
        params["Y"] = Y_train_known_outputs
        params["label_row_inds"] = label_row_inds
        params["label_col_inds"] = label_col_inds

        class KernelCallback():
            def __init__(self):
                self.round = 0

            def callback(self, learner):
                self.round = self.round + 1
                tp = KernelPairwisePredictor(learner.A, learner.input1_inds,
                                             learner.input2_inds).predict(
                                                 K_test1, K_test2)
                print(
                    str(self.round) + ' ' +
                    str(np.mean(np.abs(tp -
                                       ordrls_testpred.ravel(order='F')))))

            def finished(self, learner):
                print('finished')

        tcb = KernelCallback()
        params['callback'] = tcb
        kernel_kron_learner = CGKronRLS(**params)
        kernel_kron_testpred = kernel_kron_learner.predict(
            K_test1, K_test2).reshape((test_rows, test_columns), order='F')
        kernel_kron_testpred_alt = kernel_kron_learner.predict(
            K_test1, K_test2, [0, 0, 1], [0, 1, 0])

        print('Predictions: Linear CgKronRLS, Kernel CgKronRLS, ordinary RLS')
        print('[0, 0]: ' + str(linear_kron_testpred[0, 0]) + ' ' +
              str(kernel_kron_testpred[0, 0]) + ' ' +
              str(ordrls_testpred[0, 0])
              )  #, linear_kron_testpred_alt[0], kernel_kron_testpred_alt[0]
        print('[0, 1]: ' + str(linear_kron_testpred[0, 1]) + ' ' +
              str(kernel_kron_testpred[0, 1]) + ' ' +
              str(ordrls_testpred[0, 1])
              )  #, linear_kron_testpred_alt[1], kernel_kron_testpred_alt[1]
        print('[1, 0]: ' + str(linear_kron_testpred[1, 0]) + ' ' +
              str(kernel_kron_testpred[1, 0]) + ' ' +
              str(ordrls_testpred[1, 0])
              )  #, linear_kron_testpred_alt[2], kernel_kron_testpred_alt[2]
        print(
            'Meanabsdiff: linear KronRLS - ordinary RLS, kernel KronRLS - ordinary RLS'
        )
        print(
            str(np.mean(np.abs(linear_kron_testpred - ordrls_testpred))) +
            ' ' + str(np.mean(np.abs(kernel_kron_testpred - ordrls_testpred))))
        np.testing.assert_almost_equal(linear_kron_testpred,
                                       ordrls_testpred,
                                       decimal=5)
        np.testing.assert_almost_equal(kernel_kron_testpred,
                                       ordrls_testpred,
                                       decimal=4)

        #Train multiple kernel Kronecker RLS
        params = {}
        params["regparam"] = regparam
        params["K1"] = [K_train1, K_train1]
        params["K2"] = [K_train2, K_train2]
        params["weights"] = [1. / 3, 2. / 3]
        params["Y"] = Y_train_known_outputs
        params["label_row_inds"] = [label_row_inds, label_row_inds]
        params["label_col_inds"] = [label_col_inds, label_col_inds]

        class KernelCallback():
            def __init__(self):
                self.round = 0

            def callback(self, learner):
                self.round = self.round + 1
                tp = KernelPairwisePredictor(
                    learner.A, learner.input1_inds, learner.input2_inds,
                    params["weights"]).predict([K_test1, K_test1],
                                               [K_test2, K_test2])
                print(
                    str(self.round) + ' ' +
                    str(np.mean(np.abs(tp -
                                       ordrls_testpred.ravel(order='F')))))

            def finished(self, learner):
                print('finished')

        tcb = KernelCallback()
        params['callback'] = tcb
        mkl_kernel_kron_learner = CGKronRLS(**params)
        mkl_kernel_kron_testpred = mkl_kernel_kron_learner.predict(
            [K_test1, K_test1], [K_test2, K_test2]).reshape(
                (test_rows, test_columns), order='F')
        #kernel_kron_testpred_alt = kernel_kron_learner.predict(K_test1, K_test2, [0, 0, 1], [0, 1, 0])
        '''
        #Train linear multiple kernel Kronecker RLS
        params = {}
        params["regparam"] = regparam
        params["X1"] = [X_train1, X_train1]
        params["X2"] = [X_train2, X_train2]
        params["weights"] = [1. / 3, 2. / 3]
        params["Y"] = Y_train_known_outputs
        params["label_row_inds"] = [label_row_inds, label_row_inds]
        params["label_col_inds"] = [label_col_inds, label_col_inds]
        mkl_linear_kron_learner = CGKronRLS(**params)
        mkl_linear_kron_testpred = mkl_linear_kron_learner.predict([X_test1, X_test1], [X_test2, X_test2]).reshape((test_rows, test_columns), order = 'F')
        #kernel_kron_testpred_alt = kernel_kron_learner.predict(K_test1, K_test2, [0, 0, 1], [0, 1, 0])
        
        print('Predictions: Linear CgKronRLS, MKL Kernel CgKronRLS, MKL linear CgKronRLS, ordinary RLS')
        print('[0, 0]: ' + str(linear_kron_testpred[0, 0]) + ' ' + str(mkl_kernel_kron_testpred[0, 0]) + ' ' + str(mkl_linear_kron_testpred[0, 0]) + ' ' + str(ordrls_testpred[0, 0]))#, linear_kron_testpred_alt[0], kernel_kron_testpred_alt[0]
        print('[0, 1]: ' + str(linear_kron_testpred[0, 1]) + ' ' + str(mkl_kernel_kron_testpred[0, 1]) + ' ' + str(mkl_linear_kron_testpred[0, 1]) + ' ' + str(ordrls_testpred[0, 1]))#, linear_kron_testpred_alt[1], kernel_kron_testpred_alt[1]
        print('[1, 0]: ' + str(linear_kron_testpred[1, 0]) + ' ' + str(mkl_kernel_kron_testpred[1, 0]) + ' ' + str(mkl_linear_kron_testpred[1, 0]) + ' ' + str(ordrls_testpred[1, 0]))#, linear_kron_testpred_alt[2], kernel_kron_testpred_alt[2]
        print('Meanabsdiff: MKL kernel KronRLS - ordinary RLS')
        print(str(np.mean(np.abs(mkl_kernel_kron_testpred - ordrls_testpred))))
        np.testing.assert_almost_equal(mkl_kernel_kron_testpred, ordrls_testpred, decimal=3)
        '''

        #Train polynomial kernel Kronecker RLS
        params = {}
        params["regparam"] = regparam
        #params["K1"] = [K_train1, K_train1, K_train2]
        #params["K2"] = [K_train1, K_train2, K_train2]
        #params["weights"] = [1., 2., 1.]
        params["pko"] = pairwise_kernel_operator.PairwiseKernelOperator(
            [K_train1, K_train1, K_train2], [K_train1, K_train2, K_train2],
            [label_row_inds, label_row_inds, label_col_inds],
            [label_row_inds, label_col_inds, label_col_inds],
            [label_row_inds, label_row_inds, label_col_inds],
            [label_row_inds, label_col_inds, label_col_inds], [1., 2., 1.])
        params["Y"] = Y_train_known_outputs

        #params["label_row_inds"] = [label_row_inds, label_row_inds, label_col_inds]
        #params["label_col_inds"] = [label_row_inds, label_col_inds, label_col_inds]
        class KernelCallback():
            def __init__(self):
                self.round = 0

            def callback(self, learner):
                self.round = self.round + 1
                #tp = KernelPairwisePredictor(learner.A, learner.input1_inds, learner.input2_inds, params["weights"]).predict([K_test1, K_test1], [K_test2, K_test2])
                #print(str(self.round) + ' ' + str(np.mean(np.abs(tp - ordrls_testpred.ravel(order = 'F')))))
            def finished(self, learner):
                print('finished')

        tcb = KernelCallback()
        params['callback'] = tcb
        poly_kernel_kron_learner = CGKronRLS(**params)
        pko = pairwise_kernel_operator.PairwiseKernelOperator(
            [K_test1, K_test1, K_test2], [K_test1, K_test2, K_test2], [
                all_test_label_row_inds, all_test_label_row_inds,
                all_test_label_col_inds
            ], [
                all_test_label_row_inds, all_test_label_col_inds,
                all_test_label_col_inds
            ], [label_row_inds, label_row_inds, label_col_inds],
            [label_row_inds, label_col_inds, label_col_inds], [1., 2., 1.])
        #poly_kernel_kron_testpred = poly_kernel_kron_learner.predict(pko = pko)
        poly_kernel_kron_testpred = poly_kernel_kron_learner.predict(
            [K_test1, K_test1, K_test2], [K_test1, K_test2, K_test2], [
                all_test_label_row_inds, all_test_label_row_inds,
                all_test_label_col_inds
            ], [
                all_test_label_row_inds, all_test_label_col_inds,
                all_test_label_col_inds
            ])
        #print(poly_kernel_kron_testpred, 'Polynomial kernel via CGKronRLS')

        #Train an ordinary RLS regressor with polynomial kernel for reference
        params = {}
        params["X"] = np.hstack([
            np.kron(np.ones((X_train2.shape[0], 1)), X_train1),
            np.kron(X_train2, np.ones((X_train1.shape[0], 1)))
        ])[incinds]
        #params["X"] = np.hstack([np.kron(X_train1, np.ones((X_train2.shape[0], 1))), np.kron(np.ones((X_train1.shape[0], 1)), X_train2)])[incinds]
        params["kernel"] = "PolynomialKernel"
        params["Y"] = Y_train_known_outputs
        params["regparam"] = regparam
        ordrls_poly_kernel_learner = RLS(**params)
        X_dir_test = np.hstack([
            np.kron(np.ones((X_test2.shape[0], 1)), X_test1),
            np.kron(X_test2, np.ones((X_test1.shape[0], 1)))
        ])
        #X_dir_test = np.hstack([np.kron(X_test1, np.ones((X_test2.shape[0], 1))), np.kron(np.ones((X_test1.shape[0], 1)), X_test2)])
        ordrls_poly_kernel_testpred = ordrls_poly_kernel_learner.predict(
            X_dir_test)
        #print(ordrls_poly_kernel_testpred, 'Ord. poly RLS')
        print(
            'Meanabsdiff: Polynomial kernel KronRLS - Ordinary polynomial kernel RLS'
        )
        print(
            str(
                np.mean(
                    np.abs(poly_kernel_kron_testpred -
                           ordrls_poly_kernel_testpred))))
        '''
        #Train polynomial kernel Kronecker RLS
        params = {}
        params["regparam"] = regparam
        #params["X1"] = [X_train1, X_train1, X_train2]
        #params["X2"] = [X_train1, X_train2, X_train2]
        params["K1"] = [K_train1, K_train1, K_train2]
        params["K2"] = [K_train1, K_train2, K_train2]
        params["weights"] = [1., 2., 1.]
        params["Y"] = Y_train_known_outputs
        params["label_row_inds"] = [label_row_inds, label_row_inds, label_col_inds]
        params["label_col_inds"] = [label_row_inds, label_col_inds, label_col_inds]
        class KernelCallback():
            def __init__(self):
                self.round = 0
            def callback(self, learner):
                self.round = self.round + 1
                #tp = KernelPairwisePredictor(learner.A, learner.input1_inds, learner.input2_inds, params["weights"]).predict([K_test1, K_test1], [K_test2, K_test2])
                #print(str(self.round) + ' ' + str(np.mean(np.abs(tp - ordrls_testpred.ravel(order = 'F')))))
            def finished(self, learner):
                print('finished')
        tcb = KernelCallback()
        params['callback'] = tcb
        poly_kernel_linear_kron_learner = CGKronRLS(**params)
        #poly_kernel_linear_kron_testpred = poly_kernel_linear_kron_learner.predict([X_test1, X_test1, X_test2], [X_test1, X_test2, X_test2], [all_test_label_row_inds, all_test_label_row_inds, all_test_label_col_inds], [all_test_label_row_inds, all_test_label_col_inds, all_test_label_col_inds])
        poly_kernel_linear_kron_testpred = poly_kernel_linear_kron_learner.predict([K_test1, K_test1, K_test2], [K_test1, K_test2, K_test2], [all_test_label_row_inds, all_test_label_row_inds, all_test_label_col_inds], [all_test_label_row_inds, all_test_label_col_inds, all_test_label_col_inds])
        #print(poly_kernel_kron_testpred, 'Polynomial kernel via CGKronRLS (linear)')
        print('Meanabsdiff: Polynomial kernel KronRLS (linear) - Ordinary polynomial kernel RLS')
        print(str(np.mean(np.abs(poly_kernel_linear_kron_testpred - ordrls_poly_kernel_testpred))))
        '''
        '''
Beispiel #5
0
    def __init__(self, **kwargs):
        self.Y = kwargs["Y"]
        #self.Y = array_tools.as_2d_array(Y)
        self.trained = False
        if "regparam" in kwargs:
            self.regparam = kwargs["regparam"]
        else:
            self.regparam = 0.
        regparam = self.regparam
        if CALLBACK_FUNCTION in kwargs:
            self.callbackfun = kwargs[CALLBACK_FUNCTION]
        else:
            self.callbackfun = None
        if "compute_risk" in kwargs:
            self.compute_risk = kwargs["compute_risk"]
        else:
            self.compute_risk = False

        if 'K1' in kwargs or 'pko' in kwargs:
            if 'pko' in kwargs:
                pko = kwargs['pko']
            else:
                self.input1_inds = np.array(kwargs["label_row_inds"],
                                            dtype=np.int32)
                self.input2_inds = np.array(kwargs["label_col_inds"],
                                            dtype=np.int32)
                K1 = kwargs['K1']
                K2 = kwargs['K2']
                if 'weights' in kwargs: weights = kwargs['weights']
                else: weights = None
                pko = pairwise_kernel_operator.PairwiseKernelOperator(
                    K1, K2, self.input1_inds, self.input2_inds,
                    self.input1_inds, self.input2_inds, weights)
            self.pko = pko
            if 'maxiter' in kwargs: maxiter = int(kwargs['maxiter'])
            else: maxiter = None

            Y = np.array(self.Y).ravel(order='F')
            self.bestloss = float("inf")

            def mv(v):
                return pko.matvec(v) + regparam * v

            def mvr(v):
                raise Exception('This function should not be called!')

            def cgcb(v):
                if self.compute_risk:
                    P = sampled_kronecker_products.sampled_vec_trick(
                        v, K2, K1, self.input2_inds, self.input1_inds,
                        self.input2_inds, self.input1_inds)
                    z = (Y - P)
                    Ka = sampled_kronecker_products.sampled_vec_trick(
                        v, K2, K1, self.input2_inds, self.input1_inds,
                        self.input2_inds, self.input1_inds)
                    loss = (np.dot(z, z) + regparam * np.dot(v, Ka))
                    print("loss", 0.5 * loss)
                    if loss < self.bestloss:
                        self.A = v.copy()
                        self.bestloss = loss
                else:
                    self.A = v
                if not self.callbackfun is None:
                    self.predictor = KernelPairwisePredictor(
                        self.A, self.pko.col_inds_K1, self.pko.col_inds_K2,
                        self.pko.weights)
                    self.callbackfun.callback(self)

            G = LinearOperator((self.Y.shape[0], self.Y.shape[0]),
                               matvec=mv,
                               rmatvec=mvr,
                               dtype=np.float64)
            self.A = minres(G,
                            self.Y,
                            maxiter=maxiter,
                            callback=cgcb,
                            tol=1e-20)[0]
            self.predictor = KernelPairwisePredictor(self.A,
                                                     self.pko.col_inds_K1,
                                                     self.pko.col_inds_K2,
                                                     self.pko.weights)
        else:  #Primal case. Does not work with the operator interface yet.
            self.input1_inds = np.array(kwargs["label_row_inds"],
                                        dtype=np.int32)
            self.input2_inds = np.array(kwargs["label_col_inds"],
                                        dtype=np.int32)
            X1 = kwargs['X1']
            X2 = kwargs['X2']
            self.X1, self.X2 = X1, X2

            if 'maxiter' in kwargs: maxiter = int(kwargs['maxiter'])
            else: maxiter = None

            if isinstance(X1, (list, tuple)):
                raise NotImplementedError(
                    "Got list or tuple as X1 but multiple kernel learning has not been implemented for the primal case yet."
                )
                if 'weights' in kwargs: weights = kwargs['weights']
                else: weights = np.ones((len(X1)))
                x1tsize, x1fsize = X1[0].shape  #m, d
                x2tsize, x2fsize = X2[0].shape  #q, r
            else:
                weights = None
                x1tsize, x1fsize = X1.shape  #m, d
                x2tsize, x2fsize = X2.shape  #q, r

            kronfcount = x1fsize * x2fsize

            Y = np.array(self.Y).ravel(order='F')
            self.bestloss = float("inf")

            def mv(v):
                v_after = sampled_kronecker_products.sampled_vec_trick(
                    v, X2, X1, self.input2_inds, self.input1_inds)
                v_after = sampled_kronecker_products.sampled_vec_trick(
                    v_after, X2.T, X1.T, None, None, self.input2_inds,
                    self.input1_inds) + regparam * v
                return v_after

            def mvr(v):
                raise Exception('This function should not be called!')

            def cgcb(v):
                if self.compute_risk:
                    P = sampled_kronecker_products.sampled_vec_trick(
                        v, X2, X1, self.input2_inds, self.input1_inds)
                    z = (Y - P)
                    loss = (np.dot(z, z) + regparam * np.dot(v, v))
                    if loss < self.bestloss:
                        self.W = v.copy().reshape((x1fsize, x2fsize),
                                                  order='F')
                        self.bestloss = loss
                else:
                    self.W = v.reshape((x1fsize, x2fsize), order='F')
                if not self.callbackfun is None:
                    self.predictor = LinearPairwisePredictor(self.W)
                    self.callbackfun.callback(self)

            G = LinearOperator((kronfcount, kronfcount),
                               matvec=mv,
                               rmatvec=mvr,
                               dtype=np.float64)
            v_init = np.array(self.Y).reshape(self.Y.shape[0])
            v_init = sampled_kronecker_products.sampled_vec_trick(
                v_init, X2.T, X1.T, None, None, self.input2_inds,
                self.input1_inds)

            v_init = np.array(v_init).reshape(kronfcount)
            if 'warm_start' in kwargs:
                x0 = np.array(kwargs['warm_start']).reshape(kronfcount,
                                                            order='F')
            else:
                x0 = None
            minres(G, v_init, x0=x0, maxiter=maxiter, callback=cgcb,
                   tol=1e-20)[0].reshape((x1fsize, x2fsize), order='F')
            self.predictor = LinearPairwisePredictor(self.W, self.input1_inds,
                                                     self.input2_inds, weights)
            if not self.callbackfun is None:
                self.callbackfun.finished(self)