def __exit__(self, error_type, unused_value, unused_traceback): if error_type: # Allow errors that occurred inside this context manager to pass through # normally. return # Only run in V2 Function mode. if (context.executing_eagerly() or not ops.executing_eagerly_outside_functions()): return if (self._graph is not ops.get_default_graph() or self._graph.name != 'keras_graph'): # Only auto-track updates when the Keras Graph is the only one used. return new_operations = self._graph.get_operations()[self._num_operations:] new_stateful_ops = set() # pylint: disable=protected-access for op in new_operations: # While loop is not supported in general for automatic control # dependencies. if control_flow_util.IsInWhileLoop(op): continue # Track stateful ops via `add_update`. is_stateful_op = (op.type not in self._graph._registered_ops or auto_control_deps.op_is_stateful( self._graph._registered_ops[op.type])) # Ignore ReadVariableOps as they are not needed to be run separately. # This ensures existing Layers don't get extra updates. if is_stateful_op and op.type != 'ReadVariableOp': new_stateful_ops.add(op) explicit_updates = set([ u for u in self.layer._get_unfiltered_updates(check_trainable=False) if not isinstance(u, tuple) ]) # pylint: enable=protected-access # Don't add updates that will already be run by virtue of being consumed by # other stateful ops or by the Layer's outputs. This ensures that existing # Layers like `BatchNormalization` continue to return the same values for # `.update` calls. minimum_ops = set() targets = new_stateful_ops.union(set(nest.flatten(self.outputs)), explicit_updates) for op in new_stateful_ops: # Scrub any ops that are consumed by the outputs or other stateful ops. reachable = tf_utils.get_reachable_from_inputs(op) if not (targets - {op}).intersection(reachable): minimum_ops.add(op) new_stateful_ops = minimum_ops # Don't double-track updates added via explicitly calling `add_update`. # Also don't double-track updates already tracked in sublayers. new_stateful_ops = new_stateful_ops - explicit_updates # Decide whether to track as input-conditional or unconditional. input_reachable_ops = tf_utils.get_reachable_from_inputs( self.inputs, targets=new_stateful_ops) unconditional_updates = new_stateful_ops - input_reachable_ops conditional_updates = new_stateful_ops - unconditional_updates if unconditional_updates: self.layer.add_update(list(unconditional_updates)) if conditional_updates: self.layer.add_update(list(conditional_updates), inputs=self.inputs)
def __exit__(self, error_type, unused_value, unused_traceback): if error_type: # Allow errors that occurred inside this context manager to pass through # normally. return # Only run in V2 Function mode. if (context.executing_eagerly() or not ops.executing_eagerly_outside_functions()): return if (self._graph is not ops.get_default_graph() or self._graph.name != 'keras_graph'): # Only auto-track updates when the Keras Graph is the only one used. return new_operations = self._graph.get_operations()[self._num_operations:] new_stateful_ops = set() # pylint: disable=protected-access for op in new_operations: # While loop is not supported in general for automatic control # dependencies. if control_flow_util.IsInWhileLoop(op): continue # Track stateful ops via `add_update`. is_stateful_op = ( op.type not in self._graph._registered_ops or auto_control_deps.op_is_stateful( self._graph._registered_ops[op.type])) # Ignore ReadVariableOps as they are not needed to be run separately. # This ensures existing Layers don't get extra updates. if is_stateful_op and op.type != 'ReadVariableOp': new_stateful_ops.add(op) explicit_updates = set( [u for u in self.layer._unfiltered_updates if not isinstance(u, tuple)]) # pylint: enable=protected-access # Don't add updates that will already be run by virtue of being consumed by # other stateful ops or by the Layer's outputs. This ensures that existing # Layers like `BatchNormalization` continue to return the same values for # `.update` calls. minimum_ops = set() targets = new_stateful_ops.union( set(nest.flatten(self.outputs)), explicit_updates) for op in new_stateful_ops: # Scrub any ops that are consumed by the outputs or other stateful ops. reachable = tf_utils.get_reachable_from_inputs(op) if not (targets - {op}).intersection(reachable): minimum_ops.add(op) new_stateful_ops = minimum_ops # Don't double-track updates added via explicitly calling `add_update`. # Also don't double-track updates already tracked in sublayers. new_stateful_ops = new_stateful_ops - explicit_updates # Decide whether to track as input-conditional or unconditional. input_reachable_ops = tf_utils.get_reachable_from_inputs( self.inputs, targets=new_stateful_ops) unconditional_updates = new_stateful_ops - input_reachable_ops conditional_updates = new_stateful_ops - unconditional_updates if unconditional_updates: self.layer.add_update(list(unconditional_updates)) if conditional_updates: self.layer.add_update(list(conditional_updates), inputs=self.inputs)
def _build_case(branch_index, branch_graphs, branch_inputs, name=None, lower_using_switch_merge=None): """Creates an `Case` op from `branch_index`, branch graphs and inputs. Note that this modifies `branch_graphs` to make the inputs match, and to output all intermediates values so they're available for the gradient computation. `branch_graphs` need not have the same input types, but they must have the same outpute types. Args: branch_index: integer Tensor branch_graphs: List of FuncGraph branch_inputs: List of lists of Tensors to be passed to corresponding branch_graph as input. name: the name for the Case op. lower_using_switch_merge: Lower this op using switch merge ops (optional). Returns: A list of Tensors which are the outputs of the Case op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_CASE, branch_graphs) _check_same_outputs(_CASE, branch_graphs) # Add inputs to branch_graphs to make them match. Note that this modifies the # graphs in `branch_graphs`. case_inputs = _make_inputs_match(branch_graphs, branch_inputs) stateful_ops = [] for bg in branch_graphs: stateful_ops.extend([ op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op) ]) if stateful_ops: op_fn = gen_functional_ops.case else: op_fn = gen_functional_ops.stateless_case # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): tensors = op_fn( branch_index, case_inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), name=name) case_op, tensors = _get_op_and_outputs(tensors) if case_op is not None: util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) util.maybe_propagate_compile_time_consts_in_xla(case_op) _set_read_only_resource_inputs_attr(case_op, branch_graphs) # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN # outputs, which if fetched will cause all ops in the taken branch to be run # (since it takes all merge ops as input). After lowering, each output # identity op will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)