Exemplo n.º 1
0
    def test_Tree(self):
        tree = Tree.TreeNode(None, 'root', label=0)
        self.assertEqual(tree.GetLevel(), 0)
        self.assertEqual(tree.GetName(), 'root')
        self.assertEqual(tree.GetData(), None)
        self.assertEqual(tree.GetTerminal(), False)
        self.assertEqual(tree.GetLabel(), 0)
        self.assertEqual(tree.GetParent(), None)
        self.assertEqual(tree.GetChildren(), [])

        for i in range(3):
            child = tree.AddChild('child {0}'.format(i),
                                  i + 1,
                                  data={'key': 'value'})
            self.assertEqual(child.GetLevel(), 1)
            self.assertEqual(child.GetName(), 'child {0}'.format(i))
            self.assertEqual(child.GetData(), {'key': 'value'})
            self.assertEqual(child.GetLabel(), i + 1)
            self.assertEqual(child.GetParent(), tree)
            self.assertEqual(child.GetChildren(), [])
        children = tree.GetChildren()
        self.assertEqual(len(children), 3)
        children[0].AddChild('terminal', 4, isTerminal=True)

        s = str(tree)
        self.assertIn('root', s)
        self.assertIn('    terminal', s)
        self.assertIn('  child 2', s)

        tree.NameTree(['a', 'b', 'c', 'd', 'e'])
        self.assertEqual(str(tree), 'a\n  b\n    terminal\n  c\n  d\n')

        tree.PruneChild(children[1])
        self.assertEqual(str(tree), 'a\n  b\n    terminal\n  d\n')

        f = StringIO()
        with redirect_stdout(f):
            tree.Print(showData=True)
        s = f.getvalue()
        self.assertIn('value', s)
        self.assertIn('None', s)

        f = StringIO()
        with redirect_stdout(f):
            tree.Print()
        s = f.getvalue()
        self.assertNotIn('value', s)
        self.assertNotIn('None', s)

        tree.Destroy()
        self.assertEqual(str(tree), 'a\n')
Exemplo n.º 2
0
  def test_Tree(self):
    tree = Tree.TreeNode(None, 'root', label=0)
    self.assertEqual(tree.GetLevel(), 0)
    self.assertEqual(tree.GetName(), 'root')
    self.assertEqual(tree.GetData(), None)
    self.assertEqual(tree.GetTerminal(), False)
    self.assertEqual(tree.GetLabel(), 0)
    self.assertEqual(tree.GetParent(), None)
    self.assertEqual(tree.GetChildren(), [])

    for i in range(3):
      child = tree.AddChild('child {0}'.format(i), i + 1, data={'key': 'value'})
      self.assertEqual(child.GetLevel(), 1)
      self.assertEqual(child.GetName(), 'child {0}'.format(i))
      self.assertEqual(child.GetData(), {'key': 'value'})
      self.assertEqual(child.GetLabel(), i + 1)
      self.assertEqual(child.GetParent(), tree)
      self.assertEqual(child.GetChildren(), [])
    children = tree.GetChildren()
    self.assertEqual(len(children), 3)
    children[0].AddChild('terminal', 4, isTerminal=True)

    s = str(tree)
    self.assertIn('root', s)
    self.assertIn('    terminal', s)
    self.assertIn('  child 2', s)

    tree.NameTree(['a', 'b', 'c', 'd', 'e'])
    self.assertEqual(str(tree), 'a\n  b\n    terminal\n  c\n  d\n')

    tree.PruneChild(children[1])
    self.assertEqual(str(tree), 'a\n  b\n    terminal\n  d\n')

    f = StringIO()
    with redirect_stdout(f):
      tree.Print(showData=True)
    s = f.getvalue()
    self.assertIn('value', s)
    self.assertIn('None', s)

    f = StringIO()
    with redirect_stdout(f):
      tree.Print()
    s = f.getvalue()
    self.assertNotIn('value', s)
    self.assertNotIn('None', s)

    tree.Destroy()
    self.assertEqual(str(tree), 'a\n')
Exemplo n.º 3
0
  def testSaveState(self):
    fName = os.path.join(RDConfig.RDCodeDir, 'ML/Descriptors/test_data', 'molcalc.dsc')
    with open(fName, 'r') as inTF:
      buf = inTF.read().replace('\r\n', '\n').encode('utf-8')
      inTF.close()
    inF = BytesIO(buf)
    calc = cPickle.load(inF)
    self.assertEqual(calc.GetDescriptorNames(), tuple(self.descs))
    self.assertEqual(calc.GetDescriptorVersions(), tuple(self.vers))
    self._testVals(calc, self.testD)

    f = StringIO()
    with redirect_stdout(f):
      calc.ShowDescriptors()
    s = f.getvalue()
    for name in calc.GetDescriptorNames():
      self.assertIn(name, s)

    self.assertIn('Wildman-Crippen LogP value', calc.GetDescriptorSummaries())
    self.assertIn('N/A', calc.GetDescriptorSummaries())

    funcs = calc.GetDescriptorFuncs()
    self.assertEqual(len(funcs), len(self.descs))
    for f in funcs:
      self.assertTrue(callable(f))
Exemplo n.º 4
0
    def test1(self):
        # " testing pruning with known results "
        oPts = [
            [0, 0, 1, 0],
            [0, 1, 1, 1],
            [1, 0, 1, 1],
            [1, 1, 0, 0],
            [1, 1, 1, 1],
        ]
        tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
        tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
        err, badEx = CrossValidate.CrossValidate(tree, oPts)
        assert err == 0.0, 'bad initial error'
        assert len(badEx) == 0, 'bad initial error'

        # prune with original data, shouldn't do anything
        f = StringIO()
        with redirect_stdout(f):
            PruneTree._verbose = True
            newTree, err = PruneTree.PruneTree(tree, [], oPts)
            PruneTree._verbose = False
        self.assertIn('Pruner', f.getvalue())
        assert newTree == tree, 'improper pruning'

        # prune with train data
        newTree, err = PruneTree.PruneTree(tree, [], tPts)
        assert newTree != tree, 'bad pruning'
        assert feq(err, 0.14286), 'bad error result'
Exemplo n.º 5
0
 def test_exampleCode(self):
   # We make sure that the example code runs
   f = StringIO()
   with redirect_stdout(f):
     _exampleCode()
   s = f.getvalue()
   self.assertIn('[58, 75, 78, 84]', s)
Exemplo n.º 6
0
    def test1(self):
        # " testing pruning with known results "
        oPts = [
          [0, 0, 1, 0],
          [0, 1, 1, 1],
          [1, 0, 1, 1],
          [1, 1, 0, 0],
          [1, 1, 1, 1],
        ]
        tPts = oPts + [[0, 1, 1, 0], [0, 1, 1, 0]]
        tree = ID3.ID3Boot(oPts, attrs=range(3), nPossibleVals=[2] * 4)
        err, badEx = CrossValidate.CrossValidate(tree, oPts)
        assert err == 0.0, 'bad initial error'
        assert len(badEx) == 0, 'bad initial error'

        # prune with original data, shouldn't do anything
        f = StringIO()
        with redirect_stdout(f):
            PruneTree._verbose = True
            newTree, err = PruneTree.PruneTree(tree, [], oPts)
            PruneTree._verbose = False
        self.assertIn('Pruner', f.getvalue())
        assert newTree == tree, 'improper pruning'

        # prune with train data
        newTree, err = PruneTree.PruneTree(tree, [], tPts)
        assert newTree != tree, 'bad pruning'
        assert feq(err, 0.14286), 'bad error result'
Exemplo n.º 7
0
    def test_CrossValidate(self):
        # We just check here that the code works
        net, _ = self._trainExamples(self.orExamples)
        percentage, badExamples = CrossValidate(net, self.orExamples, 0.2)
        self.assertEqual(percentage, 1.0 / 4)
        self.assertEqual(len(badExamples), 1)

        percentage, badExamples = CrossValidate(net, self.orExamples,
                                                self.trainTol)
        self.assertEqual(percentage, 0.0)
        self.assertEqual(len(badExamples), 0)

        net, cvError = CrossValidationDriver(self.orExamples + self.orExamples,
                                             silent=True)
        self.assertEqual(cvError, 0.5)

        net, cvError = CrossValidationDriver(self.orExamples + self.orExamples,
                                             silent=True,
                                             replacementSelection=True)
        self.assertEqual(cvError, 0.0)

        net, cvError = CrossValidationDriver(self.orExamples + self.orExamples,
                                             silent=True,
                                             calcTotalError=True)
        self.assertEqual(cvError, 0.25)

        f = StringIO()
        with redirect_stdout(f):
            CrossValidationDriver(self.orExamples + self.orExamples)
Exemplo n.º 8
0
 def test_exampleCode(self):
     # We make sure that the example code runs
     f = StringIO()
     with redirect_stdout(f):
         _exampleCode()
     s = f.getvalue()
     self.assertIn('[58, 75, 78, 84]', s)
Exemplo n.º 9
0
  def test_CrossValidate(self):
    # We just check here that the code works
    net, _ = self._trainExamples(self.orExamples)
    percentage, badExamples = CrossValidate(net, self.orExamples, 0.2)
    self.assertEqual(percentage, 1.0 / 4)
    self.assertEqual(len(badExamples), 1)

    percentage, badExamples = CrossValidate(net, self.orExamples, self.trainTol)
    self.assertEqual(percentage, 0.0)
    self.assertEqual(len(badExamples), 0)

    net, cvError = CrossValidationDriver(self.orExamples + self.orExamples, silent=True)
    self.assertEqual(cvError, 0.5)

    net, cvError = CrossValidationDriver(self.orExamples + self.orExamples, silent=True,
                                         replacementSelection=True)
    self.assertEqual(cvError, 0.0)

    net, cvError = CrossValidationDriver(self.orExamples + self.orExamples, silent=True,
                                         calcTotalError=True)
    self.assertEqual(cvError, 0.25)

    f = StringIO()
    with redirect_stdout(f):
      CrossValidationDriver(self.orExamples + self.orExamples)
Exemplo n.º 10
0
  def testSaveState(self):
    fName = os.path.join(RDConfig.RDCodeDir, 'ML/Descriptors/test_data', 'molcalc.dsc')
    with open(fName, 'r') as inTF:
      buf = inTF.read().replace('\r\n', '\n').encode('utf-8')
      inTF.close()
    inF = BytesIO(buf)
    calc = cPickle.load(inF)
    self.assertEqual(calc.GetDescriptorNames(), tuple(self.descs))
    self.assertEqual(calc.GetDescriptorVersions(), tuple(self.vers))
    self._testVals(calc, self.testD)

    f = StringIO()
    with redirect_stdout(f):
      calc.ShowDescriptors()
    s = f.getvalue()
    for name in calc.GetDescriptorNames():
      self.assertIn(name, s)

    self.assertIn('Wildman-Crippen LogP value', calc.GetDescriptorSummaries())
    self.assertIn('N/A', calc.GetDescriptorSummaries())

    funcs = calc.GetDescriptorFuncs()
    self.assertEqual(len(funcs), len(self.descs))
    for f in funcs:
      self.assertTrue(callable(f))
Exemplo n.º 11
0
    def test_Cluster(self):
        """ tests the Cluster class functionality """
        root = Clusters.Cluster(index=1, position=1)
        c1 = Clusters.Cluster(index=10, position=10)
        c1.AddChild(Clusters.Cluster(index=30, position=30))
        c1.AddChild(Clusters.Cluster(index=31, position=31))
        t32 = Clusters.Cluster(index=32, position=32)
        c1.AddChild(t32)

        c2 = Clusters.Cluster(index=11)
        #     c2.AddChild(Clusters.Cluster(index=40))
        #     c2.AddChild(Clusters.Cluster(index=41))
        c2.AddChildren(
            [Clusters.Cluster(index=40),
             Clusters.Cluster(index=41)])

        root.AddChild(c1)
        root.AddChild(c2)
        nodes = ClusterUtils.GetNodeList(root)

        indices = [x.GetIndex() for x in nodes]
        assert indices == [30, 31, 32, 10, 40, 41, 11, 1], 'bad indices'
        subtree = root.FindSubtree(11)
        self.assertEqual(
            [x.GetIndex() for x in ClusterUtils.GetNodeList(subtree)],
            [40, 41, 11])

        self.assertFalse(root.IsTerminal())
        self.assertTrue(t32.IsTerminal())

        self.assertEqual(root.GetData(), None)
        root.SetData(3.14)
        self.assertEqual(root.GetData(), 3.14)

        self.assertEqual(root.GetMetric(), 0.0)
        root.SetMetric(0.1)
        self.assertEqual(root.GetMetric(), 0.1)

        self.assertEqual(root.GetIndex(), 1)
        root.SetIndex(100)
        self.assertEqual(root.GetIndex(), 100)

        self.assertEqual(root.GetPointsPositions(), [30, 31, 32, []])

        root.RemoveChild(c1)
        self.assertEqual(
            [x.GetIndex() for x in ClusterUtils.GetNodeList(root)],
            [40, 41, 11, 100])

        self.assertEqual(root.GetName(), 'Cluster(100)')
        root.SetName('abc')
        self.assertEqual(root.GetName(), 'abc')

        f = StringIO()
        with redirect_stdout(f):
            root.Print(showData=True)
        self.assertIn('abc', f.getvalue())
        self.assertIn('Cluster(41)', f.getvalue())
        self.assertIn('Metric', f.getvalue())
Exemplo n.º 12
0
 def test_exampleCode(self):
     # We make sure that the example code runs
     from rdkit.TestRunner import redirect_stdout
     f = StringIO()
     with redirect_stdout(f):
         EState.EState._exampleCode()
     s = f.getvalue()
     self.assertIn('CC(N)C(=O)O', s)
Exemplo n.º 13
0
 def test_exampleCode(self):
     # We make sure that the example code runs
     from rdkit.TestRunner import redirect_stdout
     f = StringIO()
     with redirect_stdout(f):
         Fingerprinter._exampleCode()
     s = f.getvalue()
     self.assertIn('NCCc1ccc(O)c(O)c1', s)
Exemplo n.º 14
0
 def test_exampleCode(self):
   # We make sure that the example code runs
   from rdkit.TestRunner import redirect_stdout
   f = StringIO()
   with redirect_stdout(f):
     EState.EState._exampleCode()
   s = f.getvalue()
   self.assertIn('CC(N)C(=O)O', s)
Exemplo n.º 15
0
 def test_exampleCode(self):
   # We make sure that the example code runs
   from rdkit.TestRunner import redirect_stdout
   f = StringIO()
   with redirect_stdout(f):
     Fingerprinter._exampleCode()
   s = f.getvalue()
   self.assertIn('NCCc1ccc(O)c(O)c1', s)
Exemplo n.º 16
0
    def test_exceptions(self):
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, -0.1)
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, 1.1)

        f = StringIO()
        with redirect_stdout(f):
            SplitData.SplitIndices(10, 0.5, replacement=True, silent=False)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Exemplo n.º 17
0
 def test_exampleCode(self):
   try:
     f = StringIO()
     with redirect_stdout(f):
       Tree._exampleCode()
     self.assertTrue(os.path.isfile('save.pkl'))
     self.assertIn('tree==tree2 False', f.getvalue(), 'Example didn' 't run to end')
   finally:
     if os.path.isfile('save.pkl'):
       os.remove('save.pkl')
Exemplo n.º 18
0
    def test_exceptions(self):
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, -0.1)
        self.assertRaises(ValueError, SplitData.SplitIndices, 10, 1.1)

        f = StringIO()
        with redirect_stdout(f):
            SplitData.SplitIndices(10, 0.5, replacement=True, silent=False)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Exemplo n.º 19
0
 def test_TestRun(self):
     try:
         f = StringIO()
         with redirect_stdout(f):
             CrossValidate.TestRun()
         self.assertTrue(os.path.isfile('save.pkl'))
         s = f.getvalue()
         self.assertIn('t1 == t2 True', s)
     finally:
         if os.path.isfile('save.pkl'):
             os.remove('save.pkl')
Exemplo n.º 20
0
 def test_TestRun(self):
   try:
     f = StringIO()
     with redirect_stdout(f):
       CrossValidate.TestRun()
     self.assertTrue(os.path.isfile('save.pkl'))
     s = f.getvalue()
     self.assertIn('t1 == t2 True', s)
   finally:
     if os.path.isfile('save.pkl'):
       os.remove('save.pkl')
Exemplo n.º 21
0
 def test_exampleCode(self):
     try:
         f = StringIO()
         with redirect_stdout(f):
             Tree._exampleCode()
         self.assertTrue(os.path.isfile('save.pkl'))
         self.assertIn('tree==tree2 False', f.getvalue(), 'Example didn'
                       't run to end')
     finally:
         if os.path.isfile('save.pkl'):
             os.remove('save.pkl')
Exemplo n.º 22
0
  def test_Cluster(self):
    """ tests the Cluster class functionality """
    root = Clusters.Cluster(index=1, position=1)
    c1 = Clusters.Cluster(index=10, position=10)
    c1.AddChild(Clusters.Cluster(index=30, position=30))
    c1.AddChild(Clusters.Cluster(index=31, position=31))
    t32 = Clusters.Cluster(index=32, position=32)
    c1.AddChild(t32)

    c2 = Clusters.Cluster(index=11)
    #     c2.AddChild(Clusters.Cluster(index=40))
    #     c2.AddChild(Clusters.Cluster(index=41))
    c2.AddChildren([Clusters.Cluster(index=40), Clusters.Cluster(index=41)])

    root.AddChild(c1)
    root.AddChild(c2)
    nodes = ClusterUtils.GetNodeList(root)

    indices = [x.GetIndex() for x in nodes]
    assert indices == [30, 31, 32, 10, 40, 41, 11, 1], 'bad indices'
    subtree = root.FindSubtree(11)
    self.assertEqual([x.GetIndex() for x in ClusterUtils.GetNodeList(subtree)], [40, 41, 11])

    self.assertFalse(root.IsTerminal())
    self.assertTrue(t32.IsTerminal())

    self.assertEqual(root.GetData(), None)
    root.SetData(3.14)
    self.assertEqual(root.GetData(), 3.14)

    self.assertEqual(root.GetMetric(), 0.0)
    root.SetMetric(0.1)
    self.assertEqual(root.GetMetric(), 0.1)

    self.assertEqual(root.GetIndex(), 1)
    root.SetIndex(100)
    self.assertEqual(root.GetIndex(), 100)

    self.assertEqual(root.GetPointsPositions(), [30, 31, 32, []])

    root.RemoveChild(c1)
    self.assertEqual([x.GetIndex() for x in ClusterUtils.GetNodeList(root)], [40, 41, 11, 100])

    self.assertEqual(root.GetName(), 'Cluster(100)')
    root.SetName('abc')
    self.assertEqual(root.GetName(), 'abc')

    f = StringIO()
    with redirect_stdout(f):
      root.Print(showData=True)
    self.assertIn('abc', f.getvalue())
    self.assertIn('Cluster(41)', f.getvalue())
    self.assertIn('Metric', f.getvalue())
Exemplo n.º 23
0
  def testRun(self):
    # " test that the CrossValidationDriver runs "
    examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nExamples=200)
    f = StringIO()
    with redirect_stdout(f):
      tree, frac = CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, silent=False)
    self.assertGreater(frac, 0)
    self.assertEqual('Var: 1', tree.GetName())
    self.assertIn('Validation error', f.getvalue())

    CrossValidate.CrossValidationDriver(examples, attrs, nPossibleVals, lessGreedy=True,
                                        calcTotalError=True, silent=True)
Exemplo n.º 24
0
    def test_exampleCode(self):
        f = StringIO()
        with redirect_stdout(f):
            try:
                PruneTree._testRandom()
                self.assertTrue(os.path.isfile('prune.pkl'))
            finally:
                if os.path.isfile('orig.pkl'):
                    os.remove('orig.pkl')
                if os.path.isfile('prune.pkl'):
                    os.remove('prune.pkl')
        self.assertIn('pruned error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testSpecific()
        self.assertIn('pruned holdout error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testChain()
        self.assertIn('pruned holdout error', f.getvalue())
Exemplo n.º 25
0
    def test_exampleCode(self):
        f = StringIO()
        with redirect_stdout(f):
            try:
                PruneTree._testRandom()
                self.assertTrue(os.path.isfile('prune.pkl'))
            finally:
                if os.path.isfile('orig.pkl'):
                    os.remove('orig.pkl')
                if os.path.isfile('prune.pkl'):
                    os.remove('prune.pkl')
        self.assertIn('pruned error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testSpecific()
        self.assertIn('pruned holdout error', f.getvalue())

        f = StringIO()
        with redirect_stdout(f):
            PruneTree._testChain()
        self.assertIn('pruned holdout error', f.getvalue())
Exemplo n.º 26
0
    def test_SplitData(self):
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, -1.1)
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, 1.1)

        data = list(range(10))
        DataUtils.InitRandomNumbers((23, 42))
        f = StringIO()
        with redirect_stdout(f):
            result = SplitData.SplitDataSet(data, 0.5)
        self.assertEqual(set(result[0]).intersection(result[1]), set())
        self.assertEqual(len(result[0]), 5)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Exemplo n.º 27
0
    def test_SplitData(self):
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, -1.1)
        self.assertRaises(ValueError, SplitData.SplitDataSet, None, 1.1)

        data = list(range(10))
        DataUtils.InitRandomNumbers((23, 42))
        f = StringIO()
        with redirect_stdout(f):
            result = SplitData.SplitDataSet(data, 0.5)
        self.assertEqual(set(result[0]).intersection(result[1]), set())
        self.assertEqual(len(result[0]), 5)
        s = f.getvalue()
        self.assertIn('Training', s)
        self.assertIn('hold-out', s)
Exemplo n.º 28
0
  def test_DescriptorCalculator(self):
    calc = Descriptors.DescriptorCalculator()
    self.assertRaises(NotImplementedError, calc.ShowDescriptors)
    self.assertRaises(NotImplementedError, calc.GetDescriptorNames)
    self.assertRaises(NotImplementedError, calc.CalcDescriptors, None)

    calc.simpleList = ['simple1', 'simple2']
    calc.compoundList = ['cmpd1', 'cmpd2']
    f = StringIO()
    with redirect_stdout(f):
      calc.ShowDescriptors()
    s = f.getvalue()
    for name in calc.simpleList:
      self.assertIn(name, s)
    for name in calc.compoundList:
      self.assertIn(name, s)
Exemplo n.º 29
0
  def test_DescriptorCalculator(self):
    calc = Descriptors.DescriptorCalculator()
    self.assertRaises(NotImplementedError, calc.ShowDescriptors)
    self.assertRaises(NotImplementedError, calc.GetDescriptorNames)
    self.assertRaises(NotImplementedError, calc.CalcDescriptors, None)

    calc.simpleList = ['simple1', 'simple2']
    calc.compoundList = ['cmpd1', 'cmpd2']
    f = StringIO()
    with redirect_stdout(f):
      calc.ShowDescriptors()
    s = f.getvalue()
    for name in calc.simpleList:
      self.assertIn(name, s)
    for name in calc.compoundList:
      self.assertIn(name, s)
Exemplo n.º 30
0
    def testRun(self):
        # " test that the CrossValidationDriver runs "
        examples, attrs, nPossibleVals = randomtest.GenRandomExamples(
            nExamples=200)
        f = StringIO()
        with redirect_stdout(f):
            tree, frac = CrossValidate.CrossValidationDriver(examples,
                                                             attrs,
                                                             nPossibleVals,
                                                             silent=False)
        self.assertGreater(frac, 0)
        self.assertEqual('Var: 1', tree.GetName())
        self.assertIn('Validation error', f.getvalue())

        CrossValidate.CrossValidationDriver(examples,
                                            attrs,
                                            nPossibleVals,
                                            lessGreedy=True,
                                            calcTotalError=True,
                                            silent=True)
Exemplo n.º 31
0
  def test3(self):
    examples = []

    bv = ExplicitBitVect(2)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['a', vc, 1])

    bv = ExplicitBitVect(2)
    bv.SetBit(1)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['c', vc, 0])

    bv = ExplicitBitVect(2)
    bv.SetBit(1)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['c2', vc, 0])

    bv = ExplicitBitVect(2)
    bv.SetBit(0)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['d', vc, 0])

    bv = ExplicitBitVect(2)
    bv.SetBit(0)
    vc = VectCollection()
    vc.AddVect(1, bv)
    bv = ExplicitBitVect(2)
    bv.SetBit(1)
    vc.AddVect(2, bv)
    examples.append(['d2', vc, 0])

    bv = ExplicitBitVect(2)
    bv.SetBit(0)
    bv.SetBit(1)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['d', vc, 1])

    bv = ExplicitBitVect(2)
    bv.SetBit(0)
    bv.SetBit(1)
    vc = VectCollection()
    vc.AddVect(1, bv)
    examples.append(['e', vc, 1])

    f = StringIO()
    with redirect_stdout(f):
      t = BuildSigTree(examples, 2, metric=InfoTheory.InfoType.ENTROPY, maxDepth=2, verbose=True)
    self.assertIn('Build', f.getvalue())

    self.assertEqual(t.GetName(), 'Bit-0')
    self.assertEqual(t.GetLabel(), 0)
    c0 = t.GetChildren()[0]
    self.assertEqual(c0.GetName(), 'Bit-1')
    self.assertEqual(c0.GetLabel(), 1)
    c1 = t.GetChildren()[1]
    self.assertEqual(c1.GetName(), 'Bit-1')
    self.assertEqual(c1.GetLabel(), 1)

    bv = ExplicitBitVect(2)
    bv.SetBit(0)
    vc = VectCollection()
    vc.AddVect(1, bv)
    bv = ExplicitBitVect(2)
    bv.SetBit(1)
    vc.AddVect(2, bv)
    r = t.ClassifyExample(['t', vc, 0])
    self.assertEqual(r, 0)
Exemplo n.º 32
0
    def test3(self):
        examples = []

        bv = ExplicitBitVect(2)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['a', vc, 1])

        bv = ExplicitBitVect(2)
        bv.SetBit(1)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['c', vc, 0])

        bv = ExplicitBitVect(2)
        bv.SetBit(1)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['c2', vc, 0])

        bv = ExplicitBitVect(2)
        bv.SetBit(0)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['d', vc, 0])

        bv = ExplicitBitVect(2)
        bv.SetBit(0)
        vc = VectCollection()
        vc.AddVect(1, bv)
        bv = ExplicitBitVect(2)
        bv.SetBit(1)
        vc.AddVect(2, bv)
        examples.append(['d2', vc, 0])

        bv = ExplicitBitVect(2)
        bv.SetBit(0)
        bv.SetBit(1)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['d', vc, 1])

        bv = ExplicitBitVect(2)
        bv.SetBit(0)
        bv.SetBit(1)
        vc = VectCollection()
        vc.AddVect(1, bv)
        examples.append(['e', vc, 1])

        f = StringIO()
        with redirect_stdout(f):
            t = BuildSigTree(examples,
                             2,
                             metric=InfoTheory.InfoType.ENTROPY,
                             maxDepth=2,
                             verbose=True)
        self.assertIn('Build', f.getvalue())

        self.assertEqual(t.GetName(), 'Bit-0')
        self.assertEqual(t.GetLabel(), 0)
        c0 = t.GetChildren()[0]
        self.assertEqual(c0.GetName(), 'Bit-1')
        self.assertEqual(c0.GetLabel(), 1)
        c1 = t.GetChildren()[1]
        self.assertEqual(c1.GetName(), 'Bit-1')
        self.assertEqual(c1.GetLabel(), 1)

        bv = ExplicitBitVect(2)
        bv.SetBit(0)
        vc = VectCollection()
        vc.AddVect(1, bv)
        bv = ExplicitBitVect(2)
        bv.SetBit(1)
        vc.AddVect(2, bv)
        r = t.ClassifyExample(['t', vc, 0])
        self.assertEqual(r, 0)
Exemplo n.º 33
0
 def test_exampleCode(self):
     f = StringIO()
     with redirect_stdout(f):
         BuildQuantTree.TestTree()
     self.assertIn('Var: 2', f.getvalue())
Exemplo n.º 34
0
 def test_exampleCode(self):
   # We make sure that the example code runs
   f = StringIO()
   with redirect_stdout(f):
     Matcher._exampleCode()
   self.assertIn('finished', f.getvalue())
Exemplo n.º 35
0
 def test_exampleCode(self):
     # We make sure that the example code runs
     f = StringIO()
     with redirect_stdout(f):
         Matcher._exampleCode()
     self.assertIn('finished', f.getvalue())
Exemplo n.º 36
0
 def test_exampleCode(self):
     f = StringIO()
     with redirect_stdout(f):
         CompoundDescriptors._exampleCode()
Exemplo n.º 37
0
 def test_exampleCode(self):
   f = StringIO()
   with redirect_stdout(f):
     BuildQuantTree.TestTree()
   self.assertIn('Var: 2', f.getvalue())