Ejemplo n.º 1
0
    def test_is_attr_legal_verbose(self):  # type: () -> None

        def _set(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, Text, Any) -> None
            setattr(attr, var, value)
            setattr(attr, 'type', type)

        def _extend(attr, type, var, value):  # type: (AttributeProto, AttributeProto.AttributeType, List[Any], Any) -> None
            var.extend(value)
            setattr(attr, 'type', type)

        SET_ATTR = [
            (lambda attr: _set(attr, AttributeProto.FLOAT, "f", 1.0)),
            (lambda attr: _set(attr, AttributeProto.INT, "i", 1)),
            (lambda attr: _set(attr, AttributeProto.STRING, "s", b"str")),
            (lambda attr: _extend(attr, AttributeProto.FLOATS, attr.floats, [1.0, 2.0])),
            (lambda attr: _extend(attr, AttributeProto.INTS, attr.ints, [1, 2])),
            (lambda attr: _extend(attr, AttributeProto.STRINGS, attr.strings, [b"a", b"b"])),
        ]
        # Randomly set one field, and the result should be legal.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(SET_ATTR)(attr)
            checker.check_attribute(attr)
        # Randomly set two fields, and then ensure helper function catches it.
        for _i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(SET_ATTR, 2):
                func(attr)
            self.assertRaises(checker.ValidationError,
                              checker.check_attribute,
                              attr)
Ejemplo n.º 2
0
 def test_attr_repeated_graph_proto(self):  # type: () -> None
     graphs = [GraphProto(), GraphProto()]
     graphs[0].name = "a"
     graphs[1].name = "b"
     attr = helper.make_attribute("graphs", graphs)
     self.assertEqual(attr.name, "graphs")
     self.assertEqual(list(attr.graphs), graphs)
     checker.check_attribute(attr)
Ejemplo n.º 3
0
 def test_attr_float(self):  # type: () -> None
     # float
     attr = helper.make_attribute("float", 1.)
     self.assertEqual(attr.name, "float")
     self.assertEqual(attr.f, 1.)
     checker.check_attribute(attr)
     # float with scientific
     attr = helper.make_attribute("float", 1e10)
     self.assertEqual(attr.name, "float")
     self.assertEqual(attr.f, 1e10)
     checker.check_attribute(attr)
Ejemplo n.º 4
0
 def test_attr_repeated_tensor_proto(self):  # type: () -> None
     tensors = [
         helper.make_tensor(
             name='a',
             data_type=TensorProto.FLOAT,
             dims=(1,),
             vals=np.ones(1).tolist()
         ),
         helper.make_tensor(
             name='b',
             data_type=TensorProto.FLOAT,
             dims=(1,),
             vals=np.ones(1).tolist()
         )]
     attr = helper.make_attribute("tensors", tensors)
     self.assertEqual(attr.name, "tensors")
     self.assertEqual(list(attr.tensors), tensors)
     checker.check_attribute(attr)
Ejemplo n.º 5
0
 def test_attr_int(self):  # type: () -> None
     # integer
     attr = helper.make_attribute("int", 3)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 3)
     checker.check_attribute(attr)
     # long integer
     attr = helper.make_attribute("int", 5)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 5)
     checker.check_attribute(attr)
     # octinteger
     attr = helper.make_attribute("int", 0o1701)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 0o1701)
     checker.check_attribute(attr)
     # hexinteger
     attr = helper.make_attribute("int", 0x1701)
     self.assertEqual(attr.name, "int")
     self.assertEqual(attr.i, 0x1701)
     checker.check_attribute(attr)
Ejemplo n.º 6
0
 def test_attr_string(self):  # type: () -> None
     # bytes
     attr = helper.make_attribute("str", b"test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     checker.check_attribute(attr)
     # unspecified
     attr = helper.make_attribute("str", "test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     checker.check_attribute(attr)
     # unicode
     attr = helper.make_attribute("str", u"test")
     self.assertEqual(attr.name, "str")
     self.assertEqual(attr.s, b"test")
     checker.check_attribute(attr)
Ejemplo n.º 7
0
 def test_attr_repeated_str(self):  # type: () -> None
     attr = helper.make_attribute("strings", ["str1", "str2"])
     self.assertEqual(attr.name, "strings")
     self.assertEqual(list(attr.strings), [b"str1", b"str2"])
     checker.check_attribute(attr)
Ejemplo n.º 8
0
 def test_attr_repeated_int(self):  # type: () -> None
     attr = helper.make_attribute("ints", [1, 2])
     self.assertEqual(attr.name, "ints")
     self.assertEqual(list(attr.ints), [1, 2])
     checker.check_attribute(attr)
Ejemplo n.º 9
0
 def test_attr_repeated_float(self):  # type: () -> None
     attr = helper.make_attribute("floats", [1.0, 2.0])
     self.assertEqual(attr.name, "floats")
     self.assertEqual(list(attr.floats), [1.0, 2.0])
     checker.check_attribute(attr)
Ejemplo n.º 10
0
 def test_attr_repeated_float(self):  # type: () -> None
     attr = helper.make_attribute("floats", [1.0, 2.0])
     self.assertEqual(attr.name, "floats")
     self.assertEqual(list(attr.floats), [1.0, 2.0])
     checker.check_attribute(attr)
Ejemplo n.º 11
0
 def test_attr_repeated_str(self):  # type: () -> None
     attr = helper.make_attribute("strings", ["str1", "str2"])
     self.assertEqual(attr.name, "strings")
     self.assertEqual(list(attr.strings), [b"str1", b"str2"])
     checker.check_attribute(attr)
Ejemplo n.º 12
0
 def test_attr_repeated_int(self):
     attr = helper.make_attribute("ints", [1, 2])
     self.assertEqual(attr.name, "ints")
     self.assertEqual(list(attr.ints), [1, 2])
     checker.check_attribute(attr)
Ejemplo n.º 13
0
 def test_attr_repeated_mixed_floats_and_ints(self) -> None:
     attr = helper.make_attribute("mixed", [1, 2, 3.0, 4.5])
     self.assertEqual(attr.name, "mixed")
     self.assertEqual(list(attr.floats), [1.0, 2.0, 3.0, 4.5])
     checker.check_attribute(attr)