def __init__(self, func_name, input_types, output_types): """Creates a `Declare` object. Args: func_name: The name of the function. input_types: A list of data types of function arguments. output_types: A list of data types of function return values. """ self._sig = op_def_pb2.OpDef() self._sig.name = func_name def _to_argdef_list(types): return [ op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types ] self._sig.input_arg.extend(_to_argdef_list(input_types)) self._sig.output_arg.extend(_to_argdef_list(output_types))
def __init__(self, func_name, inputs, outputs): """Creates a `Declare` object. Args: func_name: The name of the function. inputs: A list of (name, data type) pairs of function arguments. outputs: A list of (name, data type) pairs of function return values. """ self._sig = op_def_pb2.OpDef() self._sig.name = func_name def _to_argdef_list(args): names = [n for n, t in args] if len(names) != len(set(names)): raise ValueError("Expected names to all be unique: %s" % str(names)) return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) for n, t in args] self._sig.input_arg.extend(_to_argdef_list(inputs)) self._sig.output_arg.extend(_to_argdef_list(outputs))
def _create_kaldi_table_dataset_op_proto(op_name): # TODO(galv): See if I can call a C++ function to get this protobuf # directly from whatever REGISTER_OP expands to, so we can follow # DRY. op = _op_def_pb2.OpDef() # Does this need to be bytes rather than unicode str? Hmm... op.name = op_name r_specifier = _op_def_pb2.OpDef.ArgDef() r_specifier.name = "r_specifier" r_specifier.type = _types_pb2.DT_STRING op.input_arg.extend([r_specifier]) output_handle = _op_def_pb2.OpDef.ArgDef() output_handle.name = "handle" output_handle.type = _types_pb2.DT_VARIANT op.output_arg.extend([output_handle]) op.is_stateful = True return op
def get(name): """Returns an OpDef for a given `name` or None if the lookup fails.""" try: return _cache[name] except KeyError: pass with _cache_lock: try: # Return if another thread has already populated the cache. return _cache[name] except KeyError: pass serialized_op_def = _op_def_registry.get(name) if serialized_op_def is None: return None op_def = op_def_pb2.OpDef() op_def.ParseFromString(serialized_op_def) _cache[name] = op_def return op_def
def test_opdef_sig(): """Make sure we can construct an `inspect.Signature` object for a protobuf OpDef when its corresponding function isn't present in `tf.raw_ops`.""" from tensorflow.core.framework import op_def_pb2 custom_opdef_tf = op_def_pb2.OpDef() custom_opdef_tf.name = "MyOpDef" arg1_tf = op_def_pb2.OpDef.ArgDef() arg1_tf.name = "arg1" arg1_tf.type_attr = "T" arg2_tf = op_def_pb2.OpDef.ArgDef() arg2_tf.name = "arg2" arg2_tf.type_attr = "T" custom_opdef_tf.input_arg.extend([arg1_tf, arg2_tf]) attr1_tf = op_def_pb2.OpDef.AttrDef() attr1_tf.name = "T" attr1_tf.type = "type" attr2_tf = op_def_pb2.OpDef.AttrDef() attr2_tf.name = "axis" attr2_tf.type = "int" attr2_tf.default_value.i = 1 custom_opdef_tf.attr.extend([attr1_tf, attr2_tf]) opdef_sig, opdef_func = MetaOpDefLibrary.make_opdef_sig(custom_opdef_tf) import inspect # These are standard inputs assert opdef_sig.parameters['arg1'].default == inspect._empty assert opdef_sig.parameters['arg2'].default == inspect._empty # These are attributes that are sometimes required by the OpDef assert opdef_sig.parameters['axis'].default == inspect._empty # The obligatory tensor name parameter assert opdef_sig.parameters['name'].default is None
def testStripDefaultAttrsInconsistentConsumerDefaults(self): export_dir = os.path.join( test.get_temp_dir(), "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) # Add a graph with two float32 variables and a Complex Op composing them # with strip_default_attrs enabled. This must remove the following # defaults for the "Complex" Op: # o "T" : float32. (input type) # o "Tout" : complex64. (output type) with session.Session(graph=ops.Graph()) as sess: real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") math_ops.complex(real_num, imag_num, name="complex") sess.run(variables.global_variables_initializer()) builder.add_meta_graph_and_variables(sess, ["foo"], strip_default_attrs=True) # Save the SavedModel to disk in text format. builder.save(as_text=True) # Update the Op registry to remove defaults for all attrs("T", "Tout") from # the "Complex" OpDef. complex_op_def = op_def_registry.get_registered_ops()["Complex"] original_complex_op_def = op_def_pb2.OpDef() original_complex_op_def.CopyFrom(complex_op_def) for attr_def in complex_op_def.attr: attr_def.ClearField("default_value") # Loading the SavedModel via the loader must fail because the SavedModel # does not have any attr values for the "Complex" node and the current # op registry does not have have any default values for the "Complex" op. sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( ValueError, "Expected one attr with name .*T(out)?.* in name: \"complex\".*" ): loader.load(sess, ["foo"], export_dir) # Update the Op registry to change the defaults for attr "Tout" # (complex64 -> complex128). complex_op_def.CopyFrom(original_complex_op_def) for attr_def in complex_op_def.attr: if attr_def.name == "Tout": attr_def.default_value.type = types_pb2.DT_COMPLEX128 # Loading the SavedModel via the loader must set "Tout" attr_value for the # "Complex" node according to the latest defaults (complex128). This is # expected to fail the model import as there is no OpKernel registered to # handle attrs "T" (float32) and "Tout" (complex128). sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, ".*No OpKernel was registered to support Op \'Complex\' with these " "attrs..*"): loader.load(sess, ["foo"], export_dir)