示例#1
0
def _get_attribute(attribute_proto):
    if attribute_proto.HasField('ref_attr_name'):
        raise ParseException('Unexpected ref_attr_name in main graph')

    name = _fixstr(attribute_proto.name)

    if attribute_proto.HasField('f'):
        value = float(attribute_proto.f)
    elif attribute_proto.HasField('i'):
        value = utils.anyint_to_int(attribute_proto.i)
    elif attribute_proto.HasField('s'):
        value = utils.anystr_to_str(attribute_proto.s)
    elif attribute_proto.HasField('t'):
        value = _get_tensor(attribute_proto.t)
    elif attribute_proto.HasField('g'):
        value = _get_graph(attribute_proto.g)
    elif attribute_proto.floats:
        value = [float(f) for f in attribute_proto.floats]
    elif attribute_proto.ints:
        value = [utils.anyint_to_int(i) for i in attribute_proto.ints]
    elif attribute_proto.strings:
        value = [utils.anystr_to_str(s) for s in attribute_proto.strings]
    elif attribute_proto.tensors:
        value = [_get_tensor(t) for t in attribute_proto.tensors]
    elif attribute_proto.graphs:
        value = [_get_graph(g) for g in attribute_proto.graphs]
    else:
        value = []

    return name, value
示例#2
0
def _get_attribute(attribute_proto):
    if attribute_proto.HasField('ref_attr_name'):
        raise ParseException('Unexpected ref_attr_name in main graph')

    name = _fixstr(attribute_proto.name)

    if attribute_proto.HasField('f'):
        value = float(attribute_proto.f)
    elif attribute_proto.HasField('i'):
        value = utils.anyint_to_int(attribute_proto.i)
    elif attribute_proto.HasField('s'):
        value = utils.anystr_to_str(attribute_proto.s)
    elif attribute_proto.HasField('t'):
        value = _get_tensor(attribute_proto.t)
        # raise ParseException("Attribute '{}' with type TENSOR in unsupported".format(name))
    elif attribute_proto.HasField('g'):
        value = _get_graph(attribute_proto.g)
        # raise ParseException("Attribute '{}' with type GRAPH in unsupported".format(name))
    elif attribute_proto.floats:
        value = [float(f) for f in attribute_proto.floats]
    elif attribute_proto.ints:
        value = [utils.anyint_to_int(i) for i in attribute_proto.ints]
    elif attribute_proto.strings:
        value = [utils.anystr_to_str(s) for s in attribute_proto.strings]
    elif attribute_proto.tensors:
        # raise ParseException("Attribute '{}' with type TENSOR LIST in unsupported".format(name))
        value = [_get_tensor(t) for t in attribute_proto.tensors]
    elif attribute_proto.graphs:
        # raise ParseException("Attribute '{}' with type GRAPH LIST in unsupported".format(name))
        value = [_get_graph(g) for g in attribute_proto.graphs]
    else:
        value = []

    return name, value
示例#3
0
def run_test(output_name, output_nodes, recreate=True):
    batch_size = 1
    source_shapes = {
        ph.name: [
            int(d.value) if d.value is not None else batch_size
            for d in ph.shape.dims
        ]
        for ph in get_placeholders()
    }
    pb_path = os.path.join('out', 'pb', output_name)
    network_name = output_name.rstrip('/').replace('/', '_')
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        save_protobuf(pb_path, output_nodes, sess, recreate)
        sess.close()

    for prefer_nhwc in [True, False]:
        test_activations(filename=os.path.join(pb_path, 'test.pb'),
                         source_shapes=source_shapes,
                         feed_dict={
                             utils.anystr_to_str(k): np.random.random(v)
                             for k, v in six.iteritems(source_shapes)
                         },
                         prefer_nhwc=prefer_nhwc,
                         network_name=network_name,
                         delete_after_each=False,
                         export_io_only=True)

    return nnef.parse_file(
        os.path.join('out', network_name + '_nhwc', 'nnef',
                     network_name + '_nnef', 'graph.nnef'))
示例#4
0
def _get_attributes(attr_map_proto, graph):
    attributes = {}
    for name, value in attr_map_proto.items():
        if not name.startswith('_'):
            field = value.WhichOneof('value')
            value = getattr(value, field)
            attributes[utils.anystr_to_str(name)] = _get_attribute(field, value, graph)

    return attributes
示例#5
0
def write(
    tf_graph,  # type: TFGraph
    file_path,  # type: str
    write_weights=True,  # type: bool
    custom_op_protos=None,  # type: typing.Optional[typing.List[OpProto]]
    custom_imports=None  # type: str
):
    # type: (...) -> typing.Optional[conversion_info.ConversionInfo]

    generate_source_operations(tf_graph)
    tf_graph.sort()
    try:

        names_to_write = _generate_names(tf_graph=tf_graph,
                                         custom_imports=custom_imports,
                                         custom_op_protos=custom_op_protos)

        old_names = {}
        for tensor in tf_graph.tensors:
            old_names[tensor] = tensor.name
            tensor.name = utils.anystr_to_str(names_to_write[tensor])

        if not os.path.exists(os.path.dirname(file_path)):
            os.makedirs(os.path.dirname(file_path))

        with open(file_path, "w") as f:
            _print(tf_graph,
                   file_handle=f,
                   custom_op_protos=custom_op_protos,
                   custom_imports=custom_imports)

        with open(file_path, "r") as f:
            tf_source = f.read()

        if tf_graph.list_variables() and write_weights:
            checkpoint_dir = file_path + ".checkpoint"
            checkpoint_path = os.path.join(
                checkpoint_dir,
                os.path.basename(file_path) + ".ckpt")
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            _create_checkpoint_with_values(net_fun=_tfsource_to_function(
                tf_source, tf_graph.name),
                                           file_name=checkpoint_path,
                                           variable_value_by_name={
                                               t.name: t.data
                                               for t in tf_graph.tensors
                                               if t.is_variable and t.name
                                           })

        for tensor in tf_graph.tensors:
            tensor.name = old_names[tensor]

        return _get_rename_info(tf_graph, names_to_write)
    finally:
        remove_source_operations(tf_graph)
示例#6
0
def _get_attribute(field, value, graph):
    if field == 'i' or field == 'f' or field == 'b' or field == 'placeholder':
        if utils.is_anyint(value):
            return utils.anyint_to_int(value)
        return value
    elif field == 's':
        return utils.anystr_to_str(value.decode())
    elif field == 'shape':
        return _get_shape(value)
    elif field == 'type':
        return _get_dtype(value)
    elif field == 'tensor':
        return _get_tensor(value, graph)
    elif field == 'func':
        return _get_func(value)
    elif field == 'list':
        field, items = _get_nonempty_items(value, fields=['i', 'f', 'b', 's', 'shape', 'type', 'tensor', 'func'])
        if items is None:
            return []
        return [_get_attribute(field, item, graph) for item in items]

    assert False
示例#7
0
def _normalize_types(arg):
    if utils.is_anyint(arg):
        return utils.anyint_to_int(arg)
    elif utils.is_anystr(arg):
        return utils.anystr_to_str(arg)
    elif isinstance(arg, np.ndarray):
        return arg.tolist()
    elif isinstance(arg, tf.TensorShape):
        if arg.dims is None:
            return None
        return [None if dim is None else int(dim) for dim in arg.as_list()]
    elif isinstance(arg, tf.Dimension):
        return arg.value
    elif isinstance(arg, tf.DType):
        return arg.name
    elif isinstance(
            arg,
        (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
         np.uint64, np.float16, np.float32, np.float64, np.bool_)):
        return arg.item()
    else:
        return arg
示例#8
0
def read_tf_graph_from_protobuf(filename):
    graph_def = tf_pb.GraphDef()
    with open(filename, 'rb') as file:
        graph_def.ParseFromString(file.read())

    graph = TFGraph()
    attrib_graph = TFGraph()  # just a graph to contain the tensors that are in attributes, no need to return this

    attributes_by_node_name = {}
    outputs_by_node_name = {}
    detected_output_count = _detect_output_counts(graph_def)
    for node in graph_def.node:
        outputs = []
        attributes = _get_attributes(node.attr, attrib_graph)
        output_count = _OutputCount.get(node.op, 1)
        if isinstance(output_count, str):
            output_count = attributes[output_count]
        output_count = max(output_count, detected_output_count.get(node.op, 1))
        assert isinstance(output_count, int)
        if output_count >= 1:
            output = TFTensor(graph, utils.anystr_to_str(node.name))
            outputs.append(output)
            for i in range(1, output_count):
                tensor_name = utils.anystr_to_str(node.name) + ':' + str(i)
                output = TFTensor(graph, tensor_name)
                outputs.append(output)

        outputs_by_node_name[node.name] = outputs
        attributes_by_node_name[node.name] = attributes

    tensor_by_name = {tensor.name: tensor
                      for outputs in six.itervalues(outputs_by_node_name)
                      for tensor in outputs}

    placeholders = []
    for node in graph_def.node:
        attributes = attributes_by_node_name[node.name]
        outputs = outputs_by_node_name[node.name]

        if node.op == 'Placeholder':
            assert len(outputs) == 1
            tensor = outputs[0]
            tensor.shape = attributes['shape'] if 'shape' in attributes else None
            tensor.dtype = attributes['dtype'] if 'dtype' in attributes else None
            placeholders.append(tensor)
        elif node.op == 'Const':
            assert len(outputs) == 1
            tensor = outputs[0]
            value = attributes['value']
            if isinstance(value, TFTensor):
                tensor.shape = value.shape
                tensor.dtype = value.dtype
                tensor.data = value.data
            else:
                tensor.data = value
        else:
            input_names = [name[:-2] if name.endswith(':0') else name for name in node.input if not name.startswith('^')]
            for name in input_names:
                if name not in tensor_by_name:
                    print('Info: List of node types in graph: {}\n'.format(
                        sorted(list({node.op for node in graph_def.node}))))

                    raise utils.NNEFToolsException(
                        "Tensor {} is used, but it is not clear which operation produced it. "
                        "Probably the graph has unsupported dynamic operations.".format(name))
            inputs = tuple([tensor_by_name[name] for name in input_names])
            TFOperation(graph,
                        name=utils.anystr_to_str(node.op),
                        inputs=inputs,
                        outputs=outputs,
                        attribs=attributes)

    for tensor in graph.tensors:
        if tensor.name is not None and ':' not in tensor.name:
            tensor.name += ':0'

    graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders])
    graph_outputs = []
    for op in graph.operations:
        if all(len(output.consumers) == 0 for output in op.outputs):
            for output in op.outputs:
                graph_outputs.append(output)

    graph.outputs = OrderedDict([('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor)
                                 for i, tensor in enumerate(graph_outputs)])

    return graph
示例#9
0
def read_tflite_graph_from_flatbuffers(filename):
    with open(filename, 'rb') as file:
        bytes = bytearray(file.read())

    model = tflite_fb.Model.GetRootAsModel(bytes, 0)

    if model.SubgraphsLength() != 1:
        raise NotImplementedError(
            'graphs with multiple sub-graphs are not supported')

    subgraph = model.Subgraphs(0)
    name = subgraph.Name()

    graph = TFGraph(name.decode() if name is not None else None)

    tensors = []
    for i in range(subgraph.TensorsLength()):
        tensor = subgraph.Tensors(i)
        name = tensor.Name().decode()
        shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())]
        dtype = _TensorTypeNameByValue[tensor.Type()]
        buffer = model.Buffers(tensor.Buffer())
        data = _get_data_as_ndarray(buffer, _TensorDtypeAsNumpy[tensor.Type()],
                                    shape)
        quant = _get_quantization(tensor)
        label = name if data is not None else None
        tensors.append(
            TFTensor(graph, utils.anystr_to_str(name), shape, dtype, data,
                     utils.anystr_to_str(label) if label is not None else None,
                     quant))

    for i in range(subgraph.OperatorsLength()):
        operator = subgraph.Operators(i)
        operatorCode = model.OperatorCodes(operator.OpcodeIndex())
        name = _BuiltinOperatorNameByValue[operatorCode.BuiltinCode()]

        options = operator.BuiltinOptions()
        optionsClass = _BuiltinOptionsClasses[operator.BuiltinOptionsType()]

        inputs = [
            tensors[operator.Inputs(i)] for i in range(operator.InputsLength())
            if operator.Inputs(i) != -1
        ]
        outputs = [
            tensors[operator.Outputs(i)]
            for i in range(operator.OutputsLength())
            if operator.Outputs(i) != -1
        ]

        if optionsClass is not None:
            optionsObject = optionsClass()
            optionsObject.Init(options.Bytes, options.Pos)
            attribs = _enumerate_attributes(optionsClass, optionsObject)
        else:
            attribs = {}

        if operatorCode.BuiltinCode() == tflite_fb.BuiltinOperator.CUSTOM:
            assert _custom_op_type_key not in attribs, \
                "'{}' shall not be set as an attribute".format(_custom_op_type_key)
            attribs[_custom_op_type_key] = operatorCode.CustomCode().decode(
                'ascii')
            assert _custom_op_options_key not in attribs, \
                "'{}' shall not be set as an attribute".format(_custom_op_options_key)
            attribs[_custom_op_options_key] = operator.CustomOptionsAsNumpy(
            ).tolist()
        TFOperation(graph, name, inputs, outputs, attribs)

    inputs = []
    for i in range(subgraph.InputsLength()):
        tensor_index = subgraph.Inputs(i)
        inputs.append(tensors[tensor_index])

    outputs = []
    for i in range(subgraph.OutputsLength()):
        tensor_index = subgraph.Outputs(i)
        outputs.append(tensors[tensor_index])

    graph.inputs = inputs
    graph.outputs = outputs

    return graph
示例#10
0
def get_feed_dict():
    placeholders = get_placeholders()
    feed_dict = {}
    for p in placeholders:
        feed_dict[utils.anystr_to_str(p.name)] = np.random.random(p.shape)
    return feed_dict
示例#11
0
 def fixstr(s):
     return utils.anystr_to_str(s) if s is not None else None
示例#12
0
def read_tf_graph_from_protobuf(filename):
    graph_def = tf_pb.GraphDef()
    with open(filename, 'rb') as file:
        graph_def.ParseFromString(file.read())

    graph = TFGraph()
    # just a graph to contain the tensors that are in attributes
    # no need to return this
    attrib_graph = TFGraph()

    attributes_by_node_id = {}
    outputs_by_node_id = {}

    for node in graph_def.node:
        outputs = []
        attributes = _get_attributes(node.attr, attrib_graph)
        output_count = _OutputCount.get(node.op, 1)
        if isinstance(output_count, str):
            output_count = attributes[output_count]
        assert isinstance(output_count, int)
        if output_count >= 1:
            output = TFTensor(graph, utils.anystr_to_str(node.name))
            outputs.append(output)
            for i in range(1, output_count):
                tensor_name = utils.anystr_to_str(node.name) + ':' + str(i)
                output = TFTensor(graph, tensor_name)
                outputs.append(output)
        outputs_by_node_id[id(node)] = outputs
        attributes_by_node_id[id(node)] = attributes

    tensor_by_name = {
        tensor.name: tensor
        for outputs in six.itervalues(outputs_by_node_id) for tensor in outputs
    }
    placeholders = []
    for node in graph_def.node:
        attributes = attributes_by_node_id[id(node)]
        outputs = outputs_by_node_id[id(node)]

        if node.op == 'Placeholder':
            assert len(outputs) == 1
            tensor = outputs[0]
            tensor.shape = attributes[
                'shape'] if 'shape' in attributes else None
            tensor.dtype = attributes[
                'dtype'] if 'dtype' in attributes else None
            placeholders.append(tensor)
        elif node.op == 'Const':
            assert len(outputs) == 1
            tensor = outputs[0]
            value = attributes['value']
            if isinstance(value, TFTensor):
                tensor.shape = value.shape
                tensor.dtype = value.dtype
                tensor.data = value.data
            else:
                tensor.data = value
        else:
            inputs = tuple([tensor_by_name[name] for name in node.input])
            TFOperation(graph,
                        name=utils.anystr_to_str(node.op),
                        inputs=inputs,
                        outputs=outputs,
                        attribs=attributes)

    for tensor in graph.tensors:
        if tensor.name is not None and ':' not in tensor.name:
            tensor.name += ':0'

    graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor)
                                for tensor in placeholders])
    graph_outputs = []
    for op in graph.operations:
        if all(len(output.consumers) == 0 for output in op.outputs):
            for output in op.outputs:
                graph_outputs.append(output)

    graph.outputs = OrderedDict([
        ('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor)
        for i, tensor in enumerate(graph_outputs)
    ])

    return graph