def _RemoveDefaultAttrs(producer_op_list, graph_def): """Removes unknown default attrs according to `producer_op_list`. Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in registered OpDefs) that have a default value in `producer_op_list`. Args: producer_op_list: OpList proto. graph_def: GraphDef proto """ producer_op_dict = {op.name: op for op in producer_op_list.op} for node in graph_def.node: # Remove any default attr values that aren't in op_def. if node.op in producer_op_dict: op_def = op_def_registry.get(node.op) if op_def is None: # Some custom op registrations won't show up here. That's OK, attribute # stripping just won't be available. continue producer_op_def = producer_op_dict[node.op] # We make a copy of node.attr to iterate through since we may modify # node.attr inside the loop. for key in list(node.attr): if _FindAttrInOpDef(key, op_def) is None: # No attr_def in consumer, look in producer. attr_def = _FindAttrInOpDef(key, producer_op_def) if (attr_def and attr_def.HasField('default_value') and node.attr[key] == attr_def.default_value): # Unknown attr had default value in producer, delete it so it can be # understood by consumer. del node.attr[key]
def _get_num_inputs_outputs(op_type): """Returns (num_inputs, num_outputs). Args: op_type: String. The type of the Operation. Used to lookup the op in the registry. Returns: (num_inputs, num_outputs), for either num_inputs or num_outputs if the value can't be statically inferred from the OpDef alone or of the OpDef lookup fails, -1 is returned. """ def _is_list_arg(arg): return arg.number_attr or arg.type_list_attr def _count_args(arg_defs): for arg in arg_defs: if _is_list_arg(arg): # Op has list type args which could be variable. return -1 return len(arg_defs) op_def = op_def_registry.get(op_type) if not op_def: return -1, -1 return _count_args(op_def.input_arg), _count_args(op_def.output_arg)
def _ProcessGraphDefParam(graph_def): """Type-checks and possibly canonicalizes `graph_def`.""" if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') else: # If we're using the graph_def provided by the caller, modify graph_def # in-place to add attr defaults to the NodeDefs (this is visible to the # caller). # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py # depends on. It might make sense to move this to meta_graph.py and have # import_graph_def not modify the graph_def argument (we'd have to make sure # this doesn't break anything else.) for node in graph_def.node: op_def = op_def_registry.get(node.op) if op_def is None: # Assume unrecognized ops are functions for now. TF_ImportGraphDef will # report an error if the op is actually missing. continue _SetDefaultAttrValues(node, op_def) return graph_def
def visit_FunctionDef(self, node): # TODO(fengliuai): create one utility method to match different apis and # shared it with the tfr_gen.py module. compose_dec = [] for dec in node.decorator_list: if isinstance(dec, ast.Call): if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite': compose_dec.append(dec) if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': compose_dec.append(dec) if not compose_dec: # skip a non-composition function return elif len(compose_dec) > 1: raise KeyError('More than one TF ops decomposes for.') all_dec_args = {} for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args): all_dec_args[arg_name] = self.visit(arg_value) kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords]) if all_dec_args.keys() & kw_dec_args.keys(): raise KeyError('More arguments than expected.') all_dec_args.update(kw_dec_args) op_name = all_dec_args['op_name'] op_def = op_def_registry.get(op_name) if op_def: if len(all_dec_args) > 1: # Op has been registered, so it is a user error to specify op def. raise ValueError('Op has been registered: ' + op_name) else: # Op has been registered, then we don't need to generate register code. return # Validates the function inputs match what are in the decorator. inputs = all_dec_args.get('inputs', []) attrs = all_dec_args.get('attrs', []) expected_args = [arg.split(':')[0] for arg in inputs + attrs] all_func_args = self.visit(node.args) if len(expected_args) != len(all_func_args): raise KeyError('Composition arguments do not match the registration.') cxx_reg_code = '\nREGISTER_OP("{0}")'.format(op_name) for input_ in inputs: cxx_reg_code += '\n .Input("{0}")'.format(input_) for attr in attrs: py_str = attr.replace('"', '\'') cxx_reg_code += '\n .Attr("{0}")'.format(py_str) for attr in all_dec_args.get('derived_attrs', []): py_str = attr.replace('"', '\'') cxx_reg_code += '\n .Attr("{0}")'.format(py_str) for output_ in all_dec_args.get('outputs', []): cxx_reg_code += '\n .Output("{0}")'.format(output_) cxx_reg_code += ';\n' self.emit(cxx_reg_code)
def get_metaopdef(cls, name): """Obtain a MetaOpDef for a given string name. This is more flexible because it ignores things like string case (when the non-`raw_ops` name differs from the TF user-level API). """ raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name) op_def = op_def_registry.get(raw_op_name) if op_def is not None: return TFlowMetaOpDef(obj=op_def)
def _get_output_type(node: r.NodeDef, output_idx: int) -> int: """Return the type of the nth output of a node""" op = op_def_registry.get(node.op) output_arg = op.output_arg[output_idx] if output_arg.type != 0: return output_arg.type elif len(output_arg.type_attr) > 0: return node.attr[output_arg.type_attr].type else: raise ValueError(f'cannot determine output type of node "{node.name}"' f' op={op.name}')
def __init__(self): # # We need this in order to construct "Const" tensors directly, since # the "value" attr in a meta `NodeDef` is just a NumPy array and not # the `TensorProto` expected by `raw_ops.Const`. # def mt_const(value, dtype, name=None): return tf.raw_ops.Const(value=tensor_util.make_tensor_proto(value), dtype=dtype, name=name) opdef = op_def_registry.get("Const") self.opdef_signatures[opdef.name] = self.make_opdef_sig( opdef, mt_const)
def get_op_def(op_name: Text) -> Optional[OpDef]: """ Get the definition for a native TF operation. This is useful for checking whether an operation is supported or to get all valid inputs and attributes. Args: op_name: Name of the native TF operation (e.g. "AddV2") Returns: Protobuf object containing the operation definition `None` is returned, if the operation is not registered with TF """ return op_def_registry.get(op_name)
def register_ops_if_needed(graph_ops): """Register graph ops absent in op_def_registry, if present in c++ registry. Args: graph_ops: set with graph op names to register. Raises: tf.errors.NotFoundError: if `graph_ops` contains ops that are not in either python or c++ registry. """ if all(op_def_registry.get(op) is not None for op in graph_ops): return # Note: Only raise missing op ValueError after trying to load ops. # This allows the test to exercise all the calls into TensorFlow # without having to write a C + python test. op_def_registry.sync() missing_ops = {op for op in graph_ops if op_def_registry.get(op) is None} if missing_ops: raise tf.errors.NotFoundError( None, None, "Graph ops missing from the python registry (%s) are also absent from " "the c++ registry." % missing_ops)
def _strip_node_default_valued_attrs(node_def): """Removes default valued attributes from a single node def.""" if node_def.op in op_name_to_function: return op_def = op_def_registry.get(node_def.op) if op_def is None: return attrs_to_strip = set() for attr_name, attr_value in node_def.attr.items(): if _is_default_attr_value(op_def, attr_name, attr_value): attrs_to_strip.add(attr_name) for attr in attrs_to_strip: del node_def.attr[attr]
def _GetOpDef(op_type_name, keywords): """Returns the OpDef, Graph and Producer. For use in _apply_op_helper.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError(f"Unrecognized Op name {op_type_name}") # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) producer = g.graph_def_versions.producer # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") return op_def, g, producer
def fix_node_def(node_def, functions, shared_name_suffix, debug_name): """Replace functions calls and shared names in `node_def`.""" if ("_gradient_op_type" in node_def.attr and node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): logging.warning( "Importing a function (%s) with ops with custom gradients. Will likely " "fail if a gradient is requested.", debug_name) if node_def.op in functions: node_def.op = functions[node_def.op].name for _, attr_value in node_def.attr.items(): if attr_value.WhichOneof("value") == "func": attr_value.func.name = functions[attr_value.func.name].name elif attr_value.WhichOneof("value") == "list": for fn in attr_value.list.func: fn.name = functions[fn.name].name # Fix old table creation bug. if node_def.op == "HashTableV2": if ("use_node_name_sharing" not in node_def.attr or not node_def.attr["use_node_name_sharing"].b): node_def.attr["use_node_name_sharing"].b = True # We are turning on node mame sharing, so have to make sure we don't # accidentally share a table resource. shared_name_suffix += "_{}".format(ops.uid()) # TODO(b/124205571): Avoid accidental sharing and destruction of restored # resources. For now uniquify "shared_name" when loading functions to avoid # sharing. # TODO: Add regression test for b/150826922. op_def = op_def_registry.get(node_def.op) if op_def: attr = next((a for a in op_def.attr if a.name == "shared_name"), None) if attr: shared_name = None if "shared_name" in node_def.attr and node_def.attr[ "shared_name"].s: shared_name = node_def.attr["shared_name"].s elif attr.default_value.s: shared_name = compat.as_bytes(attr.default_value.s) if not shared_name: shared_name = compat.as_bytes(node_def.name) node_def.attr["shared_name"].s = ( shared_name + compat.as_bytes(shared_name_suffix))
def __init__(self, op_def, node_def, inputs, outputs=None, obj=None): """Create a TensorFlow meta `Operation`. The real signature of `tf.Operation.__init__` includes the graph object, so we can't really the signature directly. This is part of the reason why we have `TFlowMetaOpFactory.__call__` and `TFlowMetaTensor.operator` + `TFlowMetaTensor.inputs` that do not directly use `__all_props__`/`TFlowMetaTensor.rands` and construct the objects directly. """ super().__init__(obj=obj) if isinstance(op_def, str): op_def = op_def_registry.get(op_def) self.op_def = metatize(op_def) self.node_def = metatize(node_def) if isvar(inputs): self.inputs = inputs else: # Inputs are supposed to be immutable, so we're able to convert # lists to tuples. def _convert_inputs(arg, nested): if nested and isinstance(arg, list): arg = tuple(metatize(i) for i in arg) else: arg = metatize(arg) return arg if not isvar(self.op_def): self.inputs = tuple( _convert_inputs(i, hasattr(info, "number_attr")) for i, info in zip(inputs, self.op_def.obj.input_arg)) else: self.inputs = tuple(_convert_inputs(i, False) for i in inputs) if outputs is not None: if isvar(outputs): self._outputs = outputs else: self._outputs = tuple(metatize(o) for o in outputs)
def _get_ref_args(self, node): """Determine whether an input of an op is ref-type. Args: node: A `NodeDef`. Returns: A list of the arg names (as strs) that are ref-type. """ op_def = op_def_registry.get(node.op) if op_def is None: return [] ref_args = [] for i, output_arg in enumerate(op_def.output_arg): if output_arg.is_ref: arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) ref_args.append(arg_name) return ref_args
def lookup(self, f_name, func_def=None, optional=False): if f_name in self._op_defs.keys(): return self._op_defs[f_name] if isinstance(func_def, types.FunctionType): if not hasattr(func_def, '_tfr_op_name'): # skip a non-composition function if optional: return (None, None) else: raise KeyError('OpDef does not exist: ' + f_name) op_name = getattr(func_def, '_tfr_op_name') elif not func_def: op_name = f_name else: # TODO(fengliuai): create one utility method to match different apis. compose_dec = [] for dec in func_def.decorator_list: if isinstance(dec, ast.Call): if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite': compose_dec.append(dec) if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': compose_dec.append(dec) if not compose_dec: # skip a non-composition function if optional: return (None, None) else: raise KeyError('OpDef does not exist: ' + f_name) elif len(compose_dec) > 1: raise KeyError('More than one TF ops decomposes for.') else: op_name = compose_dec[0].args[0].value op_def = op_def_registry.get(op_name) if not op_def: raise ValueError('Not a registered op: ' + op_name) derived_attrs = _collect_derived_attrs_from_proto(op_def) self._op_defs[f_name] = (op_def, derived_attrs) return (op_def, derived_attrs)
def get_op_info(cls, opdef): """Return the TF Python API function signature for a given `OpDef`. Parameter --------- opdef: str or `OpDef` object (meta or base) """ if isinstance(opdef, str): opdef_name = opdef opdef = op_def_registry.get(opdef_name) else: opdef_name = opdef.name opdef_sig = cls.opdef_signatures.get(opdef_name, None) if opdef_sig is None and opdef is not None: opdef_func = getattr(tf.raw_ops, opdef.name, None) opdef_sig = cls.make_opdef_sig(opdef, opdef_func) cls.opdef_signatures[opdef.name] = opdef_sig return opdef_sig
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") op_defs = [] for op in sorted(used_ops): op_def = op_def_registry.get(op) if op_def is not None: op_defs.append(op_def) elif op not in op_whitelist: raise ValueError( "Op %s is used by the graph, but is not registered" % op) return op_def_pb2.OpList(op=op_defs)
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. """ # This is similar to StrippedOpListForGraph in C++, but unlike its # C++ counterpart, this version does not require all ops to be registered. # This is done to support Prelu fusion in tfjs. used_ops = ops_used_by_graph_def(graph_def) op_defs = [] for op in sorted(used_ops): op_def = op_def_registry.get(op) if op_def is not None: op_defs.append(op_def) return op_def_pb2.OpList(op=op_defs)
def __exit__(self, unused_type, unused_value, unused_traceback): if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Graph changed while trying to add control dependencies.") # pylint: disable=protected-access if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # pylint: enable=protected-access # map from resource tensor to the last op which used it last_op_using_resource_tensor = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_op_using_resource_tensor). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run if op_def_registry.get(op.type) is None or op_is_stateful(op): ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) # pylint: disable=protected-access for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_op_using_resource_tensor: last_op_using_resource_tensor[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_op_using_resource_tensor. for inp in op.inputs: if inp.dtype != dtypes_module.resource: continue input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_op_using_resource_tensor, merge_for_resource) # Ensure uses of resources are serialized if input_id in last_op_using_resource_tensor: if (last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access control_inputs.add(last_op_using_resource_tensor[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[input_id] = op if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): # pylint: disable=protected-access if None in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access last_op_using_resource_tensor[None] = op control_inputs = [c for c in control_inputs if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access op._add_control_inputs(control_inputs) # pylint: disable=protected-access # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: r.op._add_control_inputs( # pylint: disable=protected-access [o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access
def convert_int64_to_int32(graph_def: r.GraphDef) -> r.GraphDef: """Convert int64 input to int32 for TFJS compatibility Args: graph_def: GraphDef proto containing the network layout Returns: Updated graph with int64 inputs converted to int32 """ inputs = util.get_input_nodes(graph_def) convert = [info.name for info in inputs if info.dtype == util.np.int64] if len(convert) == 0: return graph_def # quick access to nodes by name node_map = r.get_input_node_map(graph_def) # map of all node inputs to their referencing node and their argument index input_map = defaultdict(list) for node in graph_def.node: for index, name in enumerate(node.input): input_map[name].append((index, node)) # type cast ops to add to the graph type_cast_ops = [] # nodes that require a type cast op type_cast_candidates: Dict[str, Tuple[int, r.NodeDef]] = {} for node in map(lambda x: node_map[x], convert): _set_tensor_dtype(node, _DT_INT32) # find all nodes that reference this input and adjust their datatype # attributes if required # technical note: referenced_by is a stack, this really is a # depth-first recursion referenced_by = input_map[node.name] while len(referenced_by) > 0: idx, ref = referenced_by.pop() # get the input node and the index of the output tensor input_node, output_idx = _get_input_node(ref, idx, node_map) # find the description of this node's operation op = op_def_registry.get(ref.op) desc = op.input_arg[idx] # find out whether we can just change the input type and which # attributes we might need to touch if desc.type != 0 and desc.type != _DT_INT32: # input type is fixed and cannot be changed: add a type cast cast_op = _make_cast_node(input_node, output_idx, _DT_INT32, desc.type) ref.input[idx] = cast_op.name type_cast_ops.append(cast_op) node_map[cast_op.name] = cast_op input_map[cast_op.name].append((idx, ref)) elif desc.type_list_attr != '' or desc.type_attr == '': # input arrays of potentially mixed types cannot be handled raise ValueError("don't know how to handle input type changes" f' for node "{ref.name}" op={ref.op}') else: # change the type of this input type_attr = desc.type_attr ref.attr[type_attr].type = _DT_INT32 if ref.name in type_cast_candidates: del type_cast_candidates[ref.name] # check the other inputs for type compatibility for i, desc in enumerate(op.input_arg): if i == idx or desc.type_attr != type_attr: continue # not a matching input input_node, output_idx = _get_input_node(ref, i, node_map) if input_node.name in convert: continue # Placeholder that will be converted src_type = _get_output_type(input_node, output_idx) if src_type == _DT_INT32: continue # type matches already if input_node.op == 'Const': # weight tensor: harmonize_dtypes() will fix these _set_tensor_dtype(input_node, _DT_INT32) else: # add node as a candidate for needing type cast op type_cast_candidates[input_node.name] = (i, ref) # process any changed outputs next for idx, output in enumerate(op.output_arg): if output.type_attr == type_attr: input_name = _get_tensor_name(ref, idx) referenced_by += input_map[input_name] for idx, ref in type_cast_candidates.values(): # add type cast operations for all nodes that have a type mismatch inp_node, channel = _get_input_node(ref, idx, node_map) src_type = _get_output_type(inp_node, channel) if src_type != _DT_INT32: cast_op = _make_cast_node(inp_node, channel, src_type, _DT_INT32) ref.input[idx] = cast_op.name type_cast_ops.append(cast_op) node_map[cast_op.name] = cast_op graph_def.node.extend(type_cast_ops) return graph_def
def _op_def(self, op_name): return op_def_registry.get(op_name)
def enable_jit_nonstateful(node_def): op_def = op_def_registry.get(node_def.op) if op_def is None: raise ValueError("Unregistered op being created: %s" % node_def) return not op_def.is_stateful
def __exit__(self, unused_type, unused_value, unused_traceback): # pylint: disable=protected-access if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Graph changed while trying to add control dependencies.") if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) # CollectiveManager manager_ids within a particular function call should not # be needed outside of that function call. So we keep them separate (though # the general idea of the maps is the same, in the future, we'll need to # correctly thread the control output outside). # Map from collective manager scope to the last op which used it collective_manager_scopes_opened = {} collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run if (op_def_registry.get(op.type) is None or (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)): # TODO(srbs): Do not add functional ops to `ops_which_must_run` if # they only have variable reads and are otherwise stateless. ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": try: collective_manager_scopes_opened[op.get_attr( "_collective_manager_id")] = op except ValueError: pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[ 0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_write_to_resource. for inp, resource_type in _get_resource_inputs(op): is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( last_write_to_resource[input_id]. _control_flow_context is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) if is_read: reads_since_last_write_to_resource[input_id].append(op) else: control_inputs.update( reads_since_last_write_to_resource[input_id]) reads_since_last_write_to_resource[input_id] = [] last_write_to_resource[input_id] = op if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): if None in last_write_to_resource: op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op # Ensure ordering of collective ops manager_ids = collective_manager_ids_from_op(op) for manager_id in manager_ids: if manager_id in collective_manager_scopes_opened: # Chain this function call if the scope was opened. op._add_control_input( collective_manager_scopes_opened[manager_id]) collective_manager_scopes_opened[manager_id] = op else: # If this op is in a scope not created here, create a chain starting # at this op. if manager_id in collective_manager_scopes_used: op._add_control_input( collective_manager_scopes_used[manager_id]) collective_manager_scopes_used[manager_id] = op if control_inputs and not is_building_function: control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] op._add_control_inputs(control_inputs) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: updated_ops_which_must_run = [] if r.graph.building_function: updated_ops_which_must_run = self.ops_which_must_run else: updated_ops_which_must_run = [ o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context ] r.op._add_control_inputs(updated_ops_which_must_run) self.collective_manager_ids_used = collective_manager_scopes_used
def __exit__(self, unused_type, unused_value, unused_traceback): # pylint: disable=protected-access if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Within the automatic control dependency context, the default graph" f" cannot change. Upon entry it was {self._graph}, but on exit it" f" changed to {ops.get_default_graph()}") if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) # CollectiveManager manager_ids within a particular function call should not # be needed outside of that function call. So we keep them separate (though # the general idea of the maps is the same, in the future, we'll need to # correctly thread the control output outside). # Map from collective manager scope to the last op which used it collective_manager_scopes_opened = {} collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] first_use_for_res = {} resources_by_op = {} # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run. # Read-only ops are added to control outputs if the read value is # consumed. This covers the case when the read value is returned from # the function since that goes through a tf.identity in mark_as_return. if ((op_def_registry.get(op.type) is None) or (op_is_stateful(op) and (op.type not in utils.RESOURCE_READ_OPS or any(output.consumers() for output in op.outputs))) or (op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS)): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": try: collective_manager_scopes_opened[op.get_attr( "_collective_manager_id")] = op except ValueError: pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run # TODO(mdan): Don't do this. Write a transform to chains instead. # See core/common_runtime/control_flow_deps_to_chains.cc. if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_write_to_resource. for inp, resource_type in _get_resource_inputs(op): is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( last_write_to_resource[input_id]._control_flow_context is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) do_record = ( self.record_initial_resource_uses and input_id not in first_use_for_res) if is_read: reads_list = reads_since_last_write_to_resource[input_id] reads_list.append(op) if do_record: # Note: this will track the entire list that # reads_since_last_write_to_resource maintains. Updates to it will # and should be tracked, until the first write is encountered. At # that point, reads_since_last_write_to_resource will contain a new # empty list. This logic relies on that behavior. first_use_for_res[input_id] = reads_list else: control_inputs.update(reads_since_last_write_to_resource[input_id]) reads_since_last_write_to_resource[input_id] = [] last_write_to_resource[input_id] = op if do_record: first_use_for_res[input_id] = [op] if self.record_initial_resource_uses and op_is_stateful(op): if resource_inputs: resources_by_op[op] = tuple(resource_inputs) else: if None not in first_use_for_res: first_use_for_res[None] = [op] resources_by_op[op] = (None,) if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): if None in last_write_to_resource: op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op # Ensure ordering of collective ops manager_ids = collective_manager_ids_from_op(op) for manager_id in manager_ids: if manager_id in collective_manager_scopes_opened: # Chain this function call if the scope was opened. op._add_control_input(collective_manager_scopes_opened[manager_id]) collective_manager_scopes_opened[manager_id] = op else: # If this op is in a scope not created here, create a chain starting # at this op. if manager_id in collective_manager_scopes_used: op._add_control_input(collective_manager_scopes_used[manager_id]) collective_manager_scopes_used[manager_id] = op if control_inputs and not is_building_function: control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] op._add_control_inputs(control_inputs) # Record the ops which first use resources touched by "ops which must run". if self.record_initial_resource_uses: first_uses_by_output_ops = {} for op in ops_which_must_run: if op not in resources_by_op: # This may happen with Merge/Switch nodes which are special cased # above. continue for r in resources_by_op[op]: if op not in first_uses_by_output_ops: first_uses_by_output_ops[op] = set() first_uses_by_output_ops[op].update(first_use_for_res[r]) # For each "op which must run", set a private attr indicating the ops that # used the same resources it did. for op in first_uses_by_output_ops: others = [ other.name.encode() for other in first_uses_by_output_ops[op] ] l = attr_value_pb2.AttrValue.ListValue(s=others) # TODO(mdan): Is there a way which doesn't use anonymous attrs? op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l)) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) control_output_op = None for idx, r in enumerate( nest.flatten(list(self._returned_tensors), expand_composites=True)): if self.ops_which_must_run: updated_ops_which_must_run = [] if r.graph.building_function: # There may be many stateful ops in the graph. Adding them as # control inputs to each function output could create excessive # control edges in the graph. Thus we create an intermediate No-op # to chain the control dependencies between stateful ops and # function outputs. if idx == 0: control_output_op = control_flow_ops.no_op() control_output_op._set_attr("_acd_function_control_output", attr_value_pb2.AttrValue(b=True)) control_output_op._add_control_inputs(self.ops_which_must_run) updated_ops_which_must_run = [control_output_op] else: updated_ops_which_must_run = [ o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context ] r.op._add_control_inputs(updated_ops_which_must_run) self.collective_manager_ids_used = collective_manager_scopes_used
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError("Unrecognized Op name " + op_type_name) # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError("Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[ input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.internal_convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append( converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append( "<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid. Tensors: %s" % (prefix, values)) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] try: values = ops.internal_convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( "Expected %s passed to parameter '%s' of op '%s', got %s of " "type '%s' instead. Error: %s" % (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, repr(values), type(values).__name__, err)) except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.internal_convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( "Tried to convert '%s' to a tensor and failed. Error: %s" % (input_name, err)) prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError("All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type %s that does not " "match type %s of argument '%s'." % (input_name, op_type_name, dtypes.as_dtype(attr_value).name, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( dtypes.as_dtype(x).name for x in attr_value), ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr( op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError(( "'%s' Op requires that input '%s' be a mutable tensor " "(e.g.: a tf.Variable)") % (op_type_name, input_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text( attr_value.s), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": attr_value.func.CopyFrom(_MakeFunc(value, key)) elif attr_def.type == "list(func)": attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [ val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref ] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError(f"Unrecognized Op name {op_type_name}") # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( f"Op {op_type_name} is not available in GraphDef version {producer}. " f"It has been removed in version {deprecation_version}. " f"{op_def.deprecation.explanation}.") # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} allowed_list_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) if attr_def.HasField("allowed_values"): allowed_list_attr_map[key] = attr_def.allowed_values.list.type # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError(f"No argument for input {input_name} found in {op_def}") # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( f"Expected list for '{input_name}' argument to '{op_type_name}' " f"Op, not {values}.") # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) all_types = set(v.dtype.base_dtype for v in values) if input_arg.number_attr and len(all_types) > 1: # All types should match. raise TypeError(f"Not all types matched for {input_arg.name} for " f"{op_type_name}. Got {all_types}") except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} that do not match expected type " f"{dtype.name}.") elif input_arg.type_attr in attrs: raise TypeError(f"{prefix} that do not match type {dtype.name} " "inferred from earlier arguments.") else: raise TypeError(f"{prefix} that don't all match.") else: raise TypeError(f"{prefix} that are invalid. Tensors: {values}") types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None allowed_list = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] allowed_list = allowed_list_attr_map.get(input_arg.type_attr) try: # First see if we can get a valid dtype with the default conversion # and see if it matches an allowed dtypes. Some ops like ConcatV2 may # not list allowed dtypes, in which case we should skip this. if dtype is None and allowed_list: inferred = None try: inferred = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref) except TypeError as err: # When converting a python object such as a list of Dimensions, we # need a dtype to be specified, thus tensor conversion may throw # an exception which we will ignore and try again below. pass # If we did not match an allowed dtype, try again with the default # dtype. This could be because we have an empty tensor and thus we # picked the wrong type. if inferred is not None and inferred.dtype in allowed_list: values = inferred else: values = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) else: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " f"'{input_arg.name}' of op '{op_type_name}', got " f"{repr(values)} of type '{type(values).__name__}' instead. " f"Error: {err}") except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( f"Tried to convert '{input_name}' to a tensor and failed. " f"Error: {err}") prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} expected type of " f"{dtypes.as_dtype(input_arg.type).name}.") else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( f"{prefix} type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} must match length " f"{attrs[input_arg.number_attr]} of argument " f"'{inferred_from[input_arg.number_attr]}'.") else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} shorter than minimum length " f"{num_attr.minimum}.") # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError( f"All tensors passed to '{input_name}' of '{op_type_name}' Op " f"must have the same type. Got {base_types} instead.") if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: # If it's in default_type_attr_map, then wait to set it # (in "process remaining attrs", below). if input_arg.type_attr not in default_type_attr_map: raise TypeError( "Don't know how to infer type variable from empty input " f"list passed to input '{input_name}' of '{op_type_name}' " "Op.") else: attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type " f"{dtypes.as_dtype(attr_value).name} that does not match type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: actual_types = ", ".join( dtypes.as_dtype(x).name for x in attr_value) expected_types = ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type list of " f"{actual_types} that does not match type list {expected_types}" f" of argument '{inferred_from[input_arg.type_list_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( f"'{op_type_name}' Op requires that input '{input_name}' be a " "mutable tensor (e.g.: a tf.Variable)") input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( f"Should not specify value for inferred attr '{attr.name}' for " f"{op_type_name}.") continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") elif attr.name in default_type_attr_map: attrs[attr.name] = default_type_attr_map[attr.name] inferred_from.setdefault(attr.name, "Default in OpDef") else: raise TypeError(f"No argument found for attr {attr.name} for " f"{op_type_name}") # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] if attr_def.HasField("default_value") and value is None: attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) if attr_def.HasField("allowed_values"): if attr_def.type == "string": _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, op_type_name) elif attr_def.type == "list(string)": for value in attr_value.list.s: _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) if attr_def.has_minimum and attr_def.type == "int": _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) if attr_def.type == "type": _SatisfiesTypeConstraint(attr_value.type, attr_def, key) if attr_def.type == "list(type)": for value in attr_value.list.type: _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr, op_type_name) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: all_keywords = ", ".join(sorted(keywords.keys())) raise TypeError(f"{op_type_name} got unexpected keyword arguments: " f"{all_keywords}.") # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs
def _get_op_def(op): return op.op_def or op_def_registry.get(op.type)
def is_registered_stateful_op_without_inputs(name): """Checks if an op is registered, stateful and does not expect inputs.""" op_def = op_def_registry.get(name) return op_def is not None and (op_def.is_stateful and not op_def.input_arg)