def test_is_attr_legal_verbose(self): ATTR_FUNCTIONS = [ (lambda attr: setattr(attr, "f", 1.0)), (lambda attr: setattr(attr, "i", 1)), (lambda attr: setattr(attr, "s", b"str")), (lambda attr: attr.floats.extend([1.0, 2.0])), (lambda attr: attr.ints.extend([1, 2])), (lambda attr: attr.strings.extend([b"a", b"b"])), (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])), (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])), ] # Randomly set one field, and the result should be legal. for i in range(100): attr = AttributeProto() attr.name = "test" random.choice(ATTR_FUNCTIONS)(attr) self.assertTrue(helper.is_attribute_legal(attr)) # Randomly set two fields, and then ensure helper function catches it. for i in range(100): attr = AttributeProto() attr.name = "test" for func in random.sample(ATTR_FUNCTIONS, 2): func(attr) self.assertFalse(helper.is_attribute_legal(attr))
def test_attr_repeated_tensor_proto(self): graphs = [GraphProto(), GraphProto()] graphs[0].name = "a" graphs[1].name = "b" attr = helper.make_attribute("graphs", graphs) self.assertEqual(attr.name, "graphs") self.assertEqual(list(attr.graphs), graphs) self.assertTrue(helper.is_attribute_legal(attr))
def test_attr_repeated_graph_proto(self): graphs = [GraphProto(), GraphProto()] graphs[0].name = "a" graphs[1].name = "b" attr = helper.make_attribute("graphs", graphs) self.assertEqual(attr.name, "graphs") self.assertEqual(list(attr.graphs), graphs) checker.check_attribute(attr)
def make_graph(nodes, name, inputs, outputs, initializer=None): if initializer is None: initializer = [] graph = GraphProto() graph.node.extend(nodes) graph.name = name graph.input.extend(inputs) graph.output.extend(outputs) graph.initializer.extend(initializer) return graph
def make_graph(nodes, name, inputs, outputs, initializer=[]): graph = GraphProto() # Touch graph.ir_version so it is stored as the version from which it is # generated. graph.ir_version = IR_VERSION graph.node.extend(nodes) graph.name = name graph.input.extend(inputs) graph.output.extend(outputs) graph.initializer.extend(initializer) return graph
def test_version_exists(self): graph = GraphProto() # When we create it, graph should not have a version string. self.assertFalse(graph.HasField('ir_version')) # We should touch the version so it is annotated with the current # ir version of the running ONNX graph.ir_version = IR_VERSION graph_string = graph.SerializeToString() graph.ParseFromString(graph_string) self.assertTrue(graph.HasField('ir_version')) # Check if the version is correct. self.assertEqual(graph.ir_version, IR_VERSION)