def testReadWrite(self):
        categories = [
            "ES", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10",
            "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18", "S19",
            "GB", "US"
        ]

        fieldWidth = 100
        bitsOn = 10

        original = SDRCategoryEncoder(n=fieldWidth,
                                      w=bitsOn,
                                      categoryList=categories,
                                      name="foo",
                                      verbosity=0,
                                      forced=True)

        # internal check
        self.assertEqual(original.sdrs.shape, (32, fieldWidth))

        # ES
        es = original.encode("ES")
        self.assertEqual(es.sum(), bitsOn)
        self.assertEqual(es.shape, (fieldWidth, ))
        self.assertEqual(es.sum(), bitsOn)

        decoded = original.decode(es)

        proto1 = SDRCategoryEncoderProto.new_message()
        original.write(proto1)

        # Write the proto to a temp file and read it back into a new proto
        with tempfile.TemporaryFile() as f:
            proto1.write(f)
            f.seek(0)
            proto2 = SDRCategoryEncoderProto.read(f)

        encoder = SDRCategoryEncoder.read(proto2)

        self.assertIsInstance(encoder, SDRCategoryEncoder)
        self.assertEqual(encoder.n, original.n)
        self.assertEqual(encoder.w, original.w)
        self.assertEqual(encoder.verbosity, original.verbosity)
        self.assertEqual(encoder.description, original.description)
        self.assertEqual(encoder.name, original.name)
        self.assertDictEqual(encoder.categoryToIndex, original.categoryToIndex)
        self.assertTrue(numpy.array_equal(encoder.encode("ES"), es))
        self.assertEqual(original.decode(encoder.encode("ES")),
                         encoder.decode(original.encode("ES")))
        self.assertEqual(decoded, encoder.decode(es))

        # Test autogrow serialization
        autogrow = SDRCategoryEncoder(n=fieldWidth,
                                      w=bitsOn,
                                      categoryList=None,
                                      name="bar",
                                      forced=True)

        es = autogrow.encode("ES")
        us = autogrow.encode("US")
        gs = autogrow.encode("GS")

        proto1 = SDRCategoryEncoderProto.new_message()
        autogrow.write(proto1)

        # Write the proto to a temp file and read it back into a new proto
        with tempfile.TemporaryFile() as f:
            proto1.write(f)
            f.seek(0)
            proto2 = SDRCategoryEncoderProto.read(f)

        t = SDRCategoryEncoder.read(proto2)

        self.assertTrue(numpy.array_equal(t.encode("ES"), es))
        self.assertTrue(numpy.array_equal(t.encode("US"), us))
        self.assertTrue(numpy.array_equal(t.encode("GS"), gs))
Beispiel #2
0
  def testReadWrite(self):
    categories = ["ES", "S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8",
                    "S9","S10", "S11", "S12", "S13", "S14", "S15", "S16",
                    "S17", "S18", "S19", "GB", "US"]

    fieldWidth = 100
    bitsOn = 10

    original = SDRCategoryEncoder(n=fieldWidth, w=bitsOn,
                                  categoryList=categories,
                                  name="foo", verbosity=0, forced=True)

    # internal check
    self.assertEqual(original.sdrs.shape, (32, fieldWidth))

    # ES
    es = original.encode("ES")
    self.assertEqual(es.sum(), bitsOn)
    self.assertEqual(es.shape, (fieldWidth,))
    self.assertEqual(es.sum(), bitsOn)

    decoded = original.decode(es)

    proto1 = SDRCategoryEncoderProto.new_message()
    original.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = SDRCategoryEncoderProto.read(f)

    encoder = SDRCategoryEncoder.read(proto2)

    self.assertIsInstance(encoder, SDRCategoryEncoder)
    self.assertEqual(encoder.n, original.n)
    self.assertEqual(encoder.w, original.w)
    self.assertEqual(encoder.verbosity, original.verbosity)
    self.assertEqual(encoder.description, original.description)
    self.assertEqual(encoder.name, original.name)
    self.assertDictEqual(encoder.categoryToIndex, original.categoryToIndex)
    self.assertTrue(numpy.array_equal(encoder.encode("ES"), es))
    self.assertEqual(original.decode(encoder.encode("ES")),
                     encoder.decode(original.encode("ES")))
    self.assertEqual(decoded, encoder.decode(es))

    # Test autogrow serialization
    autogrow = SDRCategoryEncoder(n=fieldWidth, w=bitsOn, categoryList = None,
                                  name="bar", forced=True)

    es = autogrow.encode("ES")
    us = autogrow.encode("US")
    gs = autogrow.encode("GS")

    proto1 = SDRCategoryEncoderProto.new_message()
    autogrow.write(proto1)

    # Write the proto to a temp file and read it back into a new proto
    with tempfile.TemporaryFile() as f:
      proto1.write(f)
      f.seek(0)
      proto2 = SDRCategoryEncoderProto.read(f)

    t = SDRCategoryEncoder.read(proto2)

    self.assertTrue(numpy.array_equal(t.encode("ES"), es))
    self.assertTrue(numpy.array_equal(t.encode("US"), us))
    self.assertTrue(numpy.array_equal(t.encode("GS"), gs))