def _set_read_only_resource_inputs_attr(op, func_graph): """Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. Args: op: PartitionedCall Operation. func_graph: FuncGraph. """ read_only_indices = acd.get_read_only_resource_input_indices_graph(func_graph) ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, read_only_indices)
def _set_read_only_resource_inputs_attr(op, func_graph): """Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. Args: op: PartitionedCall Operation. func_graph: FuncGraph. """ read_only_indices = [] for i in range(len(op.inputs)): handle = func_graph.inputs[i] if handle.dtype != dtypes.resource or acd.resource_has_writes(handle): continue read_only_indices.append(i) ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, read_only_indices)
def _set_read_only_resource_inputs_attr(op, branch_graphs): """Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. Args: op: If or Case Operation. branch_graphs: List of branch FuncGraphs. """ # The first entry in `op.inputs` is the predicate which is not passed to # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1. read_only_indices = set(range(len(op.inputs) - 1)) for branch_graph in branch_graphs: assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen" if not read_only_indices: break branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( branch_graph) read_only_indices = read_only_indices.intersection(branch_read_only_indices) # Convert indices in `branch_graphs[i].inputs` to `op.inputs`. read_only_indices = [i + 1 for i in read_only_indices] ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, sorted(read_only_indices))
def _set_read_only_resource_inputs_attr(op, branch_graphs): """Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. Args: op: If or Case Operation. branch_graphs: List of branch FuncGraphs. """ read_only_indices = [] for i in range(1, len(op.inputs)): if op.inputs[i].dtype != dtypes.resource: continue has_write = False for branch_graph in branch_graphs: handle = branch_graph.inputs[i - 1] if acd.resource_has_writes(handle): has_write = True break if not has_write: read_only_indices.append(i) ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, read_only_indices)
def partitioned_call(args, f, tout=None, executing_eagerly=None, config=None, executor_type=None): """Executes a function while respecting device annotations. Currently, only those functions that execute within the same address space can be executed. Args: args: The arguments of the function, including captured inputs. f: The function to execute; an instance of `_DefinedFunction` or `_EagerDefinedFunction`. tout: a list containing the output dtypes enums; if `None`, inferred from the signature of `f`. executing_eagerly: (Optional) A boolean indicating whether the context is executing eagerly. If `None`, fetched from the global context. config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If `None`, all optimizations are disabled. Currently only handled for eager defined functions. executor_type: (Optional) A string for the name of the executor to be used in the function call. If not set, or set to an empty string, the default tensorflow executor will be used. Returns: The list of `Tensor`s returned by invoking `f(args)`. If the function does not return anything, then returns `None` if eager execution is enabled, or the `Operation` if not. """ if tout is None: tout = tuple(x.type for x in f.definition.signature.output_arg) if executing_eagerly is None: executing_eagerly = context.executing_eagerly() if config is None: config = function_utils.get_disabled_rewriter_config() if executor_type is None: executor_type = "" if executing_eagerly: if f.stateful_ops: outputs = gen_functional_ops.stateful_partitioned_call( args=args, Tout=tout, f=f, config_proto=config, executor_type=executor_type) else: outputs = gen_functional_ops.partitioned_call( args=args, Tout=tout, f=f, config_proto=config, executor_type=executor_type) return outputs if outputs else None # The generated binding returns an empty list for functions that don't # return any Tensors, hence the need to use `create_op` directly. args = [ops.convert_to_tensor(x) for x in args] tin_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( type=[x.dtype.as_datatype_enum for x in args])) tout_attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(type=tout)) func_attr = attr_value_pb2.AttrValue(func=attr_value_pb2.NameAttrList( name=f.name)) executor_type_attr = attr_value_pb2.AttrValue( s=compat.as_bytes(executor_type)) # When running in graph mode, the graph and function graphs are optimized # (i.e. run through grappler) per the session options, so we can disable any # eager-specific rewriting. config_proto = attr_value_pb2.AttrValue(s=config) graph = ops.get_default_graph() f.add_to_graph(graph) op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" # Propagate the attribute indicating the need to compile from function to the # call itself. xla_compile_attr = "_XlaMustCompile" op_attrs = { "Tin": tin_attr, "Tout": tout_attr, "f": func_attr, "config_proto": config_proto, "executor_type": executor_type_attr, } if xla_compile_attr in f.definition.attr: op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr] op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs) outputs = op.outputs if hasattr(f, "graph"): _set_read_only_resource_inputs_attr(op, f.graph) if hasattr(f.graph, "collective_manager_ids_used"): ops.set_int_list_attr(op, acd.COLLECTIVE_MANAGER_IDS, f.graph.collective_manager_ids_used) return outputs if outputs else op