Esempio n. 1
0
    def execute(self, fn, *args, **kwargs):
        """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
        name = kwargs.pop("name", None)
        exclusive_resource_access = kwargs.pop("exclusive_resource_access",
                                               True)

        with ops.name_scope(name, "critical_section_execute", []):

            # Ensure that mutex locking only happens *after* all args and
            # kwargs have been executed.  This avoids certain types of deadlocks.
            lock = gen_resource_variable_ops.mutex_lock(self._handle)

            if not context.executing_eagerly():
                # NOTE (ebrevdo): This is to ensure we don't pick up spurious id:1153
                # https://github.com/imdone/tensorflow/issues/1154
                # Operations created by other threads.
                #         with ops.get_default_graph()._lock:  # pylint: disable=protected-access
                existing_ops = ops.get_default_graph().get_operations()
                with ops.control_dependencies([lock]):
                    r = fn(*args, **kwargs)
                # TODO (ebrevdo): If creating critical sections in a python loop, this id:1258
                # https://github.com/imdone/tensorflow/issues/1259
                # makes graph creation time quadratic.  Revisit if this
                # becomes a problem.
                created_ops = (set(
                    ops.get_default_graph().get_operations()).difference(
                        existing_ops))
            else:
                with ops.control_dependencies([lock]):
                    r = fn(*args, **kwargs)

            if not context.executing_eagerly():
                self._add_control_dependencies_to_lock(created_ops, lock.op)

                # captured_resources is a list of resources that are directly
                # accessed only by ops created during fn(), not by any
                # ancestors of those ops in the graph.
                captured_resources = set([
                    input_ for op in created_ops for input_ in op.inputs
                    if input_.dtype == dtypes.resource
                ])

                # NOTE (ebrevdo): The only time self._is_self_handle() is True id:859
                # https://github.com/imdone/tensorflow/issues/860
                # in this call is if one of the recently created ops, within
                # the execute(), themselves attempt to access the
                # CriticalSection.  This will cause a deadlock.
                if any(self._is_self_handle(x) for x in captured_resources):
                    raise ValueError(
                        "The function fn attempts to directly access the "
                        "CriticalSection in which it would be running.  "
                        "This is illegal and would cause deadlocks.")

                self._check_multiple_access_to_resources(
                    captured_resources, exclusive_resource_access)

            r_flat = [_identity(x) for x in nest.flatten(r)]

            with ops.control_dependencies(r_flat):
                # The identity must run on the same machine as self._handle
                with ops.colocate_with(self._handle):
                    # Do not use array_ops.identity as there are special
                    # optimizations within TensorFlow which seem to elide it
                    # even when optimizations are disabled(!).
                    ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
                        lock)

                # Make sure that if any element of r is accessed, all of
                # them are executed together.
                r = nest.pack_sequence_as(
                    r, control_flow_ops.tuple(nest.flatten(r)))

            with ops.control_dependencies([ensure_lock_exists]):
                outputs = nest.map_structure(_identity, r)

            if not context.executing_eagerly():
                signature = _ExecutionSignature(
                    op=lock.op,
                    handle=self._handle,
                    resources=list(captured_resources),
                    exclusive_resource_access=exclusive_resource_access)
                ops.add_to_collections(CRITICAL_SECTION_EXECUTIONS, signature)

            return outputs
Esempio n. 2
0
    def execute(self, fn, *args, **kwargs):
        """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
        name = kwargs.pop("name", None)
        exclusive_resource_access = kwargs.pop("exclusive_resource_access",
                                               True)

        with ops.name_scope(name, "critical_section_execute", []):
            lock = gen_resource_variable_ops.mutex_lock(self._handle)

            with ops.control_dependencies([lock]):
                c_known_ops = set()
                c_captured_tensors = set()

                def add_op_internal(op):
                    c_known_ops.add(op)
                    for i in op.inputs:
                        if i.op not in c_known_ops:
                            c_captured_tensors.add(i)

                c = function.HelperContext(add_op_internal)
                with c:
                    r = fn(*args, **kwargs)

                resource_inputs = set([
                    x for x in list(nest.flatten(args)) +
                    nest.flatten(kwargs.values()) + list(c_captured_tensors)
                    if tensor_util.is_tensor(x) and x.dtype == dtypes.resource
                ])

            if self._handle in resource_inputs:
                raise ValueError(
                    "The function fn attempts to access the "
                    "CriticalSection in which it would be running.  "
                    "This is illegal and would cause deadlocks.  "
                    "CriticalSection: %s." % self._handle)

            if context.in_graph_mode():
                # Collections and op introspection does not work in eager
                # mode.  This is generally ok; since eager mode (as of
                # writing) executes sequentially anyway.
                for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
                    if sg.handle.name == self._handle.name:
                        # Other executions in the same critical section are allowed.
                        continue
                    if not (exclusive_resource_access
                            or sg.exclusive_resource_access):
                        # Neither execution requested exclusive access.
                        continue
                    resource_intersection = resource_inputs.intersection(
                        sg.resources)
                    if resource_intersection:
                        raise ValueError(
                            "This execution would access resources: %s.  Either this "
                            "lock (CriticalSection: %s) or lock '%s' "
                            "(CriticalSection: %s) requested exclusive resource access "
                            "of this resource.  Did you mean to call execute with keyword "
                            "argument exclusive_resource_access=False?" %
                            (list(resource_intersection), self._handle.name,
                             sg.op.name, sg.handle.name))

            def identity(x):  # pylint: disable=invalid-name
                if isinstance(x, tensor_array_ops.TensorArray):
                    return x.identity()
                elif isinstance(x, ops.Operation):
                    return control_flow_ops.group(x)
                elif context.in_eager_mode() and x is None:
                    return None
                else:
                    return array_ops.identity(x)

            r_flat = [identity(x) for x in nest.flatten(r)]

            with ops.control_dependencies(r_flat):
                # The identity must run on the same machine as self._handle
                with ops.colocate_with(self._handle):
                    # Do not use array_ops.identity as there are special
                    # optimizations within TensorFlow which seem to elide it
                    # even when optimizations are disabled(!).
                    ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
                        lock)

                # Make sure that if any element of r is accessed, all of
                # them are executed together.
                r = nest.pack_sequence_as(
                    r, control_flow_ops.tuple(nest.flatten(r)))

            with ops.control_dependencies([ensure_lock_exists]):
                outputs = nest.map_structure(identity, r)

            if context.in_graph_mode():
                signature = _ExecutionSignature(
                    op=lock.op,
                    handle=self._handle,
                    resources=list(resource_inputs),
                    exclusive_resource_access=exclusive_resource_access)
                ops.add_to_collections(CRITICAL_SECTION_EXECUTIONS, signature)

            return outputs
Esempio n. 3
0
  def execute(self, fn, *args, **kwargs):
    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    name = kwargs.pop("name", None)
    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)

    with ops.name_scope(name, "critical_section_execute", []):
      lock = gen_resource_variable_ops.mutex_lock(self._handle)

      with ops.control_dependencies([lock]):
        c_known_ops = set()
        c_captured_tensors = set()

        def add_op_internal(op):
          c_known_ops.add(op)
          for i in op.inputs:
            if i.op not in c_known_ops:
              c_captured_tensors.add(i)

        c = function.HelperContext(add_op_internal)
        with c:
          r = fn(*args, **kwargs)

        resource_inputs = set([
            x for x in
            list(nest.flatten(args)) + nest.flatten(kwargs.values()) +
            list(c_captured_tensors)
            if tensor_util.is_tensor(x) and x.dtype == dtypes.resource])

      if self._handle in resource_inputs:
        raise ValueError("The function fn attempts to access the "
                         "CriticalSection in which it would be running.  "
                         "This is illegal and would cause deadlocks.  "
                         "CriticalSection: %s." % self._handle)

      if not context.executing_eagerly():
        # Collections and op introspection does not work in eager
        # mode.  This is generally ok; since eager mode (as of
        # writing) executes sequentially anyway.
        for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
          sg_handle_name = ops.convert_to_tensor(sg.handle).name
          self_handle_name = ops.convert_to_tensor(self._handle).name
          if sg_handle_name == self_handle_name:
            # Other executions in the same critical section are allowed.
            continue
          if not (exclusive_resource_access or sg.exclusive_resource_access):
            # Neither execution requested exclusive access.
            continue
          resource_intersection = resource_inputs.intersection(sg.resources)
          if resource_intersection:
            raise ValueError(
                "This execution would access resources: %s.  Either this "
                "lock (CriticalSection: %s) or lock '%s' "
                "(CriticalSection: %s) requested exclusive resource access "
                "of this resource.  Did you mean to call execute with keyword "
                "argument exclusive_resource_access=False?" %
                (list(resource_intersection), self._handle.name,
                 sg.op.name, sg.handle.name))

      def identity(x):  # pylint: disable=invalid-name
        if isinstance(x, tensor_array_ops.TensorArray):
          return x.identity()
        elif isinstance(x, ops.Operation):
          return control_flow_ops.group(x)
        elif context.executing_eagerly() and x is None:
          return None
        else:
          return array_ops.identity(x)

      r_flat = [identity(x) for x in nest.flatten(r)]

      with ops.control_dependencies(r_flat):
        # The identity must run on the same machine as self._handle
        with ops.colocate_with(self._handle):
          # Do not use array_ops.identity as there are special
          # optimizations within TensorFlow which seem to elide it
          # even when optimizations are disabled(!).
          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
              lock)

        # Make sure that if any element of r is accessed, all of
        # them are executed together.
        r = nest.pack_sequence_as(
            r, control_flow_ops.tuple(nest.flatten(r)))

      with ops.control_dependencies([ensure_lock_exists]):
        outputs = nest.map_structure(identity, r)

      if not context.executing_eagerly():
        signature = _ExecutionSignature(
            op=lock.op,
            handle=self._handle,
            resources=list(resource_inputs),
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return outputs
    def execute(self, fn, exclusive_resource_access=True, name=None):
        """Execute function `fn()` inside the critical section.

    `fn` should not accept any arguments.  To add extra arguments to when
    calling `fn` in the critical section, create a lambda:

    ```python
    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
    ```

    Args:
      fn: The function to execute.  Must return at least one tensor.
      exclusive_resource_access: Whether the resources required by
        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
        You may want to set this to `False` if you will be accessing a
        resource in read-only mode in two different CriticalSections.
      name: The name to use when creating the execute operation.

    Returns:
      The tensors returned from `fn()`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access == True` and
        another `CriticalSection` has an execution requesting the same
        resources as `fn``.  Note, even if `exclusive_resource_access` is
        `True`, if another execution in another `CriticalSection` was created
        without `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
        with ops.name_scope(name, "critical_section_execute", []):
            # Ensure that mutex locking only happens *after* all args and
            # kwargs have been executed.  This avoids certain types of deadlocks.
            with _push_critical_section_stack(self._signature):
                lock = gen_resource_variable_ops.mutex_lock(self._handle)

                if not context.executing_eagerly():
                    # NOTE(ebrevdo): This is to ensure we don't pick up spurious
                    # Operations created by other threads.
                    with ops.get_default_graph()._lock:  # pylint: disable=protected-access
                        existing_ops = ops.get_default_graph().get_operations()
                        with ops.control_dependencies([lock]):
                            r = fn()
                        # TODO(ebrevdo): If creating critical sections in a python loop,
                        # this makes graph creation time quadratic.  Revisit if this
                        # becomes a problem.
                        created_ops = (set(ops.get_default_graph(
                        ).get_operations()).difference(existing_ops))
                else:
                    with ops.control_dependencies([lock]):
                        r = fn()

            if not context.executing_eagerly():
                self._add_control_dependencies_to_lock(created_ops, lock.op)

                # captured_resources is a list of resources that are directly
                # accessed only by ops created during fn(), not by any
                # ancestors of those ops in the graph.
                captured_resources = object_identity.ObjectIdentitySet([
                    input_ for op in created_ops for input_ in op.inputs
                    if input_.dtype == dtypes.resource
                ])

                # NOTE(ebrevdo): The only time self._is_self_handle() is True
                # in this call is if one of the recently created ops, within
                # the execute(), themselves attempt to access the
                # CriticalSection.  This will cause a deadlock.
                if any(self._is_self_handle(x) for x in captured_resources):
                    raise ValueError(
                        "Attempting to lock a CriticalSection in which we are "
                        f"already running (signature={self._signature}). This is illegal "
                        "and may cause deadlocks.")

                self._check_multiple_access_to_resources(
                    captured_resources, exclusive_resource_access)

            r_flat = [_identity(x) for x in nest.flatten(r)]

            with ops.control_dependencies(r_flat):
                # The identity must run on the same machine as self._handle
                with ops.colocate_with(self._handle):
                    # Do not use array_ops.identity as there are special
                    # optimizations within TensorFlow which seem to elide it
                    # even when optimizations are disabled(!).
                    ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
                        lock)

                # Make sure that if any element of r is accessed, all of
                # them are executed together.
                r = nest.pack_sequence_as(
                    r, control_flow_ops.tuple(nest.flatten(r)))

            with ops.control_dependencies([ensure_lock_exists]):
                outputs = nest.map_structure(_identity, r)

            if not context.executing_eagerly():
                signature = _ExecutionSignature(
                    op=lock.op,
                    handle=self._handle,
                    resources=list(captured_resources),
                    exclusive_resource_access=exclusive_resource_access)
                ops.add_to_collections(CRITICAL_SECTION_EXECUTIONS, signature)

            return outputs
  def execute(self, fn, exclusive_resource_access=True, name=None):
    """Execute function `fn()` inside the critical section.

    `fn` should not accept any arguments.  To add extra arguments to when
    calling `fn` in the critical section, create a lambda:

    ```python
    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
    ```

    Args:
      fn: The function to execute.  Must return at least one tensor.
      exclusive_resource_access: Whether the resources required by
        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
        You may want to set this to `False` if you will be accessing a
        resource in read-only mode in two different CriticalSections.
      name: The name to use when creating the execute operation.

    Returns:
      The tensors returned from `fn()`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access == True` and
        another `CriticalSection` has an execution requesting the same
        resources as `fn``.  Note, even if `exclusive_resource_access` is
        `True`, if another execution in another `CriticalSection` was created
        without `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    with ops.name_scope(name, "critical_section_execute", []):

      # Ensure that mutex locking only happens *after* all args and
      # kwargs have been executed.  This avoids certain types of deadlocks.
      lock = gen_resource_variable_ops.mutex_lock(self._handle)

      if not context.executing_eagerly():
        # NOTE(ebrevdo): This is to ensure we don't pick up spurious
        # Operations created by other threads.
        with ops.get_default_graph()._lock:  # pylint: disable=protected-access
          existing_ops = ops.get_default_graph().get_operations()
          with ops.control_dependencies([lock]):
            r = fn()
          # TODO(ebrevdo): If creating critical sections in a python loop, this
          # makes graph creation time quadratic.  Revisit if this
          # becomes a problem.
          created_ops = (set(ops.get_default_graph().get_operations())
                         .difference(existing_ops))
      else:
        with ops.control_dependencies([lock]):
          r = fn()

      if not context.executing_eagerly():
        self._add_control_dependencies_to_lock(created_ops, lock.op)

        # captured_resources is a list of resources that are directly
        # accessed only by ops created during fn(), not by any
        # ancestors of those ops in the graph.
        captured_resources = set([
            input_ for op in created_ops
            for input_ in op.inputs
            if input_.dtype == dtypes.resource
        ])

        # NOTE(ebrevdo): The only time self._is_self_handle() is True
        # in this call is if one of the recently created ops, within
        # the execute(), themselves attempt to access the
        # CriticalSection.  This will cause a deadlock.
        if any(self._is_self_handle(x) for x in captured_resources):
          raise ValueError("The function fn attempts to directly access the "
                           "CriticalSection in which it would be running.  "
                           "This is illegal and would cause deadlocks.")

        self._check_multiple_access_to_resources(
            captured_resources, exclusive_resource_access)

      r_flat = [_identity(x) for x in nest.flatten(r)]

      with ops.control_dependencies(r_flat):
        # The identity must run on the same machine as self._handle
        with ops.colocate_with(self._handle):
          # Do not use array_ops.identity as there are special
          # optimizations within TensorFlow which seem to elide it
          # even when optimizations are disabled(!).
          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
              lock)

        # Make sure that if any element of r is accessed, all of
        # them are executed together.
        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))

      with ops.control_dependencies([ensure_lock_exists]):
        outputs = nest.map_structure(_identity, r)

      if not context.executing_eagerly():
        signature = _ExecutionSignature(
            op=lock.op,
            handle=self._handle,
            resources=list(captured_resources),
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return outputs