Esempio n. 1
0
def _InitOpDefLibrary():
  op_list = op_def_pb2.OpList()
  text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
  op_def_registry.register_op_list(op_list)
  op_def_lib = op_def_library.OpDefLibrary()
  op_def_lib.add_op_list(op_list)
  return op_def_lib
Esempio n. 2
0
def _InitOpDefLibrary():
    op_list = _op_def_pb2.OpList()
    _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
    _op_def_registry.register_op_list(op_list)
    op_def_lib = _op_def_library.OpDefLibrary()
    op_def_lib.add_op_list(op_list)
    return op_def_lib
Esempio n. 3
0
def _InitOpDefLibrary(op_list_proto_bytes):
  op_list = _op_def_pb2.OpList()
  op_list.ParseFromString(op_list_proto_bytes)
  _op_def_registry.register_op_list(op_list)
  op_def_lib = _op_def_library.OpDefLibrary()
  op_def_lib.add_op_list(op_list)
  return op_def_lib
def _InitOpDefLibrary(op_list_proto_bytes):
    op_list = _op_def_pb2.OpList()
    op_list.ParseFromString(op_list_proto_bytes)
    _op_def_registry.register_op_list(op_list)
    op_def_lib = _op_def_library.OpDefLibrary()
    op_def_lib.add_op_list(op_list)
    return op_def_lib
Esempio n. 5
0
def register_prelu_op():
    """Register a virtual PReLU OpDef.

  This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0.
  """

    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
Esempio n. 6
0
def register_prelu_op():
    """global registry of PReLU op for python, this allow metagraph to be
  properly generated with unregistered Prelu op
  """

    value = attr_value_pb2.AttrValue()
    value.list.type.extend([types_pb2.DT_FLOAT])
    attr = op_def_pb2.OpDef.AttrDef()
    attr.name = 'T'
    attr.type = 'type'
    attr.allowed_values.CopyFrom(value)
    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    prelu_op_def.attr.extend([attr])
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
Esempio n. 7
0
def register_prelu_op():
    """Register a virtual PReLU OpDef.

  This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0.
  """

    value = attr_value_pb2.AttrValue()
    value.list.type.extend([types_pb2.DT_FLOAT])
    attr = op_def_pb2.OpDef.AttrDef()
    attr.name = 'T'
    attr.type = 'type'
    attr.allowed_values.CopyFrom(value)
    prelu_op_def = op_def_pb2.OpDef()
    prelu_op_def.name = 'Prelu'
    prelu_op_def.attr.extend([attr])
    missing_op_list = op_def_pb2.OpList()
    missing_op_list.op.extend([prelu_op_def])
    op_def_registry.register_op_list(missing_op_list)
Esempio n. 8
0
def _create_op_def_library(op_protos):
    for op_proto in op_protos:
        registered_ops = _registry.get_registered_ops()
        if op_proto.name not in registered_ops:
            raise LookupError("Op with name {0} not registered".format(
                op_proto.name))

        op_def_lib = _op_def_library.OpDefLibrary()
        ops_proto = _op_def_pb2.OpList()
        ops_proto.op.extend([op_proto])

    # Fails if the interfaces ("op schemas") don't match between the
    # previously registered op and this one.
    _registry.register_op_list(ops_proto)

    op_def_lib.add_op_list(ops_proto)

    return op_def_lib
Esempio n. 9
0
def register_ops_if_needed(graph_ops):
    """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
    missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

    if not missing_ops:
        return

    p_buffer = c_api.TF_GetAllOpList()
    cpp_op_list = op_def_pb2.OpList()
    cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
    cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

    missing_op_list = op_def_pb2.OpList()
    for missing_op in missing_ops:
        if missing_op not in cpp_registry_ops:
            tf.logging.info(
                "Op %s is missing from both the python and C++ registry.",
                missing_op)
        else:
            missing_op_list.op.extend([cpp_registry_ops[missing_op]])
            tf.logging.info(
                "Adding op %s from c++ registry to python registry.",
                missing_op)

    op_def_registry.register_op_list(missing_op_list)

    # Note: Only raise missing op ValueError after trying to load ops.
    # This allows the test to exercise all the calls into TensorFlow
    # without having to write a C + python test.
    if not missing_ops <= set(cpp_registry_ops.keys()):
        raise RuntimeError(
            "Graph ops missing from the python registry (%s) are also absent from "
            "the c++ registry." %
            missing_ops.difference(set(cpp_registry_ops.keys())))
Esempio n. 10
0
def register_ops_if_needed(graph_ops):
  """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
  missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

  if not missing_ops:
    return

  p_buffer = c_api.TF_GetAllOpList()
  cpp_op_list = op_def_pb2.OpList()
  cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
  cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

  missing_op_list = op_def_pb2.OpList()
  for missing_op in missing_ops:
    if missing_op not in cpp_registry_ops:
      tf.logging.info(
          "Op %s is missing from both the python and C++ registry.",
          missing_op)
    else:
      missing_op_list.op.extend([cpp_registry_ops[missing_op]])
      tf.logging.info(
          "Adding op %s from c++ registry to python registry.",
          missing_op)

  op_def_registry.register_op_list(missing_op_list)

  # Note: Only raise missing op ValueError after trying to load ops.
  # This allows the test to exercise all the calls into TensorFlow
  # without having to write a C + python test.
  if not missing_ops <= set(cpp_registry_ops.keys()):
    raise RuntimeError(
        "Graph ops missing from the python registry (%s) are also absent from "
        "the c++ registry."
        % missing_ops.difference(set(cpp_registry_ops.keys())))
Esempio n. 11
0
    attr { name: 'N' type: 'int' minimum: 1 }
    attr { name: 'T' type: 'type' }
  }
  op {
    name: 'Otl'
    output_arg { name: 'a' type_list_attr: 't' }
    attr { name: 'T' type: 'list(type)' minimum: 1 }
  }
  op {
    name: 'Unary'
    input_arg { name: 'a' type_attr: 'T' }
    output_arg { name: 'b' type_attr: 'T' }
    attr { name: 'T' type: 'type' }
  }
""", _op_list)
op_def_registry.register_op_list(_op_list)
# NOTE(mrry): Dummy shape registrations for ops used in the tests.
for op_def in _op_list.op:
  tf.RegisterShape(op_def.name)(None)

class ImportGraphDefTest(tf.test.TestCase):

  def _MakeGraphDef(self, text):
    ret = tf.GraphDef()
    text_format.Merge(text, ret)
    return ret

  def testBasic(self):
    with tf.Graph().as_default():
      a, b, c, d = tf.import_graph_def(
          self._MakeGraphDef("""
Esempio n. 12
0
  op {
    name: 'Unary'
    input_arg { name: 'a' type_attr: 'T' }
    output_arg { name: 'b' type_attr: 'T' }
    attr { name: 'T' type: 'type' }
  }
  op {
    name: 'OpWithDefaultAttr'
    output_arg { name: 'a' type: DT_INT32 }
    attr { name: 'default_float' type: 'float' default_value { f: 123.0 } }
  }
  op {
    name: 'OpWithFutureDefaultAttr'
  }
""", _op_list)
op_def_registry.register_op_list(_op_list)
# NOTE(mrry): Dummy shape registrations for ops used in the tests.
for op_def in _op_list.op:
    ops.RegisterShape(op_def.name)(None)


class ImportGraphDefTest(test.TestCase):
    def _MakeGraphDef(self,
                      text,
                      producer=versions.GRAPH_DEF_VERSION,
                      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER):
        text = "versions: { producer: %d min_consumer: %d };\n%s" % (
            producer, min_consumer, text)
        ret = graph_pb2.GraphDef()
        text_format.Merge(text, ret)
        return ret