def test_RandomizeActivities(self): class RunDetails(object): shuffled = False randomized = False random.seed(0) details = RunDetails() self.setUpGeneralLoad() dataSet = self.d orgActivities = [d[-1] for d in dataSet] DataUtils.RandomizeActivities(dataSet, shuffle=True, runDetails=details) self.assertNotEqual(orgActivities, [d[-1] for d in dataSet]) self.assertEqual(sorted(orgActivities), sorted([d[-1] for d in dataSet])) self.assertTrue(details.shuffled) self.assertFalse(details.randomized) try: details = RunDetails() self.setUpGeneralLoad() dataSet = self.d orgActivities = [d[-1] for d in dataSet] DataUtils.RandomizeActivities(dataSet, shuffle=False, runDetails=details) self.assertNotEqual(orgActivities, [d[-1] for d in dataSet]) self.assertEqual(sorted(orgActivities), sorted([d[-1] for d in dataSet])) self.assertFalse(details.randomized) self.assertTrue(details.shuffled) except NameError: # This code branch is not working. pass
def test5(self): """ indicesToUse """ probes = [ (.5, 4, 2), (.7, 3, 3), (.75, 3, 3), (.333, 6, 0), (.25, 4, 2), ] nPts = len(self.d1) for frac, nKeep, nRej in probes: DataUtils.InitRandomNumbers((23, 42)) k, r = DataUtils.FilterData(self.d1, 1, frac, indicesToUse=range(nPts)) assert len(k) == nKeep, 'bad nKeep (%d != %d)' % (len(k), nKeep) assert len(r) == nRej, 'bad nRej (%d != %d)' % (len(r), nRej) keep, rej = k, r # make sure the examples are actually correct DataUtils.InitRandomNumbers((23, 42)) tgtKeep, tgtRej = DataUtils.FilterData(self.d1, 1, frac) assert keep == tgtKeep, '%.2f: %s!=%s' % (frac, str(keep), str(tgtKeep)) assert rej == tgtRej, '%.2f: %s!=%s' % (frac, str(rej), str(tgtRej))
def test4_indicesOnly_indicesToUse(self): # """ indicesOnly with indicesToUse """ probes = [ (.5, 4, 2), (.7, 3, 3), (.75, 3, 3), (.333, 6, 0), (.25, 4, 2), ] nPts = len(self.d1) for frac, nKeep, nRej in probes: DataUtils.InitRandomNumbers((23, 42)) k, r = DataUtils.FilterData(self.d1, 1, frac, indicesToUse=range(nPts), indicesOnly=1) assert len(k) == nKeep, 'bad nKeep (%d != %d)' % (len(k), nKeep) assert len(r) == nRej, 'bad nRej (%d != %d)' % (len(r), nRej) # make sure the indices are actually correct keep = [self.d1[x] for x in k] rej = [self.d1[x] for x in r] DataUtils.InitRandomNumbers((23, 42)) tgtKeep, tgtRej = DataUtils.FilterData(self.d1, 1, frac) assert keep == tgtKeep, '%.2f: %s!=%s' % (frac, str(keep), str(tgtKeep)) assert rej == tgtRej, '%.2f: %s!=%s' % (frac, str(rej), str(tgtRej))
def RunIt(details, progressCallback=None, saveIt=1, setDescNames=0): """ does the actual work of building a composite model **Arguments** - details: a _CompositeRun.CompositeRun_ object containing details (options, parameters, etc.) about the run - progressCallback: (optional) a function which is called with a single argument (the number of models built so far) after each model is built. - saveIt: (optional) if this is nonzero, the resulting model will be pickled and dumped to the filename specified in _details.outName_ - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method will be called using the results of the data set's _GetVarNames()_ method; it is assumed that the details object has a _descNames attribute which is passed to the composites _SetDescriptorNames()_ method. Otherwise (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_. **Returns** the composite model constructed """ details.rundate = time.asctime() fName = details.tableName.strip() if details.outName == '': details.outName = fName + '.pkl' if not details.dbName: if details.qBounds != []: data = DataUtils.TextFileToData(fName) else: data = DataUtils.BuildQuantDataSet(fName) elif details.useSigTrees or details.useSigBayes: details.tableName = fName data = details.GetDataSet(pickleCol=0, pickleClass=DataStructs.ExplicitBitVect) elif details.qBounds != [] or not details.useTrees: details.tableName = fName data = details.GetDataSet() else: data = DataUtils.DBToQuantData( details.dbName, # Function no longer defined fName, quantName=details.qTableName, user=details.dbUser, password=details.dbPassword) composite = RunOnData(details, data, progressCallback=progressCallback, saveIt=saveIt, setDescNames=setDescNames) return composite
def test1NaiveBayes(self): fName = os.path.join(RDConfig.RDCodeDir, 'ML', 'NaiveBayes', 'test_data', 'stddata.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() nvars = data.GetNVars() attrs = range(1, nvars + 1) npvals = [0] + [3] * nvars + [2] qBounds = [0] + [2] * nvars + [0] mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, qBounds, silent=True) self.assertAlmostEqual(mod._classProbs[0], 0.5000, 4) self.assertAlmostEqual(mod._classProbs[1], 0.5000, 4) self.assertAlmostEqual(mod._QBoundVals[1][0], -0.0360, 4) self.assertAlmostEqual(mod._QBoundVals[1][1], 0.114) self.assertAlmostEqual(mod._QBoundVals[2][0], -0.7022, 4) self.assertAlmostEqual(mod._QBoundVals[2][1], -0.16635, 4) self.assertAlmostEqual(mod._QBoundVals[3][0], -0.3659, 4) self.assertAlmostEqual(mod._QBoundVals[3][1], 0.4305, 4) self.assertAlmostEqual(err, 0.2121, 4)
def test2XValClass(self): fName = os.path.join(RDConfig.RDCodeDir, 'ML', 'KNN', 'test_data', 'random_pts.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() npvals = data.GetNPossibleVals() nvars = data.GetNVars() attrs = list(range(1, nvars + 1)) numNeigh = 11 mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, numNeigh, silent=1) self.assertAlmostEqual(err, 0.01075, 4) neighborList = [] res = mod.ClassifyExample(examples[0], neighborList=neighborList) self.assertEqual(res, 1) self.assertEqual(neighborList[0][1], examples[0]) self.assertEqual(mod.GetName(), '') mod.SetName('name') self.assertEqual(mod.GetName(), 'name') self.assertEqual(mod.type(), 'Classification Model') mod.NameModel('this argument is ignored') self.assertEqual(mod.GetName(), 'Classification Model')
def testPerm1(self): """ tests the descriptor remapping stuff in a packager """ from rdkit.Chem import Descriptors pkg = cPickle.load( open(os.path.join(self.dataDir, 'Jan9_build3_pkg.pkl'), 'rb')) calc = pkg.GetCalculator() names = calc.GetDescriptorNames() ref = {} DataUtils.InitRandomNumbers((23, 42)) for smi, pred, conf in self.testD: for desc in names: fn = getattr(Descriptors, desc, lambda x: 777) m = Chem.MolFromSmiles(smi) ref[desc] = fn(m) for i in range(5): perm = list(names) random.shuffle(perm) m = Chem.MolFromSmiles(smi) for desc in perm: fn = getattr(Descriptors, desc, lambda x: 777) val = fn(m) assert feq( val, ref[desc], 1e-4), '%s: %s(%s): %f!=%f' % (str(perm), smi, desc, val, ref[desc])
def testPerm1(self): """ tests the descriptor remapping stuff in a packager """ from rdkit.Chem import Descriptors with open(os.path.join(self.dataDir,'Jan9_build3_pkg.pkl'),'r') as pkgTF: buf = pkgTF.read().replace('\r\n', '\n').encode('utf-8') pkgTF.close() with io.BytesIO(buf) as pkgF: pkg = cPickle.load(pkgF) calc = pkg.GetCalculator() names = calc.GetDescriptorNames() ref = {} DataUtils.InitRandomNumbers((23,42)) for smi,pred,conf in self.testD: for desc in names: fn = getattr(Descriptors,desc,lambda x:777) m = Chem.MolFromSmiles(smi) ref[desc] = fn(m) for i in range(5): perm = list(names) random.shuffle(perm,random=random.random) m = Chem.MolFromSmiles(smi) for desc in perm: fn = getattr(Descriptors,desc,lambda x:777) val = fn(m) assert feq(val,ref[desc],1e-4),'%s: %s(%s): %f!=%f'%(str(perm), smi, desc, val, ref[desc])
def testQuantPickle(self): # " testing QuantDataSet pickling " self.setUpQuantLoad() DataUtils.WritePickledData( RDConfig.RDCodeDir + '/ML/Data/test_data/testquant.qdat.pkl', self.d) with open(RDConfig.RDCodeDir + '/ML/Data/test_data/testquant.qdat.pkl', 'rb') as f: vNames = pickle.load(f) qBounds = pickle.load(f) ptNames = pickle.load(f) examples = pickle.load(f) d = MLData.MLQuantDataSet(examples, varNames=vNames, qBounds=qBounds, ptNames=ptNames) assert self.d.GetNPts() == d.GetNPts(), 'nPts wrong' assert self.d.GetNVars() == d.GetNVars(), 'nVars wrong' assert self.d.GetNResults() == d.GetNResults(), 'nResults wrong' assert self.d.GetVarNames() == d.GetVarNames(), 'varNames wrong' assert self.d.GetPtNames() == d.GetPtNames(), 'ptNames wrong' assert self.d.GetNPossibleVals() == d.GetNPossibleVals( ), 'nPossible Wrong' assert self.d.GetQuantBounds() == d.GetQuantBounds( ), 'quantBounds Wrong' assert self.d.GetResults() == d.GetResults(), 'GetResults wrong' assert self.d.GetAllData()[1] == d.GetAllData()[1], 'GetAllData wrong' assert self.d.GetInputData()[3] == d.GetInputData( )[3], 'GetInputData wrong' assert self.d.GetNamedData()[2] == d.GetNamedData( )[2], 'GetNamedData wrong'
def test2NaiveBayes(self): fName = os.path.join(RDConfig.RDCodeDir, 'ML', 'NaiveBayes', 'test_data', 'stddata.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() nvars = data.GetNVars() attrs = list(range(1, nvars + 1)) npvals = [0] + [3] * nvars + [2] qBounds = [0] + [2] * nvars + [0] mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, qBounds, mEstimateVal=20.0, silent=True) self.assertTrue(isinstance(mod, NaiveBayesClassifier)) self.assertAlmostEqual(err, 0.1818, 4) self.assertEqual(mod.GetName(), '') mod.SetName('modelName') self.assertEqual(mod.GetName(), 'modelName') mod.NameModel(None) self.assertEqual(mod.GetName(), 'NaiveBayesClassifier') self.assertGreater(len(mod.GetExamples()), 0) self.assertGreater(len(mod.GetTrainingExamples()), 0) self.assertEqual(sorted(mod.GetTrainingExamples() + mod.GetExamples()), sorted(examples))
def testPerm1(self): # """ tests the descriptor remapping stuff in a packager """ pkg = self._loadPackage() calc = pkg.GetCalculator() names = calc.GetDescriptorNames() ref = {} DataUtils.InitRandomNumbers((23, 42)) for smi, _, _ in self.testD: for desc in names: fn = getattr(Descriptors, desc, lambda x: 777) m = Chem.MolFromSmiles(smi) ref[desc] = fn(m) for _ in range(5): perm = list(names) random.shuffle(perm, random=random.random) m = Chem.MolFromSmiles(smi) for desc in perm: fn = getattr(Descriptors, desc, lambda x: 777) val = fn(m) assert feq( val, ref[desc], 1e-4), '%s: %s(%s): %f!=%f' % (str(perm), smi, desc, val, ref[desc])
def test_CalcNPossibleUsingMap(self): self.setUpQuantLoad() order = list(range(5)) self.assertEqual( DataUtils.CalcNPossibleUsingMap(self.d.data, order, self.d.GetQuantBounds()), [3, 3, 2, 2, 2])
def setUp(self): # here is what we are going to do to test this out # - generate bit vectrs of length nbits # - turn on a fraction of the first nbits/2 bits at random # - for each bit i turned on in the range (0, nbits/2) turn on the bit # nbits/2 + i # - basically the first half of a fingerprint is same as the second half of the # fingerprint # - if we repeat this process often enough we whould see strong correlation between # the bits i (i < nbits/2) and (nbits/2 + i) DataUtils.InitRandomNumbers((100, 23)) self.nbits = 200 self.d = 40 self.nfp = 1000 self.blist = range(self.nbits) self.fps = [] for fi in range(self.nfp): fp = DataStructs.ExplicitBitVect(self.nbits) obits = range(self.nbits / 2) random.shuffle(obits) obits = obits[0:self.d] for bit in obits: fp.SetBit(bit) fp.SetBit(bit + self.nbits / 2) self.fps.append(fp)
def testQuantLoad(self): " testing QuantDataSet load" ok = 1 try: self.d = DataUtils.BuildQuantDataSet(RDConfig.RDCodeDir+'/ML/Data/test_data/test.qdat') except: ok = 0 assert ok,'BuildQuantDataSet failed'
def test_WriteData(self): self.setUpQuantLoad() with contextlib.closing(StringIO()) as f: DataUtils.WriteData(f, self.d.GetVarNames(), self.d.GetQuantBounds(), self.d.data) s = f.getvalue() self.assertIn('DataUtils', s) self.assertIn('foo1', s) self.assertIn('2 2 1 0 1', s)
def testGeneralLoad(self): " testing DataSet load" ok = 1 try: self.d = DataUtils.BuildDataSet(RDConfig.RDCodeDir+'/ML/Data/test_data/test.dat') except: ok = 0 assert ok,'BuildDataSet failed'
def _balanced_parallel_build_trees(n_trees, forest, X, y, sample_weight, sample_mask, X_argsorted, seed, verbose): """Private function used to build a batch of trees within a job""" from sklearn.utils import check_random_state from sklearn.utils.fixes import bincount import random MAX_INT = numpy.iinfo(numpy.int32).max random_state = check_random_state(seed) trees = [] for i in xrange(n_trees): if verbose > 1: print("building tree %d of %d" % (i + 1, n_trees)) seed = random_state.randint(MAX_INT) tree = forest._make_estimator(append=False) tree.set_params(compute_importances=forest.compute_importances) tree.set_params(random_state=check_random_state(seed)) if forest.bootstrap: n_samples = X.shape[0] if sample_weight is None: curr_sample_weight = numpy.ones((n_samples, ), dtype=numpy.float64) else: curr_sample_weight = sample_weight.copy() ty = list(enumerate(y)) indices = DataUtils.FilterData(ty, val=1, frac=0.5, col=1, indicesToUse=0, indicesOnly=1)[0] indices2 = random_state.randint(0, len(indices), len(indices)) indices = [indices[j] for j in indices2] sample_counts = bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts curr_sample_mask = sample_mask.copy() curr_sample_mask[sample_counts == 0] = False tree.fit(X, y, sample_weight=curr_sample_weight, sample_mask=curr_sample_mask, X_argsorted=X_argsorted, check_input=False) tree.indices = curr_sample_mask else: tree.fit(X, y, sample_weight=sample_weight, sample_mask=sample_mask, X_argsorted=X_argsorted, check_input=False) trees.append(tree) return trees
def test1NaiveBayes(self): fName = os.path.join(RDConfig.RDCodeDir, 'ML', 'NaiveBayes', 'test_data', 'stddata.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() nvars = data.GetNVars() attrs = list(range(1, nvars + 1)) npvals = [0] + [3] * nvars + [2] qBounds = [0] + [2] * nvars + [0] mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, qBounds, silent=True) self.assertAlmostEqual(mod._classProbs[0], 0.5000, 4) self.assertAlmostEqual(mod._classProbs[1], 0.5000, 4) self.assertAlmostEqual(mod._QBoundVals[1][0], -0.0360, 4) self.assertAlmostEqual(mod._QBoundVals[1][1], 0.114) self.assertAlmostEqual(mod._QBoundVals[2][0], -0.7022, 4) self.assertAlmostEqual(mod._QBoundVals[2][1], -0.16635, 4) self.assertAlmostEqual(mod._QBoundVals[3][0], -0.3659, 4) self.assertAlmostEqual(mod._QBoundVals[3][1], 0.4305, 4) self.assertAlmostEqual(err, 0.2121, 4) mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, qBounds, silent=True, calcTotalError=True) self.assertAlmostEqual(mod._classProbs[0], 0.515151, 4) self.assertAlmostEqual(mod._classProbs[1], 0.484848, 4) self.assertAlmostEqual(mod._QBoundVals[1][0], -0.40315, 4) self.assertAlmostEqual(mod._QBoundVals[1][1], 0.114) self.assertAlmostEqual(mod._QBoundVals[2][0], -0.62185, 4) self.assertAlmostEqual(mod._QBoundVals[2][1], -0.19965, 4) self.assertAlmostEqual(mod._QBoundVals[3][0], 0.4305, 4) self.assertAlmostEqual(mod._QBoundVals[3][1], 0.80305, 4) self.assertAlmostEqual(err, 0.14563, 4) mod, err = CrossValidate.CrossValidationDriver( examples, attrs, npvals, qBounds, silent=True, replacementSelection=True) self.assertAlmostEqual(mod._classProbs[0], 0.5131578, 4) self.assertAlmostEqual(mod._classProbs[1], 0.4868421, 4) self.assertAlmostEqual(mod._QBoundVals[1][0], -0.036, 4) self.assertAlmostEqual(mod._QBoundVals[1][1], 0.93465, 4) self.assertAlmostEqual(mod._QBoundVals[2][0], -0.6696, 4) self.assertAlmostEqual(mod._QBoundVals[2][1], -0.19965, 4) self.assertAlmostEqual(mod._QBoundVals[3][0], -1.06785, 4) self.assertAlmostEqual(mod._QBoundVals[3][1], 0.4305, 4) self.assertAlmostEqual(err, 0.3, 4)
def test2XValClass(self): fName = os.path.join(RDConfig.RDCodeDir,'ML','KNN','test_data','random_pts.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() npvals = data.GetNPossibleVals() nvars = data.GetNVars() attrs = range(1,nvars+1) numNeigh = 11 mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, numNeigh,silent=1) self.assertAlmostEqual(err,0.01075,4)
def testPerm2(self): # """ tests the descriptor remapping stuff in a packager """ pkg = self._loadPackage() calc = pkg.GetCalculator() names = calc.GetDescriptorNames() DataUtils.InitRandomNumbers((23, 42)) perm = list(names) random.shuffle(perm, random=random.random) calc.simpleList = perm calc.descriptorNames = perm pkg.Init() self._verify(pkg, self.testD)
def test4XValRegress(self): fName = os.path.join(RDConfig.RDCodeDir,'ML','KNN','test_data','random_pts.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() npvals = data.GetNPossibleVals() nvars = data.GetNVars() attrs = range(1,nvars+1) numNeigh = 11 mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, numNeigh,silent=1, modelBuilder=CrossValidate.makeRegressionModel) # NOTE: this number hasn't been extensively checked self.assertAlmostEqual(err,0.0777,4)
def test1(self): """ basics """ probes = [ (.5, 4, 2), (.7, 3, 3), (.75, 3, 3), (.333, 6, 0), (.25, 4, 2), ] for frac, nKeep, nRej in probes: k, r = DataUtils.FilterData(self.d1, 1, frac) assert len(k) == nKeep, 'bad nKeep (%d != %d)' % (len(k), nKeep) assert len(r) == nRej, 'bad nRej (%d != %d)' % (len(r), nRej)
def testPerm2(self): """ tests the descriptor remapping stuff in a packager """ pkg = cPickle.load( open(os.path.join(self.dataDir, 'Jan9_build3_pkg.pkl'), 'rb')) calc = pkg.GetCalculator() names = calc.GetDescriptorNames() DataUtils.InitRandomNumbers((23, 42)) perm = list(names) random.shuffle(perm) calc.simpleList = perm calc.descriptorNames = perm pkg.Init() self._verify(pkg, self.testD)
def test_SplitData(self): self.assertRaises(ValueError, SplitData.SplitDataSet, None, -1.1) self.assertRaises(ValueError, SplitData.SplitDataSet, None, 1.1) data = list(range(10)) DataUtils.InitRandomNumbers((23, 42)) f = StringIO() with redirect_stdout(f): result = SplitData.SplitDataSet(data, 0.5) self.assertEqual(set(result[0]).intersection(result[1]), set()) self.assertEqual(len(result[0]), 5) s = f.getvalue() self.assertIn('Training', s) self.assertIn('hold-out', s)
def setUp(self): #print '\n%s: '%self.shortDescription(), self.examples = cPickle.load( open(RDConfig.RDCodeDir + '/ML/Composite/test_data/ferro.pkl', 'rb')) self.varNames = [ 'composition', 'max_atomic', 'has3d', 'has4d', 'has5d', 'elconc', 'atvol', 'isferro' ] self.qBounds = [[], [1.89, 3.53], [], [], [], [0.55, 0.73], [11.81, 14.52], []] self.nPoss = [0, 3, 2, 2, 2, 3, 3, 2] self.attrs = range(1, len(self.varNames) - 1) from rdkit.ML.Data import DataUtils DataUtils.InitRandomNumbers((23, 43))
def testPerm2(self): """ tests the descriptor remapping stuff in a packager """ with open(os.path.join(self.dataDir,'Jan9_build3_pkg.pkl'),'r') as pkgTF: buf = pkgTF.read().replace('\r\n', '\n').encode('utf-8') pkgTF.close() with io.BytesIO(buf) as pkgF: pkg = cPickle.load(pkgF) calc = pkg.GetCalculator() names = calc.GetDescriptorNames() DataUtils.InitRandomNumbers((23,42)) perm = list(names) random.shuffle(perm,random=random.random) calc.simpleList = perm calc.descriptorNames = perm pkg.Init() self._verify(pkg,self.testD)
def GetDataSet(self, **kwargs): """ Returns a MLDataSet pulled from a database using our stored values. """ from rdkit.ML.Data import DataUtils data = DataUtils.DBToData(self.dbName, self.tableName, user=self.dbUser, password=self.dbPassword, what=self.dbWhat, where=self.dbWhere, join=self.dbJoin, **kwargs) return data
def setUp(self): with open(RDConfig.RDCodeDir + '/ML/Composite/test_data/ferro.pkl', 'r') as pklTF: buf = pklTF.read().replace('\r\n', '\n').encode('utf-8') pklTF.close() with io.BytesIO(buf) as pklF: self.examples = cPickle.load(pklF) self.varNames = [ 'composition', 'max_atomic', 'has3d', 'has4d', 'has5d', 'elconc', 'atvol', 'isferro' ] self.qBounds = [[], [1.89, 3.53], [], [], [], [0.55, 0.73], [11.81, 14.52], []] self.nPoss = [0, 3, 2, 2, 2, 3, 3, 2] self.attrs = list(range(1, len(self.varNames) - 1)) from rdkit.ML.Data import DataUtils DataUtils.InitRandomNumbers((23, 43))
def test2NaiveBayes(self): fName = os.path.join(RDConfig.RDCodeDir, 'ML', 'NaiveBayes', 'test_data', 'stddata.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() nvars = data.GetNVars() attrs = range(1, nvars + 1) npvals = [0] + [3] * nvars + [2] qBounds = [0] + [2] * nvars + [0] mod, err = CrossValidate.CrossValidationDriver(examples, attrs, npvals, qBounds, mEstimateVal=20.0) assert feq(err, 0.19354)
def test1Neighbors(self): fName = os.path.join(RDConfig.RDCodeDir,'ML','KNN','test_data','random_pts.csv') data = DataUtils.TextFileToData(fName) examples = data.GetNamedData() npvals = data.GetNPossibleVals() nvars = data.GetNVars() attrs = range(1,nvars+1) numNeigh = 11 metric = DistFunctions.EuclideanDist mdl = KNNModel.KNNModel(numNeigh,attrs,metric) pt = examples.pop(0) tgt = [(metric(pt,ex,attrs),ex) for ex in examples] tgt.sort() mdl.SetTrainingExamples(examples) neighbors = mdl.GetNeighbors(pt) for i in range(numNeigh): assert feq(-tgt[i][0],neighbors[i][0]) assert tgt[i][1][0]==neighbors[i][1][0]