def __exit__(self, error_type, unused_value, unused_traceback):
        if error_type:
            # Allow errors that occurred inside this context manager to pass through
            # normally.
            return

        # Only run in V2 Function mode.
        if (context.executing_eagerly()
                or not ops.executing_eagerly_outside_functions()):
            return

        if (self._graph is not ops.get_default_graph()
                or self._graph.name != 'keras_graph'):
            # Only auto-track updates when the Keras Graph is the only one used.
            return

        new_operations = self._graph.get_operations()[self._num_operations:]
        new_stateful_ops = set()

        # pylint: disable=protected-access
        for op in new_operations:
            # While loop is not supported in general for automatic control
            # dependencies.
            if control_flow_util.IsInWhileLoop(op):
                continue

            # Track stateful ops via `add_update`.
            is_stateful_op = (op.type not in self._graph._registered_ops
                              or auto_control_deps.op_is_stateful(
                                  self._graph._registered_ops[op.type]))

            # Ignore ReadVariableOps as they are not needed to be run separately.
            # This ensures existing Layers don't get extra updates.
            if is_stateful_op and op.type != 'ReadVariableOp':
                new_stateful_ops.add(op)

        explicit_updates = set([
            u
            for u in self.layer._get_unfiltered_updates(check_trainable=False)
            if not isinstance(u, tuple)
        ])
        # pylint: enable=protected-access

        # Don't add updates that will already be run by virtue of being consumed by
        # other stateful ops or by the Layer's outputs. This ensures that existing
        # Layers like `BatchNormalization` continue to return the same values for
        # `.update` calls.
        minimum_ops = set()
        targets = new_stateful_ops.union(set(nest.flatten(self.outputs)),
                                         explicit_updates)
        for op in new_stateful_ops:
            # Scrub any ops that are consumed by the outputs or other stateful ops.
            reachable = tf_utils.get_reachable_from_inputs(op)
            if not (targets - {op}).intersection(reachable):
                minimum_ops.add(op)
        new_stateful_ops = minimum_ops

        # Don't double-track updates added via explicitly calling `add_update`.
        # Also don't double-track updates already tracked in sublayers.
        new_stateful_ops = new_stateful_ops - explicit_updates

        # Decide whether to track as input-conditional or unconditional.
        input_reachable_ops = tf_utils.get_reachable_from_inputs(
            self.inputs, targets=new_stateful_ops)
        unconditional_updates = new_stateful_ops - input_reachable_ops
        conditional_updates = new_stateful_ops - unconditional_updates

        if unconditional_updates:
            self.layer.add_update(list(unconditional_updates))
        if conditional_updates:
            self.layer.add_update(list(conditional_updates),
                                  inputs=self.inputs)
  def __exit__(self, error_type, unused_value, unused_traceback):
    if error_type:
      # Allow errors that occurred inside this context manager to pass through
      # normally.
      return

    # Only run in V2 Function mode.
    if (context.executing_eagerly() or
        not ops.executing_eagerly_outside_functions()):
      return

    if (self._graph is not ops.get_default_graph() or
        self._graph.name != 'keras_graph'):
      # Only auto-track updates when the Keras Graph is the only one used.
      return

    new_operations = self._graph.get_operations()[self._num_operations:]
    new_stateful_ops = set()

    # pylint: disable=protected-access
    for op in new_operations:
      # While loop is not supported in general for automatic control
      # dependencies.
      if control_flow_util.IsInWhileLoop(op):
        continue

      # Track stateful ops via `add_update`.
      is_stateful_op = (
          op.type not in self._graph._registered_ops or
          auto_control_deps.op_is_stateful(
              self._graph._registered_ops[op.type]))

      # Ignore ReadVariableOps as they are not needed to be run separately.
      # This ensures existing Layers don't get extra updates.
      if is_stateful_op and op.type != 'ReadVariableOp':
        new_stateful_ops.add(op)

    explicit_updates = set(
        [u for u in self.layer._unfiltered_updates if not isinstance(u, tuple)])
    # pylint: enable=protected-access

    # Don't add updates that will already be run by virtue of being consumed by
    # other stateful ops or by the Layer's outputs. This ensures that existing
    # Layers like `BatchNormalization` continue to return the same values for
    # `.update` calls.
    minimum_ops = set()
    targets = new_stateful_ops.union(
        set(nest.flatten(self.outputs)), explicit_updates)
    for op in new_stateful_ops:
      # Scrub any ops that are consumed by the outputs or other stateful ops.
      reachable = tf_utils.get_reachable_from_inputs(op)
      if not (targets - {op}).intersection(reachable):
        minimum_ops.add(op)
    new_stateful_ops = minimum_ops

    # Don't double-track updates added via explicitly calling `add_update`.
    # Also don't double-track updates already tracked in sublayers.
    new_stateful_ops = new_stateful_ops - explicit_updates

    # Decide whether to track as input-conditional or unconditional.
    input_reachable_ops = tf_utils.get_reachable_from_inputs(
        self.inputs, targets=new_stateful_ops)
    unconditional_updates = new_stateful_ops - input_reachable_ops
    conditional_updates = new_stateful_ops - unconditional_updates

    if unconditional_updates:
      self.layer.add_update(list(unconditional_updates))
    if conditional_updates:
      self.layer.add_update(list(conditional_updates), inputs=self.inputs)
def _build_case(branch_index,
                branch_graphs,
                branch_inputs,
                name=None,
                lower_using_switch_merge=None):
  """Creates an `Case` op from `branch_index`, branch graphs and inputs.

  Note that this modifies `branch_graphs` to make the inputs match, and to
  output all intermediates values so they're available for the gradient
  computation.

  `branch_graphs` need not have the same input types, but they must
  have the same outpute types.

  Args:
    branch_index: integer Tensor
    branch_graphs: List of FuncGraph
    branch_inputs: List of lists of Tensors to be passed to corresponding
      branch_graph as input.
    name: the name for the Case op.
    lower_using_switch_merge: Lower this op using switch merge ops (optional).

  Returns:
    A list of Tensors which are the outputs of the Case op. Does not include
    added intermediate outputs.
  """
  _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
  _check_same_outputs(_CASE, branch_graphs)

  # Add inputs to branch_graphs to make them match. Note that this modifies the
  # graphs in `branch_graphs`.
  case_inputs = _make_inputs_match(branch_graphs, branch_inputs)

  stateful_ops = []
  for bg in branch_graphs:
    stateful_ops.extend([
        op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op)
    ])

  if stateful_ops:
    op_fn = gen_functional_ops.case
  else:
    op_fn = gen_functional_ops.stateless_case

  # Create the Case op.
  with ops.control_dependencies(
      sum((list(bg.control_captures) for bg in branch_graphs), [])):
    tensors = op_fn(
        branch_index,
        case_inputs, [t.dtype for t in branch_graphs[0].outputs],
        [util.create_new_tf_function(g) for g in branch_graphs],
        output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
        name=name)

  case_op, tensors = _get_op_and_outputs(tensors)

  if case_op is not None:
    util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
    util.maybe_propagate_compile_time_consts_in_xla(case_op)
    _set_read_only_resource_inputs_attr(case_op, branch_graphs)
    # Prevent fetching since the variant outputs can't be fetched directly.
    case_op.graph.prevent_fetching(case_op)

  _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
  # Return identities for each output of the Case op, rather than the output of
  # the Case op directly. This makes pruning work if the output of switch_case()
  # is fetched: the lowering pass converts the Case outputs into IdentityN
  # outputs, which if fetched will cause all ops in the taken branch to be run
  # (since it takes all merge ops as input). After lowering, each output
  # identity op will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)