示例#1
0
def _testRandom():
    from rdkit.ML.DecTree import randomtest
    #   examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=20, randScale=0.25,
    #                                                                 nExamples=200)
    examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5,
                                                                  nExamples=200)
    tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals)
    tree.Print()
    tree.Pickle('orig.pkl')
    print('original error is:', frac)

    print('----Pruning')
    newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples())
    newTree.Print()
    print('pruned error is:', frac2)
    newTree.Pickle('prune.pkl')
示例#2
0
  def testResults(self):
    " test the results of CrossValidation "
    from rdkit.ML.DecTree import randomtest
    from rdkit import RDRandom
    RDRandom.seed(self.randomSeed)
    examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200,
                                                                seed=self.randomArraySeed)
    tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,
                                                    nPossibleVals,silent=1)

    from rdkit.six.moves import cPickle
    #cPickle.dump(tree,open(self.origTreeName,'w+'))
    with open(self.origTreeName,'r') as inTFile:
      buf = inTFile.read().replace('\r\n', '\n').encode('utf-8')
      inTFile.close()
    with io.BytesIO(buf) as inFile:
      oTree = cPickle.load(inFile)

    assert oTree==tree,'Random CrossValidation test failed'