Exemplo n.º 1
0
def test_generic_array(generic_array):
    tf_value = TensorProtoConverter.get_tf_value(generic_array)
    assert type(
        tf_value).__name__ == TensorProtoConverter.__tfproto_type__.__name__
    generic = TensorProtoConverter.get_generic_value(tf_value)
    assert isinstance(generic, type(generic_array))
    assert (generic.np_array == generic_array.np_array).all()
Exemplo n.º 2
0
    def _build_param_ops(self, onnx_graph, ugraph, op_types_cnt,
                         tensor_names_map):
        """Find all tensors in initialization list in onnx_graph, normally constants

    Note that this method will update op_types_cnt and tensor_names_map **inplace**
    """
        # find Const ops
        params_dict = {}
        # FIXME: avoid using internal api of other library
        dict_items = _onnx_initializer_to_input_dict_items(
            onnx_graph.initializer)
        for name, tf_tensor in dict_items:
            params_dict[name] = AttrValueConverter.GenericType(
                value_name='value',
                value=TensorProtoConverter.get_generic_value(
                    tf_tensor.op.get_attr('value')))
        # build Const ops
        for tensor_name, tensor_value in params_dict.items():
            cnt = op_types_cnt['Const']
            node_name = self._format_node_name(tensor_name, 'Const', cnt)
            op_types_cnt['Const'] += 1
            tensor_names_map[tensor_name] = TensorInfo(
                name=self._format_tensor_name('', node_name, 0),
                op_name=node_name,
                dtype=tensor_value.value.dtype,
                shape=list(tensor_value.value.np_array.shape),
                ugraph=ugraph)
            OperationInfo(name=node_name,
                          input_tensors=[],
                          output_tensors=[tensor_names_map[tensor_name]],
                          op_type='Const',
                          lib_name='onnx',
                          ugraph=ugraph,
                          op_attr={'value': tensor_value})
Exemplo n.º 3
0
def test_tf_tensor_quint8(tf_quint8_tensor):
    np_array = TensorProtoConverter.get_generic_value(tf_quint8_tensor)
    assert np_array.dtype[0] == np.dtype('uint8')
    tf_value = TensorProtoConverter.get_tf_value(np_array)
    assert tf_value.tensor_content == tf_quint8_tensor.tensor_content
    assert tf_value.tensor_shape == tf_quint8_tensor.tensor_shape
Exemplo n.º 4
0
def generic_array():
    np_array = np.random.randn(3, 3).astype(np.float32)
    return TensorProtoConverter.__utensor_generic_type__(np_array=np_array)