Exemple #1
0
	def valid(self,datasets,opt,opp,method = fold,part_ids = None,seed = None,test_data = None):
		if seed is None:
			# If seed is not set. UNIX time is used as seed.
			seed = time.time()
		saving_seed = "%s/log/%s.log.seed" % (self._dir,self._name)
		with open(saving_seed,"w") as fp:
			# Save used seed value.
			fp.write("seed:%f\n" % seed)
		
		if part_ids is None:
			part_ids = datasets.pids
		groups = [(test,train) for test,train in method(part_ids,seed = seed)]
		
		for cnt,pdtsts in enumerate(groups):
			# cnt is number of cluster.
			if test_data is None:
				test = False
				ltest,dtest,itest = test2svm_prob(datasets.mkTest(pdtsts[0]))
			else:
				test = True
				ltest,dtest,itest = test2svm_prob(test_data.mkTest(test_data.pids))

			print "start %s validation" % (cnt)
			ptrn,itrain = train2svm_prob(datasets.mkTrain(pdtsts[1]))
			#opt = svm.svm_parameter(opt)
			model = svmutil.svm_train(ptrn,opt)
			
			plbl,pacc,pval = svmutil.svm_predict(ltest,dtest,model,opp)

			# create saving direcotry
			#self._mkdir(cnt)
			# create log files
			self._save_log(itest,plbl,pval,cnt,test)
			model_name = "%s/model/%s.model.%s" % (self._dir,self._name,cnt)
Exemple #2
0
 def __init__(self, train_feature_file=TRAIN_FEATURE_FILE):
     if os.path.exists(SAVED_MODEL):
         self.model = svmutil.svm_load_model(SAVED_MODEL)
     else:
         y, x = svmutil.svm_read_problem(train_feature_file)
         self.model = svmutil.svm_train(y, x, '-c 4')
         svmutil.svm_save_model(SAVED_MODEL, self.model)
Exemple #3
0
def svm():
    # Training Parameters

    # Defines how high the cost is of a misclassification
    # versus making the decision plane more complex.
    # Low COST makes decisions very simple but creates classification errors
    COST = 0.9

    # Used for generalisation
    # - Low GAMMA means high generalisation
    # - High GAMMA is closer to original dataset
    GAMMA = 6

    KERNEL = RBF
    svm_model.predict = lambda self, x: svm_predict([0], [x], self)[0][0]

    # Get the data
    SPARSE_LENGTH = 16
    sparseCodings = sparse_coding.generateFull(SPARSE_LENGTH)
    dataset, data, outputs, classes = sparse_coding.toSVMProblem(sparseCodings)

    # Set the parameters for the SVM
    parameters = svm_parameter()
    parameters.kernel_type = KERNEL
    parameters.C = COST
    parameters.gamma = GAMMA

    # Train the SVM
    solver = svm_train(dataset, parameters)

    # Create the output path if it doesn't exist
    generated_dir = path.abspath(
        path.join(
            "generated",
            "Q2Task1-TrainedSVM-{}".format(strftime("%Y-%m-%d_%H-%M-%S"))))
    if not path.exists(generated_dir):
        makedirs(generated_dir)

    uniqueFileName = path.normpath(path.join(generated_dir, "data.pkl"))
    svm_save_model(uniqueFileName, solver)

    # Compare the results to the extected values
    figure = plot.figure()
    axis = figure.add_subplot(111)
    colors = ['r', 'y', 'g', 'c', 'b', 'k']

    for sample in sparseCodings:
        classifier = sparse_coding.getClassifier(sample)
        activationResult = svm_predict([0.], [sample], solver, '-q')[0][0]
        axis.bar(classifier,
                 activationResult,
                 color=colors[classifier % len(colors)])

    plot.savefig(path.normpath(path.join(generated_dir, "activations.png")))
    plot.show()
Exemple #4
0
	def create_model(self,datasets,opt,opp,part_ids = None):
		# Should groups and ngroups be idch ?
		if part_ids is None:
			part_ids = datasets.pids
		ptrn,itrain = train2svm_prob(datasets.mkTrain(part_ids))
		print "create model ..."
		#opt = svm.svm_parameter(opt)
		model = svmutil.svm_train(ptrn,opt)
		# create saving direcotry
		#self._mkdir(cnt)
		# create log files
		#self._save_log(itest,plbl,pval,cnt)
		model_name = "%s/model/%s.model" % (self._dir,self._name)
		svmutil.svm_save_model(model_name, model)
Exemple #5
0
	def valid(self,datasets,opt,opp,method = fold,seed = None):
		# Should groups and ngroups be idch ?
		groups = [(test,train) for test,train in method(datasets.pids,seed = seed)]
		ngroups = [(test,train) for test,train in method(datasets.nids,seed = seed)]
		
		for cnt,(pdtsts,ndtsts) in enumerate(zip(groups,ngroups)):
			# cnt is number of cluster.
			ltest,dtest,itest = test2svm_prob(datasets.mkTest(pdtsts[0],ndtsts[0]))
			ptrn,itrain = train2svm_prob(datasets.mkTrain(pdtsts[1],ndtsts[1]))
			
			print "start %s validation" % (cnt)
			#opt = svm.svm_parameter(opt)
			model = svmutil.svm_train(ptrn,opt)
			plbl,pacc,pval = svmutil.svm_predict(ltest,dtest,model,opp)


			# create saving direcotry
			#self._mkdir(cnt)
			# create log files
			self._save_log(itest,plbl,pval,cnt)
			model_name = "%s/model/%s.model.%s" % (self._dir,self._name,cnt)
Exemple #6
0
def taska():
    # Training Parameters

    # Defines how high the cost is of a misclassification
    # versus making the decision plane more complex.
    # Low COST makes decisions very simple but creates classification errors
    COST = 0.9

    # Used for generalisation
    # - Low GAMMA means high generalisation
    # - High GAMMA is closer to original dataset
    GAMMA = 6

    # Get the data
    dataset, data, outputs, classes = csv.loadSVMProblem(
        "spirals\\SpiralOut.txt")

    # Set the parameters for the SVM
    parameters = svm_parameter()
    parameters.kernel_type = KERNEL
    parameters.C = COST
    parameters.gamma = GAMMA

    # Train the SVM
    start = time()
    solver = svm_train(dataset, parameters)
    end = time()

    trainingTime = end - start
    print trainingTime

    uniqueFileName = "generated\\Q1DTaskA-TrainedSVM-" + strftime(
        "%Y-%m-%d_%H-%M-%S") + '.pkl'
    svm_save_model(uniqueFileName, solver)

    # Compare the results to the extected values
    plot = plotSVM(solver, -6.0, 6.0, 0.2)
    plot.show()
Exemple #7
0
def taska():
    # Training Parameters
    
    # Defines how high the cost is of a misclassification
    # versus making the decision plane more complex.
    # Low COST makes decisions very simple but creates classification errors
    COST = 0.9

    # Used for generalisation
    # - Low GAMMA means high generalisation
    # - High GAMMA is closer to original dataset
    GAMMA = 6

    # Get the data
    dataset, data, outputs, classes = csv.loadSVMProblem("spirals\\SpiralOut.txt")
    
    # Set the parameters for the SVM
    parameters = svm_parameter()
    parameters.kernel_type = KERNEL
    parameters.C = COST
    parameters.gamma = GAMMA

    # Train the SVM
    start = time()
    solver = svm_train(dataset, parameters)
    end = time()

    trainingTime = end - start
    print trainingTime

    uniqueFileName = "generated\\Q1DTaskA-TrainedSVM-" + strftime("%Y-%m-%d_%H-%M-%S") + '.pkl'
    svm_save_model(uniqueFileName,solver)

    # Compare the results to the extected values
    plot = plotSVM(solver, -6.0, 6.0, 0.2)
    plot.show()