Пример #1
0
    def __exit__(self, error_type, unused_value, unused_traceback):
        if error_type:
            # Allow errors that occurred inside this context manager to pass through
            # normally.
            return

        # Only run in V2 Function mode.
        if (context.executing_eagerly()
                or not ops.executing_eagerly_outside_functions()):
            return

        if (self._graph is not ops.get_default_graph()
                or self._graph.name != 'keras_graph'):
            # Only auto-track updates when the Keras Graph is the only one used.
            return

        new_operations = self._graph.get_operations()[self._num_operations:]
        new_stateful_ops = set()

        # pylint: disable=protected-access
        for op in new_operations:
            # While loop is not supported in general for automatic control
            # dependencies.
            if control_flow_util.IsInWhileLoop(op):
                continue

            # Track stateful ops via `add_update`.
            is_stateful_op = (op.type not in self._graph._registered_ops
                              or auto_control_deps.op_is_stateful(
                                  self._graph._registered_ops[op.type]))

            # Ignore ReadVariableOps as they are not needed to be run separately.
            # This ensures existing Layers don't get extra updates.
            if is_stateful_op and op.type != 'ReadVariableOp':
                new_stateful_ops.add(op)

        explicit_updates = set([
            u
            for u in self.layer._get_unfiltered_updates(check_trainable=False)
            if not isinstance(u, tuple)
        ])
        # pylint: enable=protected-access

        # Don't add updates that will already be run by virtue of being consumed by
        # other stateful ops or by the Layer's outputs. This ensures that existing
        # Layers like `BatchNormalization` continue to return the same values for
        # `.update` calls.
        minimum_ops = set()
        targets = new_stateful_ops.union(set(nest.flatten(self.outputs)),
                                         explicit_updates)
        for op in new_stateful_ops:
            # Scrub any ops that are consumed by the outputs or other stateful ops.
            reachable = tf_utils.get_reachable_from_inputs(op)
            if not (targets - {op}).intersection(reachable):
                minimum_ops.add(op)
        new_stateful_ops = minimum_ops

        # Don't double-track updates added via explicitly calling `add_update`.
        # Also don't double-track updates already tracked in sublayers.
        new_stateful_ops = new_stateful_ops - explicit_updates

        # Decide whether to track as input-conditional or unconditional.
        input_reachable_ops = tf_utils.get_reachable_from_inputs(
            self.inputs, targets=new_stateful_ops)
        unconditional_updates = new_stateful_ops - input_reachable_ops
        conditional_updates = new_stateful_ops - unconditional_updates

        if unconditional_updates:
            self.layer.add_update(list(unconditional_updates))
        if conditional_updates:
            self.layer.add_update(list(conditional_updates),
                                  inputs=self.inputs)
Пример #2
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        # pylint: disable=protected-access
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Graph changed while trying to add control dependencies.")

        if hasattr(self._graph, "outer_graph"):
            outer_val = self._graph.outer_graph._add_control_dependencies
            self._graph._add_control_dependencies = outer_val
        else:
            self._graph._add_control_dependencies = False

        # map from resource tensor to the last op which wrote to it
        last_write_to_resource = {}
        # map from resource tensor to the list of reads from it since the last
        # write or since the beginning of the function.
        reads_since_last_write_to_resource = collections.defaultdict(list)
        # CollectiveManager manager_ids within a particular function call should not
        # be needed outside of that function call. So we keep them separate (though
        # the general idea of the maps is the same, in the future, we'll need to
        # correctly thread the control output outside).
        # Map from collective manager scope to the last op which used it
        collective_manager_scopes_opened = {}
        collective_manager_scopes_used = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_write_to_resource).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()
            # Ensure stateful ops run
            if (op_def_registry.get(op.type) is None
                    or (op_is_stateful(op)
                        and op.type not in utils.RESOURCE_READ_OPS)):
                # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
                # they only have variable reads and are otherwise stateless.
                ops_which_must_run.add(op)
            # Make a note of all opened manager_ids.
            if op.type == "NoOp":
                try:
                    collective_manager_scopes_opened[op.get_attr(
                        "_collective_manager_id")] = op
                except ValueError:
                    pass
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)
                    for inp in o.inputs:
                        input_id = ops.tensor_id(inp)
                        if input_id in last_write_to_resource:
                            last_write_to_resource[input_id] = op
                ops_which_must_run = set([op])
                continue

            resource_inputs = set()
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_write_to_resource.
            for inp, resource_type in _get_resource_inputs(op):
                is_read = resource_type == ResourceType.READ_ONLY
                input_id = ops.tensor_id(inp)

                # If the op receives the same resource tensor twice as an input, we skip
                # to avoid the op getting a control dependency on itself.
                if input_id in resource_inputs:
                    continue

                resource_inputs.add(input_id)
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_write_to_resource,
                                         merge_for_resource)
                is_building_function = op.graph.building_function
                # Ensure uses of resources are serialized
                if input_id in last_write_to_resource:
                    if is_building_function or (
                            last_write_to_resource[input_id].
                            _control_flow_context is op._control_flow_context):
                        control_inputs.add(last_write_to_resource[input_id])
                # Ensure merges happen after the closing of a cond block
                if input_id in merge_for_resource:
                    merge_for_resource[input_id]._add_control_input(op)
                if is_read:
                    reads_since_last_write_to_resource[input_id].append(op)
                else:
                    control_inputs.update(
                        reads_since_last_write_to_resource[input_id])
                    reads_since_last_write_to_resource[input_id] = []
                    last_write_to_resource[input_id] = op

            if (op_is_stateful(op) and not resource_inputs
                    and op._control_flow_context is None):
                if None in last_write_to_resource:
                    op._add_control_input(last_write_to_resource[None])
                last_write_to_resource[None] = op

            # Ensure ordering of collective ops
            manager_ids = collective_manager_ids_from_op(op)
            for manager_id in manager_ids:
                if manager_id in collective_manager_scopes_opened:
                    # Chain this function call if the scope was opened.
                    op._add_control_input(
                        collective_manager_scopes_opened[manager_id])
                    collective_manager_scopes_opened[manager_id] = op
                else:
                    # If this op is in a scope not created here, create a chain starting
                    # at this op.
                    if manager_id in collective_manager_scopes_used:
                        op._add_control_input(
                            collective_manager_scopes_used[manager_id])
                    collective_manager_scopes_used[manager_id] = op

            if control_inputs and not is_building_function:
                control_inputs = [
                    c for c in control_inputs
                    if c._control_flow_context is op._control_flow_context
                ]

            op._add_control_inputs(control_inputs)

        # Ensure all ops which must run do run
        self.ops_which_must_run.update(ops_which_must_run)
        for r in nest.flatten(list(self._returned_tensors),
                              expand_composites=True):
            if self.ops_which_must_run:
                updated_ops_which_must_run = []
                if r.graph.building_function:
                    updated_ops_which_must_run = self.ops_which_must_run
                else:
                    updated_ops_which_must_run = [
                        o for o in self.ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ]
                r.op._add_control_inputs(updated_ops_which_must_run)

        self.collective_manager_ids_used = collective_manager_scopes_used
Пример #3
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Graph changed while trying to add control dependencies.")

        # pylint: disable=protected-access
        if hasattr(self._graph, "outer_graph"):
            outer_val = self._graph.outer_graph._add_control_dependencies
            self._graph._add_control_dependencies = outer_val
        else:
            self._graph._add_control_dependencies = False
        # pylint: enable=protected-access

        # map from resource tensor to the last op which used it
        last_op_using_resource_tensor = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_op_using_resource_tensor).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()
            # Ensure stateful ops run
            if (op.type not in self._graph._registered_ops  # pylint: disable=protected-access
                    or (self._graph._registered_ops[op.type].is_stateful  # pylint: disable=protected-access
                        and op.type not in ASYNC_STATEFUL_OPS)):
                ops_which_must_run.add(op)
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)  # pylint: disable=protected-access
                    for inp in o.inputs:
                        if inp in last_op_using_resource_tensor:
                            last_op_using_resource_tensor[inp] = op
                ops_which_must_run = set([op])
                continue
            found_resource = False
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_op_using_resource_tensor. Note that we dedup op.inputs in case
            # op receives the same resource tensor twice as input, which would result
            # in op getting a control dependency on itself.
            for inp in set(op.inputs):
                if inp.dtype != dtypes_module.resource:
                    continue
                found_resource = True
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_op_using_resource_tensor,
                                         merge_for_resource)
                # Ensure uses of resources are serialized
                if inp in last_op_using_resource_tensor:
                    if (last_op_using_resource_tensor[inp].
                            _control_flow_context  # pylint: disable=protected-access
                            is op._control_flow_context):  # pylint: disable=protected-access
                        control_inputs.add(last_op_using_resource_tensor[inp])
                # Ensure merges happen after the closing of a cond block
                if inp in merge_for_resource:
                    merge_for_resource[inp]._add_control_input(op)  # pylint: disable=protected-access
                last_op_using_resource_tensor[inp] = op
            if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS
                    and not found_resource
                    and op._control_flow_context is None):  # pylint: disable=protected-access
                if None in last_op_using_resource_tensor:
                    op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
                last_op_using_resource_tensor[None] = op
            control_inputs = [
                c for c in control_inputs
                if c._control_flow_context is op._control_flow_context
            ]  # pylint: disable=protected-access
            op._add_control_inputs(control_inputs)  # pylint: disable=protected-access

        # Ensure all ops which must run do run
        for r in self._returned_tensors:
            if ops_which_must_run:
                r.op._add_control_inputs(  # pylint: disable=protected-access
                    [
                        o for o in ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ])  # pylint: disable=protected-access
Пример #4
0
  def __exit__(self, unused_type, unused_value, unused_traceback):
    # pylint: disable=protected-access
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Within the automatic control dependency context, the default graph"
          f" cannot change. Upon entry it was {self._graph}, but on exit it"
          f" changed to {ops.get_default_graph()}")

    if hasattr(self._graph, "outer_graph"):
      outer_val = self._graph.outer_graph._add_control_dependencies
      self._graph._add_control_dependencies = outer_val
    else:
      self._graph._add_control_dependencies = False

    # map from resource tensor to the last op which wrote to it
    last_write_to_resource = {}
    # map from resource tensor to the list of reads from it since the last
    # write or since the beginning of the function.
    reads_since_last_write_to_resource = collections.defaultdict(list)
    # CollectiveManager manager_ids within a particular function call should not
    # be needed outside of that function call. So we keep them separate (though
    # the general idea of the maps is the same, in the future, we'll need to
    # correctly thread the control output outside).
    # Map from collective manager scope to the last op which used it
    collective_manager_scopes_opened = {}
    collective_manager_scopes_used = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]
    first_use_for_res = {}
    resources_by_op = {}

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_write_to_resource).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if control_flow_util.IsInWhileLoop(op):
        continue
      control_inputs = set()

      # Ensure stateful ops run.
      # Read-only ops are added to control outputs if the read value is
      # consumed. This covers the case when the read value is returned from
      # the function since that goes through a tf.identity in mark_as_return.
      if ((op_def_registry.get(op.type) is None) or
          (op_is_stateful(op) and
           (op.type not in utils.RESOURCE_READ_OPS or
            any(output.consumers() for output in op.outputs))) or
          (op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS)):
        ops_which_must_run.add(op)

      # Make a note of all opened manager_ids.
      if op.type == "NoOp":
        try:
          collective_manager_scopes_opened[op.get_attr(
              "_collective_manager_id")] = op
        except ValueError:
          pass
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      # TODO(mdan): Don't do this. Write a transform to chains instead.
      # See core/common_runtime/control_flow_deps_to_chains.cc.
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)
          for inp in o.inputs:
            input_id = ops.tensor_id(inp)
            if input_id in last_write_to_resource:
              last_write_to_resource[input_id] = op
        ops_which_must_run = set([op])
        continue

      resource_inputs = set()
      # Check for any resource inputs. If we find any, we update control_inputs
      # and last_write_to_resource.
      for inp, resource_type in _get_resource_inputs(op):
        is_read = resource_type == ResourceType.READ_ONLY
        input_id = ops.tensor_id(inp)

        # If the op receives the same resource tensor twice as an input, we skip
        # to avoid the op getting a control dependency on itself.
        if input_id in resource_inputs:
          continue

        resource_inputs.add(input_id)
        # Deal with switches, finally.
        if inp.op.type == "Switch":
          self._process_switch(inp.op, ops_which_must_run,
                               last_write_to_resource, merge_for_resource)
        is_building_function = op.graph.building_function
        # Ensure uses of resources are serialized
        if input_id in last_write_to_resource:
          if is_building_function or (
              last_write_to_resource[input_id]._control_flow_context
              is op._control_flow_context):
            control_inputs.add(last_write_to_resource[input_id])
        # Ensure merges happen after the closing of a cond block
        if input_id in merge_for_resource:
          merge_for_resource[input_id]._add_control_input(op)

        do_record = (
            self.record_initial_resource_uses and
            input_id not in first_use_for_res)

        if is_read:
          reads_list = reads_since_last_write_to_resource[input_id]
          reads_list.append(op)

          if do_record:
            # Note: this will track the entire list that
            # reads_since_last_write_to_resource maintains. Updates to it will
            # and should be tracked, until the first write is encountered. At
            # that point, reads_since_last_write_to_resource will contain a new
            # empty list. This logic relies on that behavior.
            first_use_for_res[input_id] = reads_list

        else:
          control_inputs.update(reads_since_last_write_to_resource[input_id])
          reads_since_last_write_to_resource[input_id] = []
          last_write_to_resource[input_id] = op

          if do_record:
            first_use_for_res[input_id] = [op]

      if self.record_initial_resource_uses and op_is_stateful(op):
        if resource_inputs:
          resources_by_op[op] = tuple(resource_inputs)
        else:
          if None not in first_use_for_res:
            first_use_for_res[None] = [op]
          resources_by_op[op] = (None,)

      if (op_is_stateful(op) and not resource_inputs
          and op._control_flow_context is None):
        if None in last_write_to_resource:
          op._add_control_input(last_write_to_resource[None])
        last_write_to_resource[None] = op

      # Ensure ordering of collective ops
      manager_ids = collective_manager_ids_from_op(op)
      for manager_id in manager_ids:
        if manager_id in collective_manager_scopes_opened:
          # Chain this function call if the scope was opened.
          op._add_control_input(collective_manager_scopes_opened[manager_id])
          collective_manager_scopes_opened[manager_id] = op
        else:
          # If this op is in a scope not created here, create a chain starting
          # at this op.
          if manager_id in collective_manager_scopes_used:
            op._add_control_input(collective_manager_scopes_used[manager_id])
          collective_manager_scopes_used[manager_id] = op

      if control_inputs and not is_building_function:
        control_inputs = [
            c for c in control_inputs
            if c._control_flow_context is op._control_flow_context
        ]

      op._add_control_inputs(control_inputs)

    # Record the ops which first use resources touched by "ops which must run".
    if self.record_initial_resource_uses:
      first_uses_by_output_ops = {}
      for op in ops_which_must_run:
        if op not in resources_by_op:
          # This may happen with Merge/Switch nodes which are special cased
          # above.
          continue
        for r in resources_by_op[op]:
          if op not in first_uses_by_output_ops:
            first_uses_by_output_ops[op] = set()
          first_uses_by_output_ops[op].update(first_use_for_res[r])
      # For each "op which must run", set a private attr indicating the ops that
      # used the same resources it did.
      for op in first_uses_by_output_ops:
        others = [
            other.name.encode() for other in first_uses_by_output_ops[op]
        ]
        l = attr_value_pb2.AttrValue.ListValue(s=others)
        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))

    # Ensure all ops which must run do run
    self.ops_which_must_run.update(ops_which_must_run)
    control_output_op = None
    for idx, r in enumerate(
        nest.flatten(list(self._returned_tensors), expand_composites=True)):
      if self.ops_which_must_run:
        updated_ops_which_must_run = []
        if r.graph.building_function:
          # There may be many stateful ops in the graph. Adding them as
          # control inputs to each function output could create excessive
          # control edges in the graph. Thus we create an intermediate No-op
          # to chain the control dependencies between stateful ops and
          # function outputs.
          if idx == 0:
            control_output_op = control_flow_ops.no_op()
            control_output_op._set_attr("_acd_function_control_output",
                                        attr_value_pb2.AttrValue(b=True))
            control_output_op._add_control_inputs(self.ops_which_must_run)
          updated_ops_which_must_run = [control_output_op]
        else:
          updated_ops_which_must_run = [
              o for o in self.ops_which_must_run
              if o._control_flow_context is r.op._control_flow_context
          ]
        r.op._add_control_inputs(updated_ops_which_must_run)

    self.collective_manager_ids_used = collective_manager_scopes_used