Esempio n. 1
0
    def test_is_attr_legal_verbose(self):

        ATTR_FUNCTIONS = [
            (lambda attr: setattr(attr, "f", 1.0)),
            (lambda attr: setattr(attr, "i", 1)),
            (lambda attr: setattr(attr, "s", b"str")),
            (lambda attr: attr.floats.extend([1.0, 2.0])),
            (lambda attr: attr.ints.extend([1, 2])),
            (lambda attr: attr.strings.extend([b"a", b"b"])),
            (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])),
            (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])),
        ]
        # Randomly set one field, and the result should be legal.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(ATTR_FUNCTIONS)(attr)
            self.assertTrue(helper.is_attribute_legal(attr))
        # Randomly set two fields, and then ensure helper function catches it.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(ATTR_FUNCTIONS, 2):
                func(attr)
            self.assertFalse(helper.is_attribute_legal(attr))
Esempio n. 2
0
 def test_attr_float(self):
     # float
     attr = helper.make_attribute("float", 1.)
     self.assertEqual(attr.name, "float")
     self.assertEqual(attr.f, 1.)
     self.assertTrue(helper.is_attribute_legal(attr))
     # float with scientific
     attr = helper.make_attribute("float", 1e10)
     self.assertEqual(attr.name, "float")
     self.assertEqual(attr.f, 1e10)
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 3
0
 def test_is_attr_legal(self):
     # no name, no field
     attr = AttributeProto()
     self.assertFalse(helper.is_attribute_legal(attr))
     # name, but no field
     attr = AttributeProto()
     attr.name = "test"
     self.assertFalse(helper.is_attribute_legal(attr))
     # name, with two fields
     attr = AttributeProto()
     attr.name = "test"
     attr.f = 1.0
     attr.i = 2
     self.assertFalse(helper.is_attribute_legal(attr))
Esempio n. 4
0
 def test_attr_string(self):
     # bytes
     attr = helper.make_attribute("str", b"test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     self.assertTrue(helper.is_attribute_legal(attr))
     # unspecified
     attr = helper.make_attribute("str", "test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     self.assertTrue(helper.is_attribute_legal(attr))
     # unicode
     attr = helper.make_attribute("str", u"test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 5
0
 def test_attr_repeated_tensor_proto(self):
     graphs = [GraphProto(), GraphProto()]
     graphs[0].name = "a"
     graphs[1].name = "b"
     attr = helper.make_attribute("graphs", graphs)
     self.assertEqual(attr.name, "graphs")
     self.assertEqual(list(attr.graphs), graphs)
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 6
0
 def test_attr_repeated_tensor_proto(self):
     tensors = [TensorProto(), TensorProto()]
     tensors[0].name = "a"
     tensors[1].name = "b"
     attr = helper.make_attribute("tensors", tensors)
     self.assertEqual(attr.name, "tensors")
     self.assertEqual(list(attr.tensors), tensors)
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 7
0
 def test_attr_int(self):
     # integer
     attr = helper.make_attribute("int", 3)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 3)
     self.assertTrue(helper.is_attribute_legal(attr))
     # long integer
     attr = helper.make_attribute("int", 5)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 5)
     self.assertTrue(helper.is_attribute_legal(attr))
     # octinteger
     attr = helper.make_attribute("int", 0o1701)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 0o1701)
     self.assertTrue(helper.is_attribute_legal(attr))
     # hexinteger
     attr = helper.make_attribute("int", 0x1701)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 0x1701)
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 8
0
 def test_attr_repeated_str(self):
     attr = helper.make_attribute("strings", ["str1", "str2"])
     self.assertEqual(attr.name, "strings")
     self.assertEqual(list(attr.strings), [b"str1", b"str2"])
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 9
0
 def test_attr_repeated_int(self):
     attr = helper.make_attribute("ints", [1, 2])
     self.assertEqual(attr.name, "ints")
     self.assertEqual(list(attr.ints), [1, 2])
     self.assertTrue(helper.is_attribute_legal(attr))
Esempio n. 10
0
 def test_attr_repeated_float(self):
     attr = helper.make_attribute("floats", [1.0, 2.0])
     self.assertEqual(attr.name, "floats")
     self.assertEqual(list(attr.floats), [1.0, 2.0])
     self.assertTrue(helper.is_attribute_legal(attr))