예제 #1
0
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and tpu_values.enclosing_tpu_context() is not None:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        value_list = value.values
        # pylint: disable=protected-access
        if isinstance(value, values.DistributedVariable
                      ) and value._packed_variable is not None:
            value_list = tuple(
                value._packed_variable.on_device(d)
                for d in value._packed_variable.devices)
        # pylint: enable=protected-access

        # Currently XLA op by op mode has a limit for the number of inputs for a
        # single op, thus we break one `add_n` op into a group of `add_n` ops to
        # work around the constraint.
        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
            output = math_ops.add_n(value_list)
        else:
            output = array_ops.zeros_like(value_list[0],
                                          dtype=value_list[0].dtype)
            for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
                output += math_ops.add_n(
                    value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])

        if reduce_op == reduce_util.ReduceOp.MEAN:
            output *= (1. / len(value_list))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
예제 #2
0
    def _reduce_to(self, reduce_op, value, destinations):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        # Always performs the reduction on the TPU host.
        with ops.device(self._host_device):
            output = math_ops.add_n(value.values)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                output *= (1. / len(value.values))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output