def _block_proximal_gradient_descent(self, W, P):
        WX, XP, XSPS = self._get_XC_prods(self.train_X, W, P)
        err = self._cal_err(WX, XP, XSPS, self.train_Y)
        objs = [self._obj(err, W, P)]

        WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
        test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
        rmses = [cal_rmse(test_err)]
        maes = [cal_mae(test_err)]

        start = time.time()

        eta = self.eta
        for t in range(self.max_iters):
            start = time.time()

            l_obj, eta, lt, W, P = self._get_updated_paras(eta, W, P)

            if lt == self.ln:
                logging.info('!!!stopped by line_search, lt=%s!!!', lt)
                break

            objs.append(l_obj)

            WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
            test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
            rmses.append(cal_rmse(test_err))
            maes.append(cal_mae(test_err))
            end = time.time()

            dr = abs(objs[t] - objs[t + 1]) / objs[t]
            logging.info(
                'exp_id=%s, iter=%s, lt,eta,dr=(%s,%s, %.7f), obj=%.5f, rmse=%.5f, mae=%.5f, cost=%.2f seconds',
                self.exp_id, t, lt, eta, dr, objs[t], rmses[t], maes[t],
                (end - start))
            if dr < self.eps:
                logging.info(
                    '*************stopping criterion satisfied*********')
                break

        logging.info('train process finished, total iters=%s', t + 1)
        self.rmses, self.maes = rmses, maes
        self._save_paras(W, P)
    def _block_nonmono_acc_proximal_gradient_descent(self, W, P):
        '''
            non-monotone accelerated pg
        '''
        logging.info(
            'start solving by _block_nonmono_acc_proximal_gradient_descent')
        WX, XP, XSPS = self._get_XC_prods(self.train_X, W, P)
        err = self._cal_err(WX, XP, XSPS, self.train_Y)
        objs = [None] * (self.max_iters + 1)
        objs[0] = self._obj(err, W, P)

        WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
        test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
        rmses = [cal_rmse(test_err)]
        maes = [cal_mae(test_err)]

        start = time.time()

        A = np.hstack((W.reshape(-1, 1), P))
        A0, A1, C1 = A.copy(), A.copy(), A.copy()
        c = objs[0]
        r0, r1, q, qeta = 0.0, 1.0, 1.0, 0.5
        eta1 = eta2 = self.eta
        lt1, lt2 = 0, 0

        XS = np.square(self.train_X)
        for t in range(self.max_iters):
            start = time.time()
            self._update_bias(W, P)
            use_acc = False

            B = A1 + r0 / r1 * (C1 - A1) + (r0 - 1) / r1 * (A1 - A0)
            W, P = B[:, 0].flatten(), B[:, 1:]
            y_obj, y_eta, y_lt, yW, yP = self._get_updated_paras(eta1, W, P)
            lt1, eta1 = y_lt, y_eta
            C1 = np.hstack((yW.reshape(-1, 1), yP))

            if y_obj < c:
                objs[t + 1] = y_obj
                W, P = yW, yP
                use_acc = True
            else:
                W, P = A1[:, 0].flatten(), A1[:, 1:]
                v_obj, v_eta, v_lt, vW, vP = self._get_updated_paras(
                    eta2, W, P)
                lt2, eta2 = v_lt, v_eta

                if y_obj < v_obj:
                    objs[t + 1] = y_obj
                    W, P = yW, yP
                else:
                    objs[t + 1] = v_obj
                    W, P = vW, vP

            if lt1 == self.ln or lt2 == self.ln:
                logging.info('!!!stopped by line_search, lt1=%s, lt2=%s!!!',
                             lt1, lt2)
                break

            A0 = A1
            A1 = np.hstack((W.reshape(-1, 1), P))

            r0 = r1
            r1 = (np.sqrt(4 * pow(r0, 2) + 1) + 1) / 2.0
            tq = qeta * q + 1.0
            c = (qeta * q * c + objs[t + 1]) / tq
            q = tq

            WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
            test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
            rmses.append(cal_rmse(test_err))
            maes.append(cal_mae(test_err))
            end = time.time()

            dr = abs(objs[t] - objs[t + 1]) / objs[t]
            logging.info(
                'exp_id=%s, iter=%s, use_acc=%s, (lt1,eta1,lt2,eta2)=(%s,%s,%s,%s), obj=%.5f(%.8f), rmse=%.5f, mae=%.5f, cost=%.2f seconds',
                self.exp_id, t, use_acc, lt1, eta1, lt2, eta2, objs[t], dr,
                rmses[t], maes[t], (end - start))
            if dr < self.eps:
                logging.info(
                    '*************stopping criterion satisfied*********')
                break

        logging.info('train process finished, total iters=%s', t + 1)
        self.rmses, self.maes = rmses, maes
        self._save_paras(W, P)
    def _block_mono_acc_proximal_gradient_descent(self, W, P):
        '''
            monotone accelerated pg
        '''

        logging.info(
            'start solving by _block_mono_acc_proximal_gradient_descent')
        WX, XP, XSPS = self._get_XC_prods(self.train_X, W, P)
        err = self._cal_err(WX, XP, XSPS, self.train_Y)
        objs = [self._obj(err, W, P)]

        WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
        test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
        rmses = [cal_rmse(test_err)]
        maes = [cal_mae(test_err)]

        start = time.time()

        A = np.hstack((W.reshape(-1, 1), P))
        A0, A1, C1 = A.copy(), A.copy(), A.copy()
        r0, r1 = 0, 1

        eta = self.eta
        XS = np.square(self.train_X)
        for t in range(self.max_iters):
            start = time.time()

            v_obj, v_eta, v_lt, vW, vP = self._get_updated_paras(eta, W, P)

            B = A1 + r0 / r1 * (C1 - A1) + (r0 - 1) / r1 * (A1 - A0)
            W, P = B[:, 0].flatten(), B[:, 1:]
            y_obj, y_eta, y_lt, yW, yP = self._get_updated_paras(eta, W, P)
            C1 = np.hstack((yW.reshape(-1, 1), yP))

            if v_obj > y_obj:
                objs.append(y_obj)
                lt = y_lt
                eta = y_eta
                W, P = yW, yP
            else:
                objs.append(v_obj)
                lt = v_lt
                eta = v_eta
                W, P = vW, vP

            if lt == self.ln:
                logging.info('!!!stopped by line_search, lt=%s!!!', lt)
                break

            A0 = A1
            A1 = np.hstack((W.reshape(-1, 1), P))
            r1 = (np.sqrt(4 * pow(r0, 2) + 1) + 1) / 2.0

            WtX, tXP, tXSPS = self._get_XC_prods(self.test_X, W, P)
            test_err = self._cal_err(WtX, tXP, tXSPS, self.test_Y)
            rmses.append(cal_rmse(test_err))
            maes.append(cal_mae(test_err))
            end = time.time()

            dr = abs(objs[t] - objs[t + 1]) / objs[t]
            logging.info(
                'exp_id=%s, iter=%s, lt,eta,dr=(%s,%s, %.7f), obj=%.5f, rmse=%.5f, mae=%.5f, cost=%.2f seconds',
                self.exp_id, t, lt, eta, dr, objs[t], rmses[t], maes[t],
                (end - start))
            if dr < self.eps:
                logging.info(
                    '*************stopping criterion satisfied*********')
                break

        logging.info('train process finished, total iters=%s', t + 1)
        self.rmses, self.maes = rmses, maes
        self._save_paras(W, P)