Esempio n. 1
0
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, self._device_map, value, destinations)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_device_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host_device)
        else:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        output = math_ops.add_n(value)
        if reduce_op == reduce_util.ReduceOp.MEAN:
            return output * (1. / len(value))
        return output
Esempio n. 2
0
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.Mirrored)
                and reduce_op == reduce_util.ReduceOp.MEAN):
            return value
        assert not isinstance(value, values.Mirrored)

        if (isinstance(value, values.DistributedValues)
                and len(self.worker_devices) == 1):
            value = value.values[0]

        # When there are multiple workers, we need to reduce across workers using
        # collective ops.
        if (not isinstance(value, values.DistributedValues)
                and self._num_workers == 1):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, len(self.worker_devices))
        return self._get_cross_device_ops(value).reduce(
            reduce_op,
            value,
            destinations=destinations,
            experimental_hints=experimental_hints)
Esempio n. 3
0
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and tpu_values.enclosing_tpu_context() is not None:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        value_list = value.values
        # pylint: disable=protected-access
        if isinstance(value, values.DistributedVariable
                      ) and value._packed_variable is not None:
            value_list = tuple(
                value._packed_variable.on_device(d)
                for d in value._packed_variable.devices)
        # pylint: enable=protected-access

        # Currently XLA op by op mode has a limit for the number of inputs for a
        # single op, thus we break one `add_n` op into a group of `add_n` ops to
        # work around the constraint.
        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
            output = math_ops.add_n(value_list)
        else:
            output = array_ops.zeros_like(value_list[0],
                                          dtype=value_list[0].dtype)
            for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
                output += math_ops.add_n(
                    value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])

        if reduce_op == reduce_util.ReduceOp.MEAN:
            output *= (1. / len(value_list))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
Esempio n. 4
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
Esempio n. 5
0
 def _reduce_to(self, reduce_op, value, destinations, options):
   self._verify_destinations_not_different_worker(destinations)
   if not isinstance(value, values.DistributedValues):
     # pylint: disable=protected-access
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, value, destinations, self._num_replicas_in_sync)
   return self._cross_device_ops.reduce(
       reduce_op, value, destinations=destinations, options=options)
 def _reduce_to(self, reduce_op, value, destinations):
   self._verify_destinations_not_different_worker(destinations)
   if not isinstance(value, values.DistributedValues):
     # pylint: disable=protected-access
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, self._device_map, value, destinations)
   return self._cross_device_ops.reduce(
       reduce_op, value, destinations=destinations)
Esempio n. 7
0
 def _reduce_to(self, reduce_op, value, destinations):
   if (isinstance(value, values.Mirrored) and
       reduce_op == reduce_util.ReduceOp.MEAN):
     return value
   assert not isinstance(value, values.Mirrored)
   if not isinstance(value, values.DistributedValues):
     # This function handles reducing values that are not PerReplica or
     # Mirrored values. For example, the same value could be present on all
     # replicas in which case `value` would be a single value or value could
     # be 0.
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, self._device_map, value, destinations)
   return self._get_cross_device_ops().reduce(
       reduce_op, value, destinations=destinations)
Esempio n. 8
0
 def _reduce_to(self, reduce_op, value, destinations):
   if (isinstance(value, values.Mirrored) and
       reduce_op == reduce_util.ReduceOp.MEAN):
     return value
   assert not isinstance(value, values.Mirrored)
   if not isinstance(value, values.DistributedValues):
     # This function handles reducing values that are not PerReplica or
     # Mirrored values. For example, the same value could be present on all
     # replicas in which case `value` would be a single value or value could
     # be 0.
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, self._device_map, value, destinations)
   return self._get_cross_device_ops().reduce(
       reduce_op, value, destinations=destinations)
 def _reduce_to(self, reduce_op, value, destinations, options):
   if (distribute_utils.is_mirrored(value) and
       reduce_op == reduce_util.ReduceOp.MEAN):
     return value
   assert not distribute_utils.is_mirrored(value)
   if not isinstance(value, values.DistributedValues):
     # This function handles reducing values that are not PerReplica or
     # Mirrored values. For example, the same value could be present on all
     # replicas in which case `value` would be a single value or value could
     # be 0.
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, value, destinations, self._num_replicas_in_sync)
   return self._get_cross_device_ops(value).reduce(
       reduce_op,
       value,
       destinations=destinations,
       options=self._communication_options.merge(options))
Esempio n. 10
0
 def get_values(value):
   if not isinstance(value, values.DistributedValues):
     # This function handles reducing values that are not PerReplica or
     # Mirrored values. For example, the same value could be present on all
     # replicas in which case `value` would be a single value or value could
     # be 0.
     return cross_device_ops_lib.reduce_non_distributed_value(
         reduce_op, value, destinations, self._num_replicas_in_sync)
   if self._use_merge_call() and self._collective_ops_in_use and ((
       not cross_device_ops_lib._devices_match(value, destinations) or  # pylint: disable=protected-access
       any("cpu" in d.lower()
           for d in cross_device_ops_lib.get_devices_from(destinations)))):
     return cross_device_ops_lib.ReductionToOneDevice().reduce(
         reduce_op, value, destinations)
   return self._get_cross_device_ops(value).reduce(
       reduce_op,
       value,
       destinations=destinations,
       options=self._communication_options.merge(options))
Esempio n. 11
0
    def _reduce_to(self, reduce_op, value, destinations):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        # Always performs the reduction on the TPU host.
        with ops.device(self._host_device):
            output = math_ops.add_n(value.values)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                output *= (1. / len(value.values))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
  def _reduce_to(self, reduce_op, value, destinations):
    if (isinstance(value, values.Mirrored) and
        reduce_op == reduce_util.ReduceOp.MEAN):
      return value
    assert not isinstance(value, values.Mirrored)

    if (isinstance(value, values.DistributedValues) and
        len(self.worker_devices) == 1):
      value = value.values[0]

    # When there are multiple workers, we need to reduce across workers using
    # collective ops.
    if (not isinstance(value, values.DistributedValues) and
        self._num_workers == 1):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)
    return self._get_cross_device_ops().reduce(
        reduce_op, value, destinations=destinations)
Esempio n. 13
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) != 1:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    # Always performs the reduction on the TPU host.
    with ops.device(self._host_device):
      output = math_ops.add_n(value.values)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        output *= (1. / len(value.values))

    # If necessary, copy to requested destination.
    dest_canonical = device_util.canonicalize(devices[0])
    host_canonical = device_util.canonicalize(self._host_device)

    if dest_canonical != host_canonical:
      with ops.device(devices[0]):
        output = array_ops.identity(output)

    return output