def __exit__(self, error_type, unused_value, unused_traceback): if error_type: # Allow errors that occurred inside this context manager to pass through # normally. return # Only run in V2 Function mode. if (context.executing_eagerly() or not ops.executing_eagerly_outside_functions()): return if (self._graph is not ops.get_default_graph() or self._graph.name != 'keras_graph'): # Only auto-track updates when the Keras Graph is the only one used. return new_operations = self._graph.get_operations()[self._num_operations:] new_stateful_ops = set() # pylint: disable=protected-access for op in new_operations: # While loop is not supported in general for automatic control # dependencies. if control_flow_util.IsInWhileLoop(op): continue # Track stateful ops via `add_update`. is_stateful_op = (op.type not in self._graph._registered_ops or auto_control_deps.op_is_stateful( self._graph._registered_ops[op.type])) # Ignore ReadVariableOps as they are not needed to be run separately. # This ensures existing Layers don't get extra updates. if is_stateful_op and op.type != 'ReadVariableOp': new_stateful_ops.add(op) explicit_updates = set([ u for u in self.layer._get_unfiltered_updates(check_trainable=False) if not isinstance(u, tuple) ]) # pylint: enable=protected-access # Don't add updates that will already be run by virtue of being consumed by # other stateful ops or by the Layer's outputs. This ensures that existing # Layers like `BatchNormalization` continue to return the same values for # `.update` calls. minimum_ops = set() targets = new_stateful_ops.union(set(nest.flatten(self.outputs)), explicit_updates) for op in new_stateful_ops: # Scrub any ops that are consumed by the outputs or other stateful ops. reachable = tf_utils.get_reachable_from_inputs(op) if not (targets - {op}).intersection(reachable): minimum_ops.add(op) new_stateful_ops = minimum_ops # Don't double-track updates added via explicitly calling `add_update`. # Also don't double-track updates already tracked in sublayers. new_stateful_ops = new_stateful_ops - explicit_updates # Decide whether to track as input-conditional or unconditional. input_reachable_ops = tf_utils.get_reachable_from_inputs( self.inputs, targets=new_stateful_ops) unconditional_updates = new_stateful_ops - input_reachable_ops conditional_updates = new_stateful_ops - unconditional_updates if unconditional_updates: self.layer.add_update(list(unconditional_updates)) if conditional_updates: self.layer.add_update(list(conditional_updates), inputs=self.inputs)
def __exit__(self, unused_type, unused_value, unused_traceback): # pylint: disable=protected-access if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Graph changed while trying to add control dependencies.") if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) # CollectiveManager manager_ids within a particular function call should not # be needed outside of that function call. So we keep them separate (though # the general idea of the maps is the same, in the future, we'll need to # correctly thread the control output outside). # Map from collective manager scope to the last op which used it collective_manager_scopes_opened = {} collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run if (op_def_registry.get(op.type) is None or (op_is_stateful(op) and op.type not in utils.RESOURCE_READ_OPS)): # TODO(srbs): Do not add functional ops to `ops_which_must_run` if # they only have variable reads and are otherwise stateless. ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": try: collective_manager_scopes_opened[op.get_attr( "_collective_manager_id")] = op except ValueError: pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[ 0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_write_to_resource. for inp, resource_type in _get_resource_inputs(op): is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( last_write_to_resource[input_id]. _control_flow_context is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) if is_read: reads_since_last_write_to_resource[input_id].append(op) else: control_inputs.update( reads_since_last_write_to_resource[input_id]) reads_since_last_write_to_resource[input_id] = [] last_write_to_resource[input_id] = op if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): if None in last_write_to_resource: op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op # Ensure ordering of collective ops manager_ids = collective_manager_ids_from_op(op) for manager_id in manager_ids: if manager_id in collective_manager_scopes_opened: # Chain this function call if the scope was opened. op._add_control_input( collective_manager_scopes_opened[manager_id]) collective_manager_scopes_opened[manager_id] = op else: # If this op is in a scope not created here, create a chain starting # at this op. if manager_id in collective_manager_scopes_used: op._add_control_input( collective_manager_scopes_used[manager_id]) collective_manager_scopes_used[manager_id] = op if control_inputs and not is_building_function: control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] op._add_control_inputs(control_inputs) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) for r in nest.flatten(list(self._returned_tensors), expand_composites=True): if self.ops_which_must_run: updated_ops_which_must_run = [] if r.graph.building_function: updated_ops_which_must_run = self.ops_which_must_run else: updated_ops_which_must_run = [ o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context ] r.op._add_control_inputs(updated_ops_which_must_run) self.collective_manager_ids_used = collective_manager_scopes_used
def __exit__(self, unused_type, unused_value, unused_traceback): if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Graph changed while trying to add control dependencies.") # pylint: disable=protected-access if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # pylint: enable=protected-access # map from resource tensor to the last op which used it last_op_using_resource_tensor = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_op_using_resource_tensor). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run if (op.type not in self._graph._registered_ops # pylint: disable=protected-access or (self._graph._registered_ops[op.type].is_stateful # pylint: disable=protected-access and op.type not in ASYNC_STATEFUL_OPS)): ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[ 0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) # pylint: disable=protected-access for inp in o.inputs: if inp in last_op_using_resource_tensor: last_op_using_resource_tensor[inp] = op ops_which_must_run = set([op]) continue found_resource = False # Check for any resource inputs. If we find any, we update control_inputs # and last_op_using_resource_tensor. Note that we dedup op.inputs in case # op receives the same resource tensor twice as input, which would result # in op getting a control dependency on itself. for inp in set(op.inputs): if inp.dtype != dtypes_module.resource: continue found_resource = True # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_op_using_resource_tensor, merge_for_resource) # Ensure uses of resources are serialized if inp in last_op_using_resource_tensor: if (last_op_using_resource_tensor[inp]. _control_flow_context # pylint: disable=protected-access is op._control_flow_context): # pylint: disable=protected-access control_inputs.add(last_op_using_resource_tensor[inp]) # Ensure merges happen after the closing of a cond block if inp in merge_for_resource: merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[inp] = op if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access if None in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access last_op_using_resource_tensor[None] = op control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] # pylint: disable=protected-access op._add_control_inputs(control_inputs) # pylint: disable=protected-access # Ensure all ops which must run do run for r in self._returned_tensors: if ops_which_must_run: r.op._add_control_inputs( # pylint: disable=protected-access [ o for o in ops_which_must_run if o._control_flow_context is r.op._control_flow_context ]) # pylint: disable=protected-access
def __exit__(self, unused_type, unused_value, unused_traceback): # pylint: disable=protected-access if context.executing_eagerly(): return if self._graph is not ops.get_default_graph(): raise RuntimeError( "Within the automatic control dependency context, the default graph" f" cannot change. Upon entry it was {self._graph}, but on exit it" f" changed to {ops.get_default_graph()}") if hasattr(self._graph, "outer_graph"): outer_val = self._graph.outer_graph._add_control_dependencies self._graph._add_control_dependencies = outer_val else: self._graph._add_control_dependencies = False # map from resource tensor to the last op which wrote to it last_write_to_resource = {} # map from resource tensor to the list of reads from it since the last # write or since the beginning of the function. reads_since_last_write_to_resource = collections.defaultdict(list) # CollectiveManager manager_ids within a particular function call should not # be needed outside of that function call. So we keep them separate (though # the general idea of the maps is the same, in the future, we'll need to # correctly thread the control output outside). # Map from collective manager scope to the last op which used it collective_manager_scopes_opened = {} collective_manager_scopes_used = {} # set of conditional and loop exits ops_which_must_run = set() # merge which must depend on ops which use this resource merge_for_resource = {} new_operations = self._graph.get_operations()[self._n_operations:] first_use_for_res = {} resources_by_op = {} # Ensures that uses of resource tensors get serialized properly and all # execute. This is done by keeping a map from resource tensor to the last op # in graph-construction order which used it (last_write_to_resource). # # Conditionals are written in TensorFlow such that every external tensor # accessed in the conditional goes through a switch op and every return # tensor (it's guaranteed that there will be at least one) goes through a # merge op. # # To handle conditionals, switches are handled in a special way (see # comments for _process_switch). Merge nodes created by TF's conditional # logic (as opposed to by _process_switch) are forced to run and also get a # control dependency added to them to ensure all stateful ops inside their # control flow context run. # # We also ensure that if an op is using a resource output by a switch node # (that is, a resource tensor for which there's a value in # merge_for_resource) this op will run before the merge for that resource. # # We try to add control inputs to nodes respecting their control flow # contexts to avoid dead nodes propagating everywhere and leading to # "retval[0] doesn't have value" errors. If a node gets a control dependency # on a dead node (i.e. a note from an untaken control flow branch) that node # will be marked as dead unless it's a merge node. # # TODO(apassos): serialize non-resource-taking stateful ops as well, and # test that it works. Support while loops. Support init_scope escaping from # this. for op in new_operations: # TODO(apassos) make this code safely support while loops. if control_flow_util.IsInWhileLoop(op): continue control_inputs = set() # Ensure stateful ops run. # Read-only ops are added to control outputs if the read value is # consumed. This covers the case when the read value is returned from # the function since that goes through a tf.identity in mark_as_return. if ((op_def_registry.get(op.type) is None) or (op_is_stateful(op) and (op.type not in utils.RESOURCE_READ_OPS or any(output.consumers() for output in op.outputs))) or (op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS)): ops_which_must_run.add(op) # Make a note of all opened manager_ids. if op.type == "NoOp": try: collective_manager_scopes_opened[op.get_attr( "_collective_manager_id")] = op except ValueError: pass # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: continue # Make merges trigger all other computation which must run # TODO(mdan): Don't do this. Write a transform to chains instead. # See core/common_runtime/control_flow_deps_to_chains.cc. if op.type == "Merge": for o in ops_which_must_run: op._add_control_input(o) for inp in o.inputs: input_id = ops.tensor_id(inp) if input_id in last_write_to_resource: last_write_to_resource[input_id] = op ops_which_must_run = set([op]) continue resource_inputs = set() # Check for any resource inputs. If we find any, we update control_inputs # and last_write_to_resource. for inp, resource_type in _get_resource_inputs(op): is_read = resource_type == ResourceType.READ_ONLY input_id = ops.tensor_id(inp) # If the op receives the same resource tensor twice as an input, we skip # to avoid the op getting a control dependency on itself. if input_id in resource_inputs: continue resource_inputs.add(input_id) # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, merge_for_resource) is_building_function = op.graph.building_function # Ensure uses of resources are serialized if input_id in last_write_to_resource: if is_building_function or ( last_write_to_resource[input_id]._control_flow_context is op._control_flow_context): control_inputs.add(last_write_to_resource[input_id]) # Ensure merges happen after the closing of a cond block if input_id in merge_for_resource: merge_for_resource[input_id]._add_control_input(op) do_record = ( self.record_initial_resource_uses and input_id not in first_use_for_res) if is_read: reads_list = reads_since_last_write_to_resource[input_id] reads_list.append(op) if do_record: # Note: this will track the entire list that # reads_since_last_write_to_resource maintains. Updates to it will # and should be tracked, until the first write is encountered. At # that point, reads_since_last_write_to_resource will contain a new # empty list. This logic relies on that behavior. first_use_for_res[input_id] = reads_list else: control_inputs.update(reads_since_last_write_to_resource[input_id]) reads_since_last_write_to_resource[input_id] = [] last_write_to_resource[input_id] = op if do_record: first_use_for_res[input_id] = [op] if self.record_initial_resource_uses and op_is_stateful(op): if resource_inputs: resources_by_op[op] = tuple(resource_inputs) else: if None not in first_use_for_res: first_use_for_res[None] = [op] resources_by_op[op] = (None,) if (op_is_stateful(op) and not resource_inputs and op._control_flow_context is None): if None in last_write_to_resource: op._add_control_input(last_write_to_resource[None]) last_write_to_resource[None] = op # Ensure ordering of collective ops manager_ids = collective_manager_ids_from_op(op) for manager_id in manager_ids: if manager_id in collective_manager_scopes_opened: # Chain this function call if the scope was opened. op._add_control_input(collective_manager_scopes_opened[manager_id]) collective_manager_scopes_opened[manager_id] = op else: # If this op is in a scope not created here, create a chain starting # at this op. if manager_id in collective_manager_scopes_used: op._add_control_input(collective_manager_scopes_used[manager_id]) collective_manager_scopes_used[manager_id] = op if control_inputs and not is_building_function: control_inputs = [ c for c in control_inputs if c._control_flow_context is op._control_flow_context ] op._add_control_inputs(control_inputs) # Record the ops which first use resources touched by "ops which must run". if self.record_initial_resource_uses: first_uses_by_output_ops = {} for op in ops_which_must_run: if op not in resources_by_op: # This may happen with Merge/Switch nodes which are special cased # above. continue for r in resources_by_op[op]: if op not in first_uses_by_output_ops: first_uses_by_output_ops[op] = set() first_uses_by_output_ops[op].update(first_use_for_res[r]) # For each "op which must run", set a private attr indicating the ops that # used the same resources it did. for op in first_uses_by_output_ops: others = [ other.name.encode() for other in first_uses_by_output_ops[op] ] l = attr_value_pb2.AttrValue.ListValue(s=others) # TODO(mdan): Is there a way which doesn't use anonymous attrs? op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l)) # Ensure all ops which must run do run self.ops_which_must_run.update(ops_which_must_run) control_output_op = None for idx, r in enumerate( nest.flatten(list(self._returned_tensors), expand_composites=True)): if self.ops_which_must_run: updated_ops_which_must_run = [] if r.graph.building_function: # There may be many stateful ops in the graph. Adding them as # control inputs to each function output could create excessive # control edges in the graph. Thus we create an intermediate No-op # to chain the control dependencies between stateful ops and # function outputs. if idx == 0: control_output_op = control_flow_ops.no_op() control_output_op._set_attr("_acd_function_control_output", attr_value_pb2.AttrValue(b=True)) control_output_op._add_control_inputs(self.ops_which_must_run) updated_ops_which_must_run = [control_output_op] else: updated_ops_which_must_run = [ o for o in self.ops_which_must_run if o._control_flow_context is r.op._control_flow_context ] r.op._add_control_inputs(updated_ops_which_must_run) self.collective_manager_ids_used = collective_manager_scopes_used