Exemplo n.º 1
0
class Experiment:
	data = None
	class_index = -1
	classifier = None
	attrs = []

	def __init__(self):
#		jvm.start(max_heap_size="2500M")
		pass

	def out(self, x):
		print x.__str__().encode('ascii', 'ignore')

	def loadCSV(self, filename, path='/home/sbiastoch/Schreibtisch/csv_files/'):
		weka_loader = Loader(classname="weka.core.converters.CSVLoader")
		self.data = weka_loader.load_file(path+filename)

	def setClassIndex(self, index):
		if index < 0:
			self.data.class_index = self.data.num_attributes + index
		else:
			self.data.class_index = index

	def train_J48(self, min_per_rule=20):
		params = [
			'-C','0.3',
			'-M',str(min_per_rule),
	#		'-N',str(folds),
	#		'-R',
		]
		self.base_classifier = Classifier(classname='weka.classifiers.trees.J48', options=params)
		self._train()

	def train_JRip(self, min_per_rule=20, optimizations=2, folds=3, seed=42):
		params = [
			'-F', str(folds), # folds
			'-N', str(min_per_rule), # min elements per rule
			'-O', str(optimizations), # optimizations
			'-S', str(seed) #seed
		] 
		self.base_classifier = Classifier(classname='weka.classifiers.rules.JRip', options=params)
		self._train()

	def _train(self):
		params = [
			'-F','weka.filters.unsupervised.attribute.RemoveByName -E ^('+'|'.join(self.attrs)+')$ -V',
			'-W', self.base_classifier.classname, '--',
			]
		params.extend(self.base_classifier.options)


#		self.classifier = Classifier(classname='weka.classifiers.meta.FilteredClassifier', options=params)
		self.classifier = FilteredClassifier(options=params)
	#	self.classifier.filter(Filter(classname="weka.filters.unsupervised.attribute.RemoveByName", options=['-E','^('+'|'.join(self.attrs)+')$','-V']))
		self.classifier.build_classifier(self.data)
		self.out(self.classifier.__str__().encode('ascii', 'ignore').split("\n")[-2])

	def test(self, folds = 10):
		evaluation = Evaluation(self.data)                     # initialize with priors
		evaluation.crossvalidate_model(self.classifier, self.data, folds, Random(42))  # 10-fold CV
		print('Total number of instances: '+str(evaluation.num_instances)+'.')
		print(str(round(evaluation.percent_correct,2))+'% / '+str(round(evaluation.correct, 2))+' correct.')
		print(str(round(evaluation.percent_incorrect,2))+'% / '+str(round(evaluation.incorrect, 2))+' incorrect.')
		
	def saveCSV(self, filename, path='/home/sbiastoch/Schreibtisch/csv_files/'):
		saver = Saver(classname="weka.core.converters.CSVSaver")
		saver.save_file(self.data, path+filename)

	def loadClassifier(self, filename, path='/home/sbiastoch/Schreibtisch/classifiers/'):
		objects = serialization.read_all(path+filename)
		self.classifier = Classifier(jobject=objects[0])
		#self.data = Instances(jobject=objects[1])

	def saveClassifier(self, filename, path='/home/sbiastoch/Schreibtisch/classifiers/'):
		serialization.write_all(path+filename, [self.classifier, Instances.template_instances(self.data)])


	def remove_correct_classified(self, invert = False):
		options=[
			'-W', self.classifier.to_commandline(), 
			'-C', str(self.class_index), #classindex
	#		'-F','0', # folds
	#		'-T','0.1', #threshold by numeric classes
			'-I','0', # max iterations
			'-V' if not invert else '' 
		] # invert
		classname = "weka.filters.unsupervised.instance.RemoveMisclassified"
		remove = Filter(classname=classname, options=options)
		remove.inputformat(self.data)
		self.data = remove.filter(self.data)

	def remove_incorrect_classified(self):
		self.remove_correct_classified(True)

	def set_attributes(self, attrs):
		self.attrs = attrs

	def select_missclassified(self):
		remove = Filter(classname="weka.filters.supervised.attribute.AddClassification", options=['-classification' ,'-error' ,'-W' ,self.base_classifier.to_commandline()])
		remove.inputformat(self.data)
		self.data = remove.filter(self.data)

		remove = Filter(classname="weka.filters.unsupervised.instance.RemoveWithValues", options=['-S','0.0','-C','last','-L','last','-V'])
		remove.inputformat(self.data)

		remove = Filter(classname="weka.filters.unsupervised.attribute.Remove", options=['-R',str(self.data.num_attributes-2)+',last'])
		remove.inputformat(self.data)
		self.data = remove.filter(self.data)

	def merge_nominal_attributes(self, significance=0.01):
		remove = Filter(classname="weka.filters.supervised.attribute.MergeNominalValues", options=['-L',str(significance),'-R','first-last'])
		remove.inputformat(self.data)
		self.data = remove.filter(self.data)