예제 #1
0
파일: xlearn.py 프로젝트: memoiry/xlearn
def create_ffm():
    """
	Create a field-aware factorization machine.
	"""
    model_type = 'ffm'
    handle = XLearnHandle()
    _check_call(_LIB.XLearnCreate(c_str(model_type), ctypes.byref(handle)))
    return XLearn(handle)
예제 #2
0
파일: xlearn.py 프로젝트: memoiry/xlearn
def create_linear():
    """
	Create a linear model.
	"""
    model_type = 'linear'
    handle = XLearnHandle()
    _check_call(_LIB.XLearnCreate(c_str(model_type), ctypes.byref(handle)))
    return XLearn(handle)
예제 #3
0
	def setTrain(self, train_path):
		"""Set file path of training data.

		Parameters
		----------
		train_path : str
		   the path of training data
		"""
		_check_call(_LIB.XLearnSetTrain(ctypes.byref(self.handle), c_str(train_path)))
예제 #4
0
	def setTest(self, test_path):
		"""Set file path of test data.

		Parameters
		----------
		test_path : str
		   the path of test data.
		"""
		_check_call(_LIB.XLearnSetTest(ctypes.byref(self.handle), c_str(test_path)))
예제 #5
0
	def setValidate(self, val_path):
		"""Set file path of validation data.

		Parameters
		----------
		val_path : str
		   the path of validation data.
		"""
		_check_call(_LIB.XLearnSetValidate(ctypes.byref(self.handle), c_str(val_path)))
예제 #6
0
	def cv(self, param):
		""" Do cross-validation

		Parameters
		----------
		param : dict
		  hyper-parameter used by xlearn
		"""
		self._set_Param(param)
		_check_call(_LIB.XLearnCV(ctypes.byref(self.handle)))
예제 #7
0
	def predict(self, model_path, out_path):
		"""Predict output

        Parameters
        ----------
        model_path : str
          path of model checkpoint.
        out_path : str
          path of output result.
		"""
		_check_call(_LIB.XLearnPredict(ctypes.byref(self.handle), 
			c_str(model_path), c_str(out_path)))
예제 #8
0
	def fit(self, param, model_path):
		"""Check hyper-parameters, train model, and dump model.

		Parameters
		----------
		param : dict
		  hyper-parameter used by xlearn.
		model_path : str
		  path of model checkpoint.
		"""
		self._set_Param(param)
		_check_call(_LIB.XLearnFit(ctypes.byref(self.handle), c_str(model_path)))
예제 #9
0
	def setQuiet(self):
		"""Set xlearn to quiet model"""
		key = 'quiet'
		_check_call(_LIB.XLearnSetBool(self.handle, 
			c_str(key), ctypes.c_bool(True)))
예제 #10
0
	def setOnDisk(self):
		"""Set xlearn to use on-disk training"""
		key = 'on_disk'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
    		c_str(key), ctypes.c_bool(True)))
예제 #11
0
	def disableNorm(self):
		"""Disable instance-wise normalization"""
		key = 'norm'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
			c_str(key), ctypes.c_bool(False)))
예제 #12
0
	def disableLockFree(self):
		"""Disable lock free training"""
		key = 'lock_free'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
			c_str(key), ctypes.c_bool(False)))
예제 #13
0
	def show(self):
		"""Show model information
		"""
		_check_call(_LIB.XLearnShow(ctypes.byref(self.handle)))
예제 #14
0
	def _set_Param(self, param):
		"""Set hyper-parameter for xlearn handle

		Parameters
		----------
		param : dict
		    xlearn hyper-parameters
		"""
		for (key, value) in param.items():
			if key == 'task':
				_check_call(_LIB.XLearnSetStr(ctypes.byref(self.handle), 
					c_str(key), c_str(value)))
			elif key == 'metric':
				_check_call(_LIB.XLearnSetStr(ctypes.byref(self.handle), 
					c_str(key), c_str(value)))
			elif key == 'log':
				_check_call(_LIB.XLearnSetStr(ctypes.byref(self.handle), 
					c_str(key), c_str(value)))
			elif key == 'lr':
				_check_call(_LIB.XLearnSetFloat(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_float(value)))
			elif key == 'k':
				_check_call(_LIB.XLearnSetInt(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_uint(value)))
			elif key == 'lambda':
				_check_call(_LIB.XLearnSetFloat(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_float(value)))
			elif key == 'init':
				_check_call(_LIB.XLearnSetFloat(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_float(value)))
			elif key == 'epoch':
				_check_call(_LIB.XLearnSetInt(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_uint(value)))
			elif key == 'fold':
				_check_call(_LIB.XLearnSetInt(ctypes.byref(self.handle), 
					c_str(key), ctypes.c_uint(value)))
			else:
				raise Exception("Invalid key!", key)
예제 #15
0
def hello():
	"""
	Say hello to user
	"""
	_check_call(_LIB.XLearnHello())
예제 #16
0
	def __del__(self):
		_check_call(_LIB.XLearnHandleFree(ctypes.byref(self.handle)))
예제 #17
0
	def disableEarlyStop(self):
		"""Disable early-stopping"""
		key = 'early_stop'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
			c_str(key), ctypes.c_bool(False)))
예제 #18
0
	def setSigmoid(self):
		"""Convert output by using sigmoid"""
		key = 'sigmoid'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
			c_str(key), ctypes.c_bool(True)))
예제 #19
0
	def setSign(self):
		"""Convert output to 0 and 1"""
		key = 'sign'
		_check_call(_LIB.XLearnSetBool(ctypes.byref(self.handle), 
			c_str(key), ctypes.c_bool(True)))