コード例 #1
0
ファイル: function.py プロジェクト: zhaoweijin/tensorflow
    def __init__(self, func_name, input_types, output_types):
        """Creates a `Declare` object.

    Args:
      func_name: The name of the function.
      input_types: A list of data types of function arguments.
      output_types: A list of data types of function return values.
    """
        self._sig = op_def_pb2.OpDef()
        self._sig.name = func_name

        def _to_argdef_list(types):
            return [
                op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types
            ]

        self._sig.input_arg.extend(_to_argdef_list(input_types))
        self._sig.output_arg.extend(_to_argdef_list(output_types))
コード例 #2
0
  def __init__(self, func_name, inputs, outputs):
    """Creates a `Declare` object.

    Args:
      func_name: The name of the function.
      inputs: A list of (name, data type) pairs of function arguments.
      outputs: A list of (name, data type) pairs of function return values.
    """
    self._sig = op_def_pb2.OpDef()
    self._sig.name = func_name

    def _to_argdef_list(args):
      names = [n for n, t in args]
      if len(names) != len(set(names)):
        raise ValueError("Expected names to all be unique: %s" % str(names))
      return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
              for n, t in args]

    self._sig.input_arg.extend(_to_argdef_list(inputs))
    self._sig.output_arg.extend(_to_argdef_list(outputs))
コード例 #3
0
def _create_kaldi_table_dataset_op_proto(op_name):
    # TODO(galv): See if I can call a C++ function to get this protobuf
    # directly from whatever REGISTER_OP expands to, so we can follow
    # DRY.
    op = _op_def_pb2.OpDef()
    # Does this need to be bytes rather than unicode str? Hmm...
    op.name = op_name

    r_specifier = _op_def_pb2.OpDef.ArgDef()
    r_specifier.name = "r_specifier"
    r_specifier.type = _types_pb2.DT_STRING
    op.input_arg.extend([r_specifier])

    output_handle = _op_def_pb2.OpDef.ArgDef()
    output_handle.name = "handle"
    output_handle.type = _types_pb2.DT_VARIANT
    op.output_arg.extend([output_handle])

    op.is_stateful = True

    return op
コード例 #4
0
ファイル: op_def_registry.py プロジェクト: Harryi0/tinyML
def get(name):
    """Returns an OpDef for a given `name` or None if the lookup fails."""
    try:
        return _cache[name]
    except KeyError:
        pass

    with _cache_lock:
        try:
            # Return if another thread has already populated the cache.
            return _cache[name]
        except KeyError:
            pass

        serialized_op_def = _op_def_registry.get(name)
        if serialized_op_def is None:
            return None

        op_def = op_def_pb2.OpDef()
        op_def.ParseFromString(serialized_op_def)
        _cache[name] = op_def
        return op_def
コード例 #5
0
def test_opdef_sig():
    """Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`."""
    from tensorflow.core.framework import op_def_pb2

    custom_opdef_tf = op_def_pb2.OpDef()
    custom_opdef_tf.name = "MyOpDef"

    arg1_tf = op_def_pb2.OpDef.ArgDef()
    arg1_tf.name = "arg1"
    arg1_tf.type_attr = "T"

    arg2_tf = op_def_pb2.OpDef.ArgDef()
    arg2_tf.name = "arg2"
    arg2_tf.type_attr = "T"

    custom_opdef_tf.input_arg.extend([arg1_tf, arg2_tf])

    attr1_tf = op_def_pb2.OpDef.AttrDef()
    attr1_tf.name = "T"
    attr1_tf.type = "type"

    attr2_tf = op_def_pb2.OpDef.AttrDef()
    attr2_tf.name = "axis"
    attr2_tf.type = "int"
    attr2_tf.default_value.i = 1

    custom_opdef_tf.attr.extend([attr1_tf, attr2_tf])

    opdef_sig, opdef_func = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf)

    import inspect
    # These are standard inputs
    assert opdef_sig.parameters['arg1'].default == inspect._empty
    assert opdef_sig.parameters['arg2'].default == inspect._empty
    # These are attributes that are sometimes required by the OpDef
    assert opdef_sig.parameters['axis'].default == inspect._empty
    # The obligatory tensor name parameter
    assert opdef_sig.parameters['name'].default is None
コード例 #6
0
    def testStripDefaultAttrsInconsistentConsumerDefaults(self):
        export_dir = os.path.join(
            test.get_temp_dir(),
            "test_strip_default_attrs_no_consumer_defaults")
        builder = saved_model_builder.SavedModelBuilder(export_dir)

        # Add a graph with two float32 variables and a Complex Op composing them
        # with strip_default_attrs enabled. This must remove the following
        # defaults for the "Complex" Op:
        #   o "T"    : float32.   (input type)
        #   o "Tout" : complex64. (output type)
        with session.Session(graph=ops.Graph()) as sess:
            real_num = variables.Variable(1.0,
                                          dtype=dtypes.float32,
                                          name="real")
            imag_num = variables.Variable(2.0,
                                          dtype=dtypes.float32,
                                          name="imag")
            math_ops.complex(real_num, imag_num, name="complex")
            sess.run(variables.global_variables_initializer())
            builder.add_meta_graph_and_variables(sess, ["foo"],
                                                 strip_default_attrs=True)

        # Save the SavedModel to disk in text format.
        builder.save(as_text=True)

        # Update the Op registry to remove defaults for all attrs("T", "Tout") from
        # the "Complex" OpDef.
        complex_op_def = op_def_registry.get_registered_ops()["Complex"]
        original_complex_op_def = op_def_pb2.OpDef()
        original_complex_op_def.CopyFrom(complex_op_def)
        for attr_def in complex_op_def.attr:
            attr_def.ClearField("default_value")

        # Loading the SavedModel via the loader must fail because the SavedModel
        # does not have any attr values for the "Complex" node and the current
        # op registry does not have have any default values for the "Complex" op.
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                ValueError,
                "Expected one attr with name .*T(out)?.* in name: \"complex\".*"
        ):
            loader.load(sess, ["foo"], export_dir)

        # Update the Op registry to change the defaults for attr "Tout"
        # (complex64 -> complex128).
        complex_op_def.CopyFrom(original_complex_op_def)
        for attr_def in complex_op_def.attr:
            if attr_def.name == "Tout":
                attr_def.default_value.type = types_pb2.DT_COMPLEX128

        # Loading the SavedModel via the loader must set "Tout" attr_value for the
        # "Complex" node according to the latest defaults (complex128). This is
        # expected to fail the model import as there is no OpKernel registered to
        # handle attrs "T" (float32) and "Tout" (complex128).
        sess = session.Session(graph=ops.Graph())
        with self.assertRaisesRegexp(
                errors.InvalidArgumentError,
                ".*No OpKernel was registered to support Op \'Complex\' with these "
                "attrs..*"):
            loader.load(sess, ["foo"], export_dir)