Example #1
0
  def testWriteRead(self):
    c1 = CLAClassifier([1], 0.1, 0.1, 0)

    # Create a vector of input bit indices
    input1 = [1, 5, 9]
    result = c1.compute(recordNum=0,
                        patternNZ=input1,
                        classification={'bucketIdx': 4, 'actValue': 34.7},
                        learn=True, infer=True)

    proto1 = ClaClassifier_capnp.ClaClassifierProto.new_message()
    c1.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = ClaClassifier_capnp.ClaClassifierProto.read(f)

    # Load the deserialized proto
    c2 = CLAClassifier.read(proto2)

    self.assertEqual(c1.steps, c2.steps)
    self.assertAlmostEqual(c1.alpha, c2.alpha)
    self.assertAlmostEqual(c1.actValueAlpha, c2.actValueAlpha)
    self.assertEqual(c1._learnIteration, c2._learnIteration)
    self.assertEqual(c1._recordNumMinusLearnIteration, c2._recordNumMinusLearnIteration)
    self.assertEqual(c1._patternNZHistory, c2._patternNZHistory)
    self.assertEqual(c1._activeBitHistory.keys(), c2._activeBitHistory.keys())
    for bit, nSteps in c1._activeBitHistory.keys():
      c1BitHistory = c1._activeBitHistory[(bit, nSteps)]
      c2BitHistory = c2._activeBitHistory[(bit, nSteps)]
      self.assertEqual(c1BitHistory._id, c2BitHistory._id)
      self.assertEqual(c1BitHistory._stats, c2BitHistory._stats)
      self.assertEqual(c1BitHistory._lastTotalUpdate, c2BitHistory._lastTotalUpdate)
      self.assertEqual(c1BitHistory._learnIteration, c2BitHistory._learnIteration)
    self.assertEqual(c1._maxBucketIdx, c2._maxBucketIdx)
    self.assertEqual(len(c1._actualValues), len(c2._actualValues))
    for i in xrange(len(c1._actualValues)):
      self.assertAlmostEqual(c1._actualValues[i], c2._actualValues[i], 5)
    self.assertEqual(c1._version, c2._version)
    self.assertEqual(c1.verbosity, c2.verbosity)

    result1 = c1.compute(recordNum=1,
                         patternNZ=input1,
                         classification={'bucketIdx': 4, 'actValue': 34.7},
                         learn=True, infer=True)
    result2 = c2.compute(recordNum=1,
                         patternNZ=input1,
                         classification={'bucketIdx': 4, 'actValue': 34.7},
                         learn=True, infer=True)

    self.assertEqual(result1.keys(), result2.keys())
    for key in result1.keys():
      for i in xrange(len(c1._actualValues)):
        self.assertAlmostEqual(result1[key][i], result2[key][i], 5)
Example #2
0
  def testWriteRead(self):
    c1 = CLAClassifier([1], 0.1, 0.1, 0)

    # Create a vector of input bit indices
    input1 = [1, 5, 9]
    result = c1.compute(recordNum=0,
                        patternNZ=input1,
                        classification={'bucketIdx': 4, 'actValue': 34.7},
                        learn=True, infer=True)

    proto1 = ClaClassifier_capnp.ClaClassifierProto.new_message()
    c1.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = ClaClassifier_capnp.ClaClassifierProto.read(f)

    # Load the deserialized proto
    c2 = CLAClassifier.read(proto2)

    self.assertEqual(c1.steps, c2.steps)
    self.assertAlmostEqual(c1.alpha, c2.alpha)
    self.assertAlmostEqual(c1.actValueAlpha, c2.actValueAlpha)
    self.assertEqual(c1._learnIteration, c2._learnIteration)
    self.assertEqual(c1._recordNumMinusLearnIteration, c2._recordNumMinusLearnIteration)
    self.assertEqual(c1._patternNZHistory, c2._patternNZHistory)
    self.assertEqual(c1._activeBitHistory.keys(), c2._activeBitHistory.keys())
    for bit, nSteps in c1._activeBitHistory.keys():
      c1BitHistory = c1._activeBitHistory[(bit, nSteps)]
      c2BitHistory = c2._activeBitHistory[(bit, nSteps)]
      self.assertEqual(c1BitHistory._id, c2BitHistory._id)
      self.assertEqual(c1BitHistory._stats, c2BitHistory._stats)
      self.assertEqual(c1BitHistory._lastTotalUpdate, c2BitHistory._lastTotalUpdate)
      self.assertEqual(c1BitHistory._learnIteration, c2BitHistory._learnIteration)
    self.assertEqual(c1._maxBucketIdx, c2._maxBucketIdx)
    self.assertEqual(len(c1._actualValues), len(c2._actualValues))
    for i in xrange(len(c1._actualValues)):
      self.assertAlmostEqual(c1._actualValues[i], c2._actualValues[i], 5)
    self.assertEqual(c1._version, c2._version)
    self.assertEqual(c1.verbosity, c2.verbosity)

    result1 = c1.compute(recordNum=1,
                         patternNZ=input1,
                         classification={'bucketIdx': 4, 'actValue': 34.7},
                         learn=True, infer=True)
    result2 = c2.compute(recordNum=1,
                         patternNZ=input1,
                         classification={'bucketIdx': 4, 'actValue': 34.7},
                         learn=True, infer=True)

    self.assertEqual(result1.keys(), result2.keys())
    for key in result1.keys():
      for i in xrange(len(c1._actualValues)):
        self.assertAlmostEqual(result1[key][i], result2[key][i], 5)
Example #3
0
 def read(proto):
   """
   proto: CLAClassifierRegionProto capnproto object
   """
   impl = proto.classifierImp
   if impl == 'py':
     return CLAClassifier.read(proto.claClassifier)
   elif impl == 'cpp':
     return FastCLAClassifier.read(proto.claClassifier)
   elif impl == 'diff':
     raise NotImplementedError("CLAClassifierDiff.read not implemented")
   else:
     raise ValueError('Invalid classifier implementation (%r). Value must be '
                      '"py" or "cpp".' % impl)
Example #4
0
 def read(proto):
     """
 proto: CLAClassifierRegionProto capnproto object
 """
     impl = proto.classifierImp
     if impl == 'py':
         return CLAClassifier.read(proto.claClassifier)
     elif impl == 'cpp':
         return FastCLAClassifier.read(proto.claClassifier)
     elif impl == 'diff':
         raise NotImplementedError("CLAClassifierDiff.read not implemented")
     else:
         raise ValueError(
             'Invalid classifier implementation (%r). Value must be '
             '"py" or "cpp".' % impl)