Exemplo n.º 1
0
 def attribute_value_test(attribute_value):
     node = make_node('Abs', ['X'], [], name='test_node', test_attribute=attribute_value)
     model = make_model(make_graph([node], 'test_graph', [
         make_tensor_value_info('X', onnx.TensorProto.FLOAT, [1, 2]),
     ], []), producer_name='ngraph')
     wrapped_attribute = ModelWrapper(model).graph.node[0].get_attribute('test_attribute')
     return wrapped_attribute.get_value()
Exemplo n.º 2
0
 def attribute_value_test(attribute_value):
     node = make_node("Abs", ["X"], [], name="test_node", test_attribute=attribute_value)
     model = make_model(make_graph([node], "test_graph", [
         make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 2])
     ], []), producer_name="ngraph")
     wrapped_attribute = ModelWrapper(model).graph.node[0].get_attribute('test_attribute')
     return wrapped_attribute.get_value()
Exemplo n.º 3
0
def test_value_info_wrapper(onnx_model):
    wrapped_model = ModelWrapper(onnx_model)
    wrapped_value_info = wrapped_model.graph.input[0]

    assert wrapped_value_info.get_dtype() == np.float32
    assert wrapped_value_info.has_initializer

    initializer = wrapped_value_info.get_initializer()
    assert np.all(initializer.to_array() == np.array([[1., 1.]], dtype=np.float32))

    axes = wrapped_value_info.get_ng_axes()
    assert len(axes) == 2
    assert axes[1].length == 2

    placeholder = wrapped_value_info.get_ng_placeholder()
    assert placeholder.__class__ == ng.op_graph.op_graph.AssignableTensorOp
    assert placeholder.is_placeholder
    assert placeholder.axes == axes

    variable = wrapped_value_info.get_ng_variable()
    assert variable.__class__ == ng.op_graph.op_graph.AssignableTensorOp
    assert variable.is_trainable
    assert variable.axes == axes

    constant = wrapped_value_info.get_ng_constant()
    assert constant.__class__ == ng.op_graph.op_graph.AssignableTensorOp
    assert constant.is_constant
    assert constant.axes == axes

    ng_node = wrapped_value_info.get_ng_node()
    assert ng_node == constant
Exemplo n.º 4
0
def import_onnx_model(onnx_protobuf):  # type: (ModelProto) -> List[Dict]
    """
    Import an ONNX Protocol Buffers model (onnx_pb2.ModelProto) object
    and convert it into a list of ngraph operations.

    An ONNX model defines a set of output nodes. Each output node will be added to the
    returned list as a dict with the following fields:

    * 'name' - name of the output, as specified in the imported ONNX model
    * 'inputs' - a list of ngraph placeholder ops, used to feed data into the model
    * 'output' - ngraph Op representing the output of the model

    Usage example:

    >>> onnx_protobuf = onnx.load('y_equals_a_plus_b.onnx.pb')
    >>> import_onnx_model(onnx_protobuf)
    [{
        'name': 'Y',
        'inputs': [<AssignableTensorOp(placeholder):4552991464>,
                   <AssignableTensorOp(placeholder):4510192360>],
        'output': <Add(Add_0):4552894504>
    }]

    >>> ng_model = import_onnx_model(model)[0]
    >>> transformer = ng.transformers.make_transformer()
    >>> computation = transformer.computation(ng_model['output'], *ng_model['inputs'])
    >>> computation(4, 6)
    array([ 10.], dtype=float32)
    """
    model = ModelWrapper(onnx_protobuf)
    return model.graph.get_ng_model()
Exemplo n.º 5
0
def test_node_wrapper(onnx_model):
    wrapped_model = ModelWrapper(onnx_model)
    wrapped_node = wrapped_model.graph.node[0]

    ng_inputs = wrapped_node.get_ng_inputs()
    assert len(ng_inputs) == 2
    assert ng_inputs[0].__class__ == ng.op_graph.op_graph.AssignableTensorOp

    ng_outputs = wrapped_node.get_ng_nodes_dict()
    assert len(ng_outputs) == 1
    assert ng_outputs['Z'].__class__ == ng.op_graph.op_graph.Add
Exemplo n.º 6
0
def test_graph_wrapper(onnx_model):
    wrapped_model = ModelWrapper(onnx_model)
    wrapped_graph = wrapped_model.graph

    assert len(wrapped_graph.node) == 1
    assert wrapped_graph.node[0].__class__ == NodeWrapper

    assert len(wrapped_graph.input) == 2
    assert wrapped_graph.input[0].__class__ == ValueInfoWrapper

    assert len(wrapped_graph.output) == 1
    assert wrapped_graph.output[0].__class__ == ValueInfoWrapper

    assert len(wrapped_graph.initializer) == 1
    assert wrapped_graph.initializer[0].__class__ == TensorWrapper

    initializer = wrapped_graph.get_initializer('X')
    assert np.all(initializer.to_array() == np.array([[1., 1.]], dtype=np.float32))
    assert not wrapped_graph.get_initializer('Y')

    ng_model = wrapped_graph.get_ng_model()[0]
    assert ng_model['output'].__class__ == ng.op_graph.op_graph.Add
    assert ng_model['inputs'][0].__class__ == ng.op_graph.op_graph.AssignableTensorOp
Exemplo n.º 7
0
def test_model_wrapper(onnx_model):
    wrapped_model = ModelWrapper(onnx_model)
    assert wrapped_model.producer_name == 'ngraph ONNXImporter'
    assert wrapped_model.graph.__class__ == GraphWrapper