Ejemplo n.º 1
0
 def testReplacementSelection(self):
     # " use selection with replacement "
     RDRandom.seed(self.randomSeed)
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200, seed=self.randomArraySeed)
     tree, frac = CrossValidate.CrossValidationDriver(
         examples, attrs, nPossibleVals, silent=1, replacementSelection=1)
     self.assertTrue(tree)
     self.assertAlmostEqual(frac, 0.01666, 4)
Ejemplo n.º 2
0
 def testRun(self):
     " test that the CrossValidationDriver runs "
     from rdkit.ML.DecTree import randomtest
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200)
     tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                      attrs,
                                                      nPossibleVals,
                                                      silent=1)
Ejemplo n.º 3
0
 def testReplacementSelection(self):
     " use selection with replacement "
     from rdkit.ML.DecTree import randomtest
     from rdkit import RDRandom
     RDRandom.seed(self.randomSeed)
     examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
         nExamples=200, seed=self.randomArraySeed)
     tree, frac = CrossValidate.CrossValidationDriver(
         examples, attrs, nPossibleVals, silent=1, replacementSelection=1)
     assert tree
     assert feq(frac, 0.0833)
Ejemplo n.º 4
0
def TestRun():
    """ testing code """
    examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
        nExamples=200)
    tree, _ = CrossValidationDriver(examples, attrs, nPossibleVals)

    tree.Pickle('save.pkl')

    import copy
    t2 = copy.deepcopy(tree)
    print('t1 == t2', tree == t2)
    l = [tree]
    print('t2 in [tree]', t2 in l, l.index(t2))
Ejemplo n.º 5
0
def _testRandom():
    from rdkit.ML.DecTree import randomtest
    #   examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=20, randScale=0.25,
    #                                                                 nExamples=200)
    examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nVars=10, randScale=0.5,
                                                                  nExamples=200)
    tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals)
    tree.Print()
    tree.Pickle('orig.pkl')
    print('original error is:', frac)

    print('----Pruning')
    newTree, frac2 = PruneTree(tree, tree.GetTrainingExamples(), tree.GetTestExamples())
    newTree.Print()
    print('pruned error is:', frac2)
    newTree.Pickle('prune.pkl')
Ejemplo n.º 6
0
    def testResults(self):
        # " test the results of CrossValidation "
        RDRandom.seed(self.randomSeed)
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200, seed=self.randomArraySeed)
        tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                         attrs,
                                                         nPossibleVals,
                                                         silent=1)
        self.assertGreater(frac, 0)

        with open(self.origTreeName, 'r') as inTFile:
            buf = inTFile.read().replace('\r\n', '\n').encode('utf-8')
            inTFile.close()
        inFile = BytesIO(buf)
        oTree = pickle.load(inFile)

        assert oTree == tree, 'Random CrossValidation test failed'
Ejemplo n.º 7
0
    def testResults(self):
        " test the results of CrossValidation "
        from rdkit.ML.DecTree import randomtest
        from rdkit import RDRandom
        RDRandom.seed(self.randomSeed)
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200, seed=self.randomArraySeed)
        tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                         attrs,
                                                         nPossibleVals,
                                                         silent=1)

        import cPickle
        #cPickle.dump(tree,file(self.origTreeName,'w+'))
        inFile = open(self.origTreeName, 'r')
        oTree = cPickle.load(inFile)

        assert oTree == tree, 'Random CrossValidation test failed'
Ejemplo n.º 8
0
  def testResults(self):
    " test the results of CrossValidation "
    from rdkit.ML.DecTree import randomtest
    from rdkit import RDRandom
    RDRandom.seed(self.randomSeed)
    examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200,
                                                                seed=self.randomArraySeed)
    tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,
                                                    nPossibleVals,silent=1)

    from rdkit.six.moves import cPickle
    #cPickle.dump(tree,open(self.origTreeName,'w+'))
    with open(self.origTreeName,'r') as inTFile:
      buf = inTFile.read().replace('\r\n', '\n').encode('utf-8')
      inTFile.close()
    with io.BytesIO(buf) as inFile:
      oTree = cPickle.load(inFile)

    assert oTree==tree,'Random CrossValidation test failed'
Ejemplo n.º 9
0
    def testRun(self):
        # " test that the CrossValidationDriver runs "
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200)
        f = StringIO()
        with redirect_stdout(f):
            tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                             attrs,
                                                             nPossibleVals,
                                                             silent=False)
        self.assertGreater(frac, 0)
        self.assertEqual('Var: 1', tree.GetName())
        self.assertIn('Validation error', f.getvalue())

        CrossValidate.CrossValidationDriver(examples,
                                            attrs,
                                            nPossibleVals,
                                            lessGreedy=True,
                                            calcTotalError=True,
                                            silent=True)