Example #1
0
    def test_is_attr_legal_verbose(self):

        ATTR_FUNCTIONS = [
            (lambda attr: setattr(attr, "f", 1.0)),
            (lambda attr: setattr(attr, "i", 1)),
            (lambda attr: setattr(attr, "s", b"str")),
            (lambda attr: attr.floats.extend([1.0, 2.0])),
            (lambda attr: attr.ints.extend([1, 2])),
            (lambda attr: attr.strings.extend([b"a", b"b"])),
            (lambda attr: attr.tensors.extend([TensorProto(), TensorProto()])),
            (lambda attr: attr.graphs.extend([GraphProto(), GraphProto()])),
        ]
        # Randomly set one field, and the result should be legal.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            random.choice(ATTR_FUNCTIONS)(attr)
            self.assertTrue(helper.is_attribute_legal(attr))
        # Randomly set two fields, and then ensure helper function catches it.
        for i in range(100):
            attr = AttributeProto()
            attr.name = "test"
            for func in random.sample(ATTR_FUNCTIONS, 2):
                func(attr)
            self.assertFalse(helper.is_attribute_legal(attr))
Example #2
0
 def test_attr_repeated_tensor_proto(self):
     graphs = [GraphProto(), GraphProto()]
     graphs[0].name = "a"
     graphs[1].name = "b"
     attr = helper.make_attribute("graphs", graphs)
     self.assertEqual(attr.name, "graphs")
     self.assertEqual(list(attr.graphs), graphs)
     self.assertTrue(helper.is_attribute_legal(attr))
Example #3
0
 def test_attr_repeated_graph_proto(self):
     graphs = [GraphProto(), GraphProto()]
     graphs[0].name = "a"
     graphs[1].name = "b"
     attr = helper.make_attribute("graphs", graphs)
     self.assertEqual(attr.name, "graphs")
     self.assertEqual(list(attr.graphs), graphs)
     checker.check_attribute(attr)
Example #4
0
def make_graph(nodes, name, inputs, outputs, initializer=None):
    if initializer is None:
        initializer = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    return graph
Example #5
0
def make_graph(nodes, name, inputs, outputs, initializer=[]):
    graph = GraphProto()
    # Touch graph.ir_version so it is stored as the version from which it is
    # generated.
    graph.ir_version = IR_VERSION
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    return graph
Example #6
0
 def test_version_exists(self):
     graph = GraphProto()
     # When we create it, graph should not have a version string.
     self.assertFalse(graph.HasField('ir_version'))
     # We should touch the version so it is annotated with the current
     # ir version of the running ONNX
     graph.ir_version = IR_VERSION
     graph_string = graph.SerializeToString()
     graph.ParseFromString(graph_string)
     self.assertTrue(graph.HasField('ir_version'))
     # Check if the version is correct.
     self.assertEqual(graph.ir_version, IR_VERSION)