def test_is_attr_legal_verbose(self): # type: () -> None def _set(attr, type, var, value): # type: (AttributeProto, AttributeProto.AttributeType, Text, Any) -> None setattr(attr, var, value) setattr(attr, 'type', type) def _extend(attr, type, var, value): # type: (AttributeProto, AttributeProto.AttributeType, List[Any], Any) -> None var.extend(value) setattr(attr, 'type', type) SET_ATTR = [ (lambda attr: _set(attr, AttributeProto.FLOAT, "f", 1.0)), (lambda attr: _set(attr, AttributeProto.INT, "i", 1)), (lambda attr: _set(attr, AttributeProto.STRING, "s", b"str")), (lambda attr: _extend(attr, AttributeProto.FLOATS, attr.floats, [1.0, 2.0])), (lambda attr: _extend(attr, AttributeProto.INTS, attr.ints, [1, 2])), (lambda attr: _extend(attr, AttributeProto.STRINGS, attr.strings, [b"a", b"b"])), ] # Randomly set one field, and the result should be legal. for _i in range(100): attr = AttributeProto() attr.name = "test" random.choice(SET_ATTR)(attr) checker.check_attribute(attr) # Randomly set two fields, and then ensure helper function catches it. for _i in range(100): attr = AttributeProto() attr.name = "test" for func in random.sample(SET_ATTR, 2): func(attr) self.assertRaises(checker.ValidationError, checker.check_attribute, attr)
def test_attr_repeated_graph_proto(self): # type: () -> None graphs = [GraphProto(), GraphProto()] graphs[0].name = "a" graphs[1].name = "b" attr = helper.make_attribute("graphs", graphs) self.assertEqual(attr.name, "graphs") self.assertEqual(list(attr.graphs), graphs) checker.check_attribute(attr)
def test_attr_float(self): # type: () -> None # float attr = helper.make_attribute("float", 1.) self.assertEqual(attr.name, "float") self.assertEqual(attr.f, 1.) checker.check_attribute(attr) # float with scientific attr = helper.make_attribute("float", 1e10) self.assertEqual(attr.name, "float") self.assertEqual(attr.f, 1e10) checker.check_attribute(attr)
def test_attr_repeated_tensor_proto(self): # type: () -> None tensors = [ helper.make_tensor( name='a', data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist() ), helper.make_tensor( name='b', data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist() )] attr = helper.make_attribute("tensors", tensors) self.assertEqual(attr.name, "tensors") self.assertEqual(list(attr.tensors), tensors) checker.check_attribute(attr)
def test_attr_int(self): # type: () -> None # integer attr = helper.make_attribute("int", 3) self.assertEqual(attr.name, "int") self.assertEqual(attr.i, 3) checker.check_attribute(attr) # long integer attr = helper.make_attribute("int", 5) self.assertEqual(attr.name, "int") self.assertEqual(attr.i, 5) checker.check_attribute(attr) # octinteger attr = helper.make_attribute("int", 0o1701) self.assertEqual(attr.name, "int") self.assertEqual(attr.i, 0o1701) checker.check_attribute(attr) # hexinteger attr = helper.make_attribute("int", 0x1701) self.assertEqual(attr.name, "int") self.assertEqual(attr.i, 0x1701) checker.check_attribute(attr)
def test_attr_string(self): # type: () -> None # bytes attr = helper.make_attribute("str", b"test") self.assertEqual(attr.name, "str") self.assertEqual(attr.s, b"test") checker.check_attribute(attr) # unspecified attr = helper.make_attribute("str", "test") self.assertEqual(attr.name, "str") self.assertEqual(attr.s, b"test") checker.check_attribute(attr) # unicode attr = helper.make_attribute("str", u"test") self.assertEqual(attr.name, "str") self.assertEqual(attr.s, b"test") checker.check_attribute(attr)
def test_attr_repeated_str(self): # type: () -> None attr = helper.make_attribute("strings", ["str1", "str2"]) self.assertEqual(attr.name, "strings") self.assertEqual(list(attr.strings), [b"str1", b"str2"]) checker.check_attribute(attr)
def test_attr_repeated_int(self): # type: () -> None attr = helper.make_attribute("ints", [1, 2]) self.assertEqual(attr.name, "ints") self.assertEqual(list(attr.ints), [1, 2]) checker.check_attribute(attr)
def test_attr_repeated_float(self): # type: () -> None attr = helper.make_attribute("floats", [1.0, 2.0]) self.assertEqual(attr.name, "floats") self.assertEqual(list(attr.floats), [1.0, 2.0]) checker.check_attribute(attr)
def test_attr_repeated_int(self): attr = helper.make_attribute("ints", [1, 2]) self.assertEqual(attr.name, "ints") self.assertEqual(list(attr.ints), [1, 2]) checker.check_attribute(attr)
def test_attr_repeated_mixed_floats_and_ints(self) -> None: attr = helper.make_attribute("mixed", [1, 2, 3.0, 4.5]) self.assertEqual(attr.name, "mixed") self.assertEqual(list(attr.floats), [1.0, 2.0, 3.0, 4.5]) checker.check_attribute(attr)