def _fit_on_fly(self, xte, yte): if self.algo_type == 'cd': res = dsvm.coord_descent(xtr=self.xtr, ytr=self.ytr, kernel=self.kernel, xte=xte, yte=yte, verbose=self.verbose, lmda=self.lmda, nsweep=np.int(self.nsweep)) elif self.algo_type == 'scg_da': res = dsvm.stocda_on_fly(xtr=self.xtr, ytr=self.ytr, kernel=self.kernel, xte=xte, yte=yte, lmda=self.lmda, rho=self.rho, verbose=np.int(self.verbose), nsweep=np.int(self.nsweep), b=np.int(self.b), c=np.int(self.c)) self.alpha, self.err_tr, self.err_te, self.obj, self.nker_opers = res
def fit(self, xtr, ytr, xte=None, yte=None): self.xtr = xtr self.ytr = ytr self.construct_dataset(xtr, ytr, xte, yte) if self.algo_type == 'cd': res = dsvm.coord_descent(self.dataset, nsweep=np.int(self.nsweep), lmda=self.lmda, verbose=self.verbose) self.alpha, self.err_tr, self.err_te, self.obj, self.nker_opers = res elif self.algo_type == 'scg_da': res = dsvm.coord_dual_averaging(self.dataset, verbose=self.verbose, lmda=self.lmda, b=int(self.b), c=int(self.c), nsweep=np.int(self.nsweep), rho=self.rho) self.alpha, self.err_tr, self.err_te, self.obj, self.nker_opers = res elif self.algo_type == 'sbmd': res = dsvm.coord_mirror_descent(self.dataset, verbose=self.verbose, lmda=self.lmda, b=int(self.b), c=int(self.c), nsweep=np.int(self.nsweep), rho=self.rho) self.alpha, self.err_tr, self.err_te, self.obj, self.nker_opers, self.err_tr2, self.obj2= res else: raise NotImplementedError
def _fit_precom(self, xte, yte): # ---------------precompute the kernel----------------- self.ktr = self.kernel_matrix(self.xtr) if xte is not None: self.kte = self.kernel_matrix(xte, self.xtr) if self.nsweep is None: if self.algo_type == 'scg_da': self.nsweep = self.xtr.shape[0] elif self.algo_type == 'cd': self.nsweep = self.xtr.shape[0] if self.algo_type == 'cd': res = dsvm.coord_descent(ktr=self.ktr, ytr=self.ytr, kte=self.kte, yte=yte, verbose=self.verbose, lmda=self.lmda, nsweep=np.int(self.nsweep)) elif self.algo_type == 'scg_da': res = dsvm.stoc_dual_averaging(ktr=self.ktr, ytr=self.ytr, kte=self.kte, yte=yte, lmda=self.lmda, rho=self.rho, verbose=self.verbose, nsweep=np.int(self.nsweep), b=np.int(self.b), c=np.int(self.c)) self.alpha, self.err_tr, self.err_te, self.obj, self.nker_opers = res