Exemple #1
0
def _InitOpDefLibrary():
    op_list = _op_def_pb2.OpList()
    _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
    _op_def_registry.register_op_list(op_list)
    op_def_lib = _op_def_library.OpDefLibrary()
    op_def_lib.add_op_list(op_list)
    return op_def_lib
def _InitOpDefLibrary(op_list_proto_bytes):
  op_list = _op_def_pb2.OpList()
  op_list.ParseFromString(op_list_proto_bytes)
  _op_def_registry.register_op_list(op_list)
  op_def_lib = _op_def_library.OpDefLibrary()
  op_def_lib.add_op_list(op_list)
  return op_def_lib
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  `REGISTER_*` macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            # pylint: disable=protected-access
            raise errors_impl._make_specific_exception(None, None, error_msg,
                                                       error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Delete the library handle to release any memory held in C
    # that are no longer needed.
    py_tf.TF_DeleteLibraryHandle(lib_handle)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    if module_name in sys.modules:
        return sys.modules[module_name]
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
Exemple #4
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.
  Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
  defined in the library.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            with _OP_LIBRARY_MAP_LOCK:
                if (error_code == error_codes_pb2.ALREADY_EXISTS
                        and 'has already been loaded' in error_msg
                        and library_filename in _OP_LIBRARY_MAP):
                    return _OP_LIBRARY_MAP[library_filename]
            # pylint: disable=protected-access
            raise errors._make_specific_exception(None, None, error_msg,
                                                  error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    # Memoize the filename to module mapping.
    with _OP_LIBRARY_MAP_LOCK:
        _OP_LIBRARY_MAP[library_filename] = module
    return module
Exemple #5
0
def get_ops():
    sdk_path = os.path.dirname(__file__)
    ops_file = os.path.join(sdk_path, 'mmdnn_ops.pbtxt')

    ops = op_def_pb2.OpList()
    with open(ops_file) as fn:
        Merge(fn.read(), ops)
    ops_list = [MessageToDict(op) for op in ops.op]
    return ops_list
Exemple #6
0
def register_ops_if_needed(graph_ops):
    """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
    missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

    if not missing_ops:
        return

    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
    cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

    missing_op_list = op_def_pb2.OpList()
    for missing_op in missing_ops:
        if missing_op not in cpp_registry_ops:
            tf.logging.info(
                "Op %s is missing from both the python and C++ registry.",
                missing_op)
        else:
            missing_op_list.op.extend([cpp_registry_ops[missing_op]])
            tf.logging.info(
                "Adding op %s from c++ registry to python registry.",
                missing_op)

    op_def_registry.register_op_list(missing_op_list)

    # Note: Only raise missing op ValueError after trying to load ops.
    # This allows the test to exercise all the calls into TensorFlow
    # without having to write a C + python test.
    if not missing_ops <= set(cpp_registry_ops.keys()):
        raise RuntimeError(
            "Graph ops missing from the python registry (%s) are also absent from "
            "the c++ registry." %
            missing_ops.difference(set(cpp_registry_ops.keys())))
Exemple #7
0
def register_prelu_op():
    """Register a virtual PReLU OpDef.

  This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0.
  """

    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
Exemple #8
0
    def __init__(self):
        op_def_proto = op_def_pb2.OpList()
        buf = c_api.TF_GetAllOpList()
        try:
            op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
            self._api_def_map = c_api.TF_NewApiDefMap(buf)
        finally:
            c_api.TF_DeleteBuffer(buf)

        self._op_per_name = {}
        for op in op_def_proto.op:
            self._op_per_name[op.name] = op
Exemple #9
0
def sync():
  """Synchronize the contents of the Python registry with C++."""
  with _sync_lock:
    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
    for op_def in cpp_op_list.op:
      # If an OpList is registered from a gen_*_ops.py, it does not any
      # descriptions. Strip them here as well to satisfy validation in
      # register_op_list.
      _remove_non_deprecated_descriptions(op_def)
      _registered_ops[op_def.name] = op_def
Exemple #10
0
  def sync():
    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))

    registered_ops = op_def_registry.get_registered_ops()
    for op_def in cpp_op_list.op:
      # If an OpList is registered from a gen_*_ops.py, it does not any
      # descriptions. Strip them here as well to satisfy validation in
      # register_op_list.
      _remove_non_deprecated_descriptions(op_def)
      registered_ops[op_def.name] = op_def
Exemple #11
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  `REGISTER_*` macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    lib_handle = py_tf.TF_LoadLibrary(library_filename)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Delete the library handle to release any memory held in C
    # that are no longer needed.
    py_tf.TF_DeleteLibraryHandle(lib_handle)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    if module_name in sys.modules:
        return sys.modules[module_name]
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    # Allow this to be recognized by AutoGraph.
    setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
    sys.modules[module_name] = module
    return module
Exemple #12
0
def register_prelu_op():
    """global registry of PReLU op for python, this allow metagraph to be
  properly generated with unregistered Prelu op
  """

    value = attr_value_pb2.AttrValue()
    value.list.type.extend([types_pb2.DT_FLOAT])
    attr = op_def_pb2.OpDef.AttrDef()
    attr.name = 'T'
    attr.type = 'type'
    attr.allowed_values.CopyFrom(value)
    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    prelu_op_def.attr.extend([attr])
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.
  Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
  defined in the library.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        if py_tf.TF_GetCode(status) != 0:
            raise RuntimeError(compat.as_text(py_tf.TF_Message(status)))
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
Exemple #14
0
def _create_op_def_library(op_protos):
    for op_proto in op_protos:
        registered_ops = _registry.get_registered_ops()
        if op_proto.name not in registered_ops:
            raise LookupError("Op with name {0} not registered".format(
                op_proto.name))

        op_def_lib = _op_def_library.OpDefLibrary()
        ops_proto = _op_def_pb2.OpList()
        ops_proto.op.extend([op_proto])

    # Fails if the interfaces ("op schemas") don't match between the
    # previously registered op and this one.
    _registry.register_op_list(ops_proto)

    op_def_lib.add_op_list(ops_proto)

    return op_def_lib
Exemple #15
0
def register_prelu_op():
    """Register a virtual PReLU OpDef.

  This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0.
  """

    value = attr_value_pb2.AttrValue()
    value.list.type.extend([types_pb2.DT_FLOAT])
    attr = op_def_pb2.OpDef.AttrDef()
    attr.name = 'T'
    attr.type = 'type'
    attr.allowed_values.CopyFrom(value)
    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    prelu_op_def.attr.extend([attr])
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
Exemple #16
0
 def testDefaultAttrsRemoved(self):
   producer_op_list = op_def_pb2.OpList()
   text_format.Merge("""
     op {
       name: 'OpWithFutureDefaultAttr'
       attr { name: 'default_int' type: 'int' default_value { i: 456 } }
     }
   """, producer_op_list)
   # Attr only in producer_op_list with default value gets removed.
   with ops.Graph().as_default():
     a = importer.import_graph_def(
         self._MakeGraphDef("""
         node { name: 'A' op: 'OpWithFutureDefaultAttr'
                attr { key: 'default_int' value { i: 456 } } }
         """),
         return_elements=["A"],
         producer_op_list=producer_op_list)
     with self.assertRaisesRegexp(
         ValueError, "Operation 'import/A' has no attr named 'default_int'."):
       a[0].get_attr("default_int")
Exemple #17
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # Verify that all used ops are registered.
    registered_ops = op_def_registry.get_registered_ops()
    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    for op in used_ops:
        if op not in registered_ops and op not in op_whitelist:
            raise ValueError(
                "Op %s is used by the graph, but is not registered" % op)

    # Build the stripped op list in sorted order
    return op_def_pb2.OpList(op=[
        registered_ops[op] for op in sorted(used_ops) if op in registered_ops
    ])
Exemple #18
0
  def testDefaultAttrsRemoved(self):
    producer_op_list = op_def_pb2.OpList()
    text_format.Merge("""
      op {
        name: 'OpWithFutureDefaultAttr'
        attr { name: 'default_int' type: 'int' default_value { i: 456 } }
      }
    """, producer_op_list)
    # Attr only in producer_op_list with default value gets removed.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 456 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      if ops._USE_C_API:
        error_msg = "Operation 'import/A' has no attr named 'default_int'."
      else:
        error_msg = "No attr named 'default_int'"
      with self.assertRaisesRegexp(ValueError, error_msg):
        a[0].get_attr("default_int")

    # Unknown attrs cannot be imported using C API. This test will eventually be
    # deleted.
    if not ops._USE_C_API:
      # Attr only in producer_op_list with non-default value is preserved.
      with ops.Graph().as_default():
        a = importer.import_graph_def(
            self._MakeGraphDef("""
            node { name: 'A' op: 'OpWithFutureDefaultAttr'
                   attr { key: 'default_int' value { i: 987 } } }
            """),
            return_elements=["A"],
            producer_op_list=producer_op_list)
        self.assertEqual(987, a[0].get_attr("default_int"))
Exemple #19
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.
  """
    # This is similar to StrippedOpListForGraph in C++, but unlike its
    # C++ counterpart, this version does not require all ops to be registered.
    # This is done to support Prelu fusion in tfjs.
    used_ops = ops_used_by_graph_def(graph_def)
    op_defs = []
    for op in sorted(used_ops):
        op_def = op_def_registry.get(op)
        if op_def is not None:
            op_defs.append(op_def)
    return op_def_pb2.OpList(op=op_defs)
Exemple #20
0
import tensorflow as tf
from tensorflow.core.framework import op_def_pb2
from google.protobuf import text_format
from termcolor import colored
import re
import textwrap

ops = op_def_pb2.OpList()
text_format.Merge(open('ops.pbtxt').read(), ops)

data_type_name = [
    '', 'FLOAT', 'DOUBLE', 'INT32', 'UINT8', 'INT16', 'INT8', 'STRING',
    'COMPLEX64', 'COMPLEX', 'INT64', 'BOOL', 'QINT8', 'QUINT8', 'QINT32',
    'BFLOAT16', 'QINT16', 'QUINT16', 'UINT16', 'COMPLEX128', 'HALF',
    'RESOURCE', 'VARIANT', 'UINT32', 'UINT64'
]
lua_keywords = [
    'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function',
    'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then',
    'true', 'until', 'while'
]


def proc_var_name(name):
    if name in lua_keywords:
        name += '_'
    return name


pattern_arg = '  Args:'
pattern_ret = '  Returns:'
Exemple #21
0
def metadata():
    categories = {
        'Assign': 'Control',
        'AvgPool': 'Pool',
        'BatchNormWithGlobalNormalization': 'Normalization',
        'BiasAdd': 'Layer',
        'Concat': 'Tensor',
        'ConcatV2': 'Tensor',
        'Const': 'Constant',
        'Conv2D': 'Layer',
        'DepthwiseConv2dNative': 'Layer',
        'Dequantize': 'Tensor',
        'Elu': 'Activation',
        'FusedBatchNorm': 'Normalization',
        'FusedBatchNormV2': 'Normalization',
        'FusedBatchNormV3': 'Normalization',
        'Gather': 'Transform',
        'Identity': 'Control',
        'LeakyRelu': 'Activation',
        'LRN': 'Normalization',
        'LSTMBlockCell': 'Layer',
        'MaxPool': 'Pool',
        'MaxPoolV2': 'Pool',
        'MaxPoolWithArgmax': 'Pool',
        'Pad': 'Tensor',
        'Relu': 'Activation',
        'Relu6': 'Activation',
        'Reshape': 'Shape',
        'Sigmoid': 'Activation',
        'Slice': 'Tensor',
        'Softmax': 'Activation',
        'Split': 'Tensor',
        'Squeeze': 'Shape',
        'StridedSlice': 'Tensor',
        'swish_f32': 'Activation',
        'Variable': 'Control',
        'VariableV2': 'Control',
    }

    def find_multiline(line, colon):
        if colon == -1:
            return None
        line = line[colon + 1:]
        while line.startswith(' '):
            line = line[1:]
        if line.startswith('<<'):
            line = line[2:]
            return line
        return None

    def str_escape(text):
        result = ''
        for c in text:
            if (c == '\n'):
                result += '\\n'
            elif (c == '\r'):
                result += "\\r"
            elif (c == '\t'):
                result += "\\t"
            elif (c == '\"'):
                result += "\\\""
            elif (c == '\''):
                result += "\\'"
            elif (c == '\\'):
                result += "\\\\"
            else:
                result += c
        return result

    def pbtxt_from_multiline(multiline_pbtxt):
        pbtxt = ''
        while len(multiline_pbtxt) > 0:
            index = multiline_pbtxt.find('\n')
            if index == -1:
                pbtxt = pbtxt + multiline_pbtxt
                multiline_pbtxt = ''
                break
            line = multiline_pbtxt[0:index]
            multiline_pbtxt = multiline_pbtxt[index + 1:]
            colon = line.find(':')
            end = find_multiline(line, colon)
            if end == None:
                pbtxt = pbtxt + line + '\n'
                continue
            pbtxt = pbtxt + line[0:colon + 1]
            unescaped = ''
            newline = False
            line = ''
            while len(multiline_pbtxt) > 0:
                index = multiline_pbtxt.find('\n')
                line = multiline_pbtxt[0:index]
                multiline_pbtxt = multiline_pbtxt[index + 1:]
                if line.startswith(end):
                    line = line[len(end):]
                    break
                if newline:
                    unescaped = unescaped + '\n'
                newline = True
                unescaped = unescaped + line
                line = ''
            pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
        return pbtxt

    def read_api_def_map(folder):
        api_def_map = {}
        file_list = os.listdir(folder)
        file_list = sorted(file_list)
        for filename in file_list:
            if filename.endswith('.pbtxt'):
                api_defs = api_def_pb2.ApiDefs()
                filename = folder + '/' + filename
                with open(filename) as handle:
                    multiline_pbtxt = handle.read()
                    pbtxt = pbtxt_from_multiline(multiline_pbtxt)
                    text_format.Merge(pbtxt, api_defs)
                for api_def in api_defs.op:
                    api_def_map[api_def.graph_op_name] = api_def
        return api_def_map

    def convert_type(type):
        return {'type': 'type', 'value': type}

    def convert_tensor(tensor):
        return {'type': 'tensor', 'value': '?'}

    def convert_shape(shape):
        return {'type': 'shape', 'value': '?'}

    def convert_number(number):
        if number == float('inf'):
            return 'NaN'
        if number == float('-inf'):
            return '-NaN'
        return number

    attr_type_table = {
        'type': 'type',
        'list(type)': 'type[]',
        'bool': 'boolean',
        'int': 'int64',
        'list(int)': 'int64[]',
        'float': 'float32',
        'list(float)': 'float32[]',
        'string': 'string',
        'list(string)': 'string[]',
        'shape': 'shape',
        'list(shape)': 'shape[]',
        'tensor': 'tensor',
        'func': 'function',
        'list(func)': 'function[]'
    }

    def convert_attr_type(type):
        if type in attr_type_table:
            return attr_type_table[type]
        print(type)
        return type

    def convert_attr_value(attr_value):
        if attr_value.HasField('list'):
            list = []
            attr_value_list = attr_value.list
            if len(attr_value_list.s) > 0:
                for s in attr_value_list.s:
                    list.append(s.decode('utf8'))
            if len(attr_value_list.i) > 0:
                for i in attr_value_list.i:
                    list.append(i)
            if len(attr_value_list.f) > 0:
                for f in attr_value_list.f:
                    list.append(convert_number(f))
            if len(attr_value_list.type) > 0:
                for type in attr_value_list.type:
                    list.append(convert_type(type))
            if len(list) == 0:
                for _, value in attr_value_list.ListFields():
                    if len(value) > 0:
                        raise Exception()
            return list
        if attr_value.HasField('s'):
            return attr_value.s.decode('utf8')
        if attr_value.HasField('i'):
            return attr_value.i
        if attr_value.HasField('f'):
            return convert_number(attr_value.f)
        if attr_value.HasField('b'):
            return attr_value.b
        if attr_value.HasField('type'):
            return convert_type(attr_value.type)
        if attr_value.HasField('tensor'):
            return convert_tensor(attr_value.tensor)
        if attr_value.HasField('shape'):
            return convert_shape(attr_value.shape)
        raise Exception()

    _TYPE_TO_STRING = {
        types_pb2.DataType.DT_HALF: "float16",
        types_pb2.DataType.DT_FLOAT: "float32",
        types_pb2.DataType.DT_DOUBLE: "float64",
        types_pb2.DataType.DT_INT32: "int32",
        types_pb2.DataType.DT_UINT8: "uint8",
        types_pb2.DataType.DT_UINT16: "uint16",
        types_pb2.DataType.DT_UINT32: "uint32",
        types_pb2.DataType.DT_UINT64: "uint64",
        types_pb2.DataType.DT_INT16: "int16",
        types_pb2.DataType.DT_INT8: "int8",
        types_pb2.DataType.DT_STRING: "string",
        types_pb2.DataType.DT_COMPLEX64: "complex64",
        types_pb2.DataType.DT_COMPLEX128: "complex128",
        types_pb2.DataType.DT_INT64: "int64",
        types_pb2.DataType.DT_BOOL: "bool",
        types_pb2.DataType.DT_QINT8: "qint8",
        types_pb2.DataType.DT_QUINT8: "quint8",
        types_pb2.DataType.DT_QINT16: "qint16",
        types_pb2.DataType.DT_QUINT16: "quint16",
        types_pb2.DataType.DT_QINT32: "qint32",
        types_pb2.DataType.DT_BFLOAT16: "bfloat16",
        types_pb2.DataType.DT_RESOURCE: "resource",
        types_pb2.DataType.DT_VARIANT: "variant",
        types_pb2.DataType.DT_HALF_REF: "float16_ref",
        types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
        types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
        types_pb2.DataType.DT_INT32_REF: "int32_ref",
        types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
        types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
        types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
        types_pb2.DataType.DT_INT16_REF: "int16_ref",
        types_pb2.DataType.DT_INT8_REF: "int8_ref",
        types_pb2.DataType.DT_STRING_REF: "string_ref",
        types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
        types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
        types_pb2.DataType.DT_INT64_REF: "int64_ref",
        types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
        types_pb2.DataType.DT_BOOL_REF: "bool_ref",
        types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
        types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
        types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
        types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
        types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
        types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
        types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
        types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
    }

    def format_data_type(data_type):
        if data_type in _TYPE_TO_STRING:
            return _TYPE_TO_STRING[data_type]
        raise Exception()

    def format_attribute_value(value):
        if type(value
                ) is dict and 'type' in value and 'value' in value and value[
                    'type'] == 'type':
            return format_data_type(value['value'])
        if type(value) is str:
            return value
        if value == True:
            return 'true'
        if value == False:
            return 'false'
        raise Exception()

    tensorflow_repo_dir = os.path.join(os.path.dirname(__file__),
                                       '../third_party/source/tensorflow')
    api_def_map = read_api_def_map(
        os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api'))
    input_file = os.path.join(tensorflow_repo_dir,
                              'tensorflow/core/ops/ops.pbtxt')
    ops_list = op_def_pb2.OpList()
    with open(input_file) as input_handle:
        text_format.Merge(input_handle.read(), ops_list)

    json_root = []

    for op in ops_list.op:
        # print(op.name)
        json_schema = {}
        json_schema['name'] = op.name
        if op.name in categories:
            json_schema['category'] = categories[op.name]
        api_def = api_def_pb2.ApiDef()
        if op.name in api_def_map:
            api_def = api_def_map[op.name]
        # if op.deprecation.version != 0:
        #    print('[' + op.name + ']')
        #    print(op.deprecation.version)
        #    print(op.deprecation.explanation)
        api_def_attr_map = {}
        for attr in api_def.attr:
            api_def_attr_map[attr.name] = attr
        api_def_in_arg_map = {}
        for in_arg in api_def.in_arg:
            api_def_in_arg_map[in_arg.name] = in_arg
        api_def_out_arg_map = {}
        for out_arg in api_def.out_arg:
            api_def_out_arg_map[out_arg.name] = out_arg
        if api_def.summary:
            json_schema['summary'] = api_def.summary
        if api_def.description:
            json_schema['description'] = api_def.description
        for attr in op.attr:
            if not 'attributes' in json_schema:
                json_schema['attributes'] = []
            json_attribute = {}
            json_attribute['name'] = attr.name
            attr_type = convert_attr_type(attr.type)
            if attr_type:
                json_attribute['type'] = attr_type
            else:
                del json_attribute['type']
            if attr.name in api_def_attr_map:
                api_def_attr = api_def_attr_map[attr.name]
                if api_def_attr.description:
                    json_attribute['description'] = api_def_attr.description
            if attr.has_minimum:
                json_attribute['minimum'] = attr.minimum
            if attr.HasField('allowed_values'):
                allowed_values = convert_attr_value(attr.allowed_values)
                description = json_attribute[
                    'description'] + ' ' if 'description' in json_attribute else ''
                description = description + 'Must be one of the following: ' + ', '.join(
                    list(
                        map(lambda x: "`" + format_attribute_value(x) + "`",
                            allowed_values))) + '.'
                json_attribute['description'] = description
            if attr.HasField('default_value'):
                default_value = convert_attr_value(attr.default_value)
                json_attribute['default'] = default_value
            json_schema['attributes'].append(json_attribute)
        for input_arg in op.input_arg:
            if not 'inputs' in json_schema:
                json_schema['inputs'] = []
            json_input = {}
            json_input['name'] = input_arg.name
            if input_arg.name in api_def_in_arg_map:
                api_def_in_arg = api_def_in_arg_map[input_arg.name]
                if api_def_in_arg.description:
                    json_input['description'] = api_def_in_arg.description
            if input_arg.number_attr:
                json_input['numberAttr'] = input_arg.number_attr
            if input_arg.type:
                json_input['type'] = input_arg.type
            if input_arg.type_attr:
                json_input['typeAttr'] = input_arg.type_attr
            if input_arg.type_list_attr:
                json_input['typeListAttr'] = input_arg.type_list_attr
            if input_arg.is_ref:
                json_input['isRef'] = True
            json_schema['inputs'].append(json_input)
        for output_arg in op.output_arg:
            if not 'outputs' in json_schema:
                json_schema['outputs'] = []
            json_output = {}
            json_output['name'] = output_arg.name
            if output_arg.name in api_def_out_arg_map:
                api_def_out_arg = api_def_out_arg_map[output_arg.name]
                if api_def_out_arg.description:
                    json_output['description'] = api_def_out_arg.description
            if output_arg.number_attr:
                json_output['numberAttr'] = output_arg.number_attr
            if output_arg.type:
                json_output['type'] = output_arg.type
            elif output_arg.type_attr:
                json_output['typeAttr'] = output_arg.type_attr
            elif output_arg.type_list_attr:
                json_output['typeListAttr'] = output_arg.type_list_attr
            if output_arg.is_ref:
                json_output['isRef'] = True
            json_schema['outputs'].append(json_output)
        json_root.append(json_schema)

    json_file = os.path.join(os.path.dirname(__file__),
                             '../source/tf-metadata.json')
    with io.open(json_file, 'w', newline='') as fout:
        json_data = json.dumps(json_root, sort_keys=False, indent=2)
        for line in json_data.splitlines():
            line = line.rstrip()
            fout.write(line)
            fout.write('\n')
Exemple #22
0
ops.RegisterShape("Iif")(_UnknownShape)
ops.RegisterShape("Iii")(_UnknownShape)
ops.RegisterShape("In")(_UnknownShape)
ops.RegisterShape("Iri")(_UnknownShape)
ops.RegisterShape("None")(_UnknownShape)
ops.RegisterShape("Of")(_UnknownShape)
ops.RegisterShape("Oi")(_UnknownShape)
ops.RegisterShape("Oif")(_UnknownShape)
ops.RegisterShape("Oii")(_UnknownShape)
ops.RegisterShape("OpWithDefaultAttr")(_UnknownShape)
ops.RegisterShape("OpWithFutureDefaultAttr")(_UnknownShape)
ops.RegisterShape("Or")(_UnknownShape)
ops.RegisterShape("Otl")(_UnknownShape)
ops.RegisterShape("Unary")(_UnknownShape)

_op_list = op_def_pb2.OpList()
text_format.Merge(
    """
  op {
    name: 'None'
  }
  op {
    name: 'Oi'
    output_arg { name: 'a' type: DT_INT32 }
  }
  op {
    name: 'Or'
    output_arg { name: 'a' type: DT_INT32 is_ref: true }
  }
  op {
    name: 'Of'
Exemple #23
0
def metadata():
    categories = {
        'Const': 'Constant',
        'Conv2D': 'Layer',
        'BiasAdd': 'Layer',
        'DepthwiseConv2dNative': 'Layer',
        'Relu': 'Activation',
        'Relu6': 'Activation',
        'Elu': 'Activation',
        'Softmax': 'Activation',
        'Sigmoid': 'Activation',
        'LRN': 'Normalization',
        'MaxPool': 'Pool',
        'MaxPoolV2': 'Pool',
        'AvgPool': 'Pool',
        'Reshape': 'Shape',
        'Squeeze': 'Shape',
        'ConcatV2': 'Tensor',
        'Split': 'Tensor',
        'Dequantize': 'Tensor',
        'Identity': 'Control',
        'Variable': 'Control',
        'VariableV2': 'Control',
        'Assign': 'Control',
        'BatchNormWithGlobalNormalization': 'Normalization',
        'FusedBatchNorm': 'Normalization',
        # 'VariableV2':
        # 'Assign':
        # 'BiasAdd':
    }

    def find_multiline(line, colon):
        if colon == -1:
            return None
        line = line[colon + 1:]
        while line.startswith(' '):
            line = line[1:]
        if line.startswith('<<'):
            line = line[2:]
            return line
        return None

    def str_escape(text):
        result = ''
        for c in text:
            if (c == '\n'):
                result += '\\n'
            elif (c == '\r'):
                result += "\\r"
            elif (c == '\t'):
                result += "\\t"
            elif (c == '\"'):
                result += "\\\""
            elif (c == '\''):
                result += "\\'"
            elif (c == '\\'):
                result += "\\\\"
            else:
                result += c
        return result

    def pbtxt_from_multiline(multiline_pbtxt):
        pbtxt = ''
        while len(multiline_pbtxt) > 0:
            index = multiline_pbtxt.find('\n')
            if index == -1:
                pbtxt = pbtxt + multiline_pbtxt
                multiline_pbtxt = ''
                break
            line = multiline_pbtxt[0:index]
            multiline_pbtxt = multiline_pbtxt[index + 1:]
            colon = line.find(':')
            end = find_multiline(line, colon)
            if end == None:
                pbtxt = pbtxt + line + '\n'
                continue
            pbtxt = pbtxt + line[0:colon + 1]
            unescaped = ''
            newline = False
            line = ''
            while len(multiline_pbtxt) > 0:
                index = multiline_pbtxt.find('\n')
                line = multiline_pbtxt[0:index]
                multiline_pbtxt = multiline_pbtxt[index + 1:]
                if line.startswith(end):
                    line = line[len(end):]
                    break
                if newline:
                    unescaped = unescaped + '\n'
                newline = True
                unescaped = unescaped + line
                line = ''
            pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
        return pbtxt

    def read_api_def_map(folder):
        api_def_map = {}
        file_list = os.listdir(folder)
        file_list = sorted(file_list)
        for filename in file_list:
            api_defs = api_def_pb2.ApiDefs()
            filename = folder + '/' + filename
            with open(filename) as handle:
                multiline_pbtxt = handle.read()
                pbtxt = pbtxt_from_multiline(multiline_pbtxt)
                text_format.Merge(pbtxt, api_defs)
            for api_def in api_defs.op:
                api_def_map[api_def.graph_op_name] = api_def
        return api_def_map

    def convert_type(type):
        return {'type': 'type', 'value': type}

    def convert_tensor(tensor):
        return {'type': 'tensor', 'value': '?'}

    def convert_shape(shape):
        return {'type': 'shape', 'value': '?'}

    def convert_number(number):
        if number == float('inf'):
            return 'NaN'
        if number == float('-inf'):
            return '-NaN'
        return number

    attr_type_table = {
        'type': 'type',
        'list(type)': 'type[]',
        'bool': 'boolean',
        'int': 'int64',
        'list(int)': 'int64[]',
        'float': 'float32',
        'list(float)': 'float32[]',
        'string': 'string',
        'list(string)': 'string[]',
        'shape': 'shape',
        'list(shape)': 'shape[]',
        'tensor': 'tensor',
        'func': 'function',
        'list(func)': 'function[]'
    }

    def convert_attr_type(type):
        if type in attr_type_table:
            return attr_type_table[type]
        print(type)
        return type

    def convert_attr_value(attr_value):
        if attr_value.HasField('list'):
            list = []
            attr_value_list = attr_value.list
            if len(attr_value_list.s) > 0:
                for s in attr_value_list.s:
                    list.append(s.decode('utf8'))
            if len(attr_value_list.i) > 0:
                for i in attr_value_list.i:
                    list.append(i)
            if len(attr_value_list.f) > 0:
                for f in attr_value_list.f:
                    list.append(convert_number(f))
            if len(attr_value_list.type) > 0:
                for type in attr_value_list.type:
                    list.append(convert_type(type))
            if len(list) == 0:
                for _, value in attr_value_list.ListFields():
                    if len(value) > 0:
                        raise Exception()
            return list
        if attr_value.HasField('s'):
            return attr_value.s.decode('utf8')
        if attr_value.HasField('i'):
            return attr_value.i
        if attr_value.HasField('f'):
            return convert_number(attr_value.f)
        if attr_value.HasField('b'):
            return attr_value.b
        if attr_value.HasField('type'):
            return convert_type(attr_value.type)
        if attr_value.HasField('tensor'):
            return convert_tensor(attr_value.tensor)
        if attr_value.HasField('shape'):
            return convert_shape(attr_value.shape)
        raise Exception()

    tensorflow_repo_dir = os.path.join(os.path.dirname(__file__),
                                       '../third_party/src/tensorflow')
    api_def_map = read_api_def_map(
        os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api'))
    input_file = os.path.join(tensorflow_repo_dir,
                              'tensorflow/core/ops/ops.pbtxt')
    ops_list = op_def_pb2.OpList()
    with open(input_file) as input_handle:
        text_format.Merge(input_handle.read(), ops_list)

    json_root = []

    for op in ops_list.op:
        # print(op.name)
        json_schema = {}
        if op.name in categories:
            json_schema['category'] = categories[op.name]
        api_def = api_def_pb2.ApiDef()
        if op.name in api_def_map:
            api_def = api_def_map[op.name]
        # if op.deprecation.version != 0:
        #    print('[' + op.name + ']')
        #    print(op.deprecation.version)
        #    print(op.deprecation.explanation)
        api_def_attr_map = {}
        for attr in api_def.attr:
            api_def_attr_map[attr.name] = attr
        api_def_in_arg_map = {}
        for in_arg in api_def.in_arg:
            api_def_in_arg_map[in_arg.name] = in_arg
        api_def_out_arg_map = {}
        for out_arg in api_def.out_arg:
            api_def_out_arg_map[out_arg.name] = out_arg
        if api_def.summary:
            json_schema['summary'] = api_def.summary
        if api_def.description:
            json_schema['description'] = api_def.description
        for attr in op.attr:
            if not 'attributes' in json_schema:
                json_schema['attributes'] = []
            json_attribute = {}
            json_attribute['name'] = attr.name
            attr_type = convert_attr_type(attr.type)
            if attr_type:
                json_attribute['type'] = attr_type
            else:
                del json_attribute['type']
            if attr.name in api_def_attr_map:
                api_def_attr = api_def_attr_map[attr.name]
                if api_def_attr.description:
                    json_attribute['description'] = api_def_attr.description
            if attr.has_minimum:
                json_attribute['minimum'] = attr.minimum
            if attr.HasField('allowed_values'):
                json_attribute['allowedValues'] = convert_attr_value(
                    attr.allowed_values)
            if attr.HasField('default_value'):
                json_attribute['default'] = convert_attr_value(
                    attr.default_value)
            json_schema['attributes'].append(json_attribute)
        for input_arg in op.input_arg:
            if not 'inputs' in json_schema:
                json_schema['inputs'] = []
            json_input = {}
            json_input['name'] = input_arg.name
            if input_arg.name in api_def_in_arg_map:
                api_def_in_arg = api_def_in_arg_map[input_arg.name]
                if api_def_in_arg.description:
                    json_input['description'] = api_def_in_arg.description
            if input_arg.number_attr:
                json_input['numberAttr'] = input_arg.number_attr
            if input_arg.type:
                json_input['type'] = input_arg.type
            if input_arg.type_attr:
                json_input['typeAttr'] = input_arg.type_attr
            if input_arg.type_list_attr:
                json_input['typeListAttr'] = input_arg.type_list_attr
            if input_arg.is_ref:
                json_input['isRef'] = True
            json_schema['inputs'].append(json_input)
        for output_arg in op.output_arg:
            if not 'outputs' in json_schema:
                json_schema['outputs'] = []
            json_output = {}
            json_output['name'] = output_arg.name
            if output_arg.name in api_def_out_arg_map:
                api_def_out_arg = api_def_out_arg_map[output_arg.name]
                if api_def_out_arg.description:
                    json_output['description'] = api_def_out_arg.description
            if output_arg.number_attr:
                json_output['numberAttr'] = output_arg.number_attr
            if output_arg.type:
                json_output['type'] = output_arg.type
            elif output_arg.type_attr:
                json_output['typeAttr'] = output_arg.type_attr
            elif output_arg.type_list_attr:
                json_output['typeListAttr'] = output_arg.type_list_attr
            if output_arg.is_ref:
                json_output['isRef'] = True
            json_schema['outputs'].append(json_output)
        json_root.append({'name': op.name, 'schema': json_schema})

    json_file = os.path.join(os.path.dirname(__file__),
                             '../src/tf-metadata.json')
    with io.open(json_file, 'w', newline='') as fout:
        json_data = json.dumps(json_root, sort_keys=True, indent=2)
        for line in json_data.splitlines():
            line = line.rstrip()
            if sys.version_info[0] < 3:
                line = unicode(line)
            fout.write(line)
            fout.write('\n')