예제 #1
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update the variable
        # on each replica directly.
        updates = []
        values_and_devices = []
        packed_var = var._packed_variable  # pylint: disable=protected-access
        if packed_var is not None:
            for device in packed_var.devices:
                values_and_devices.append((packed_var, device))
        else:
            for value in var.values:
                values_and_devices.append((value, value.device))

        for i, value_and_device in enumerate(values_and_devices):
            value = value_and_device[0]
            device = value_and_device[1]
            name = "update_%d" % i
            with ops.device(device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(value,
                       *distribute_utils.select_replica_mirrored(i, args),
                       **distribute_utils.select_replica_mirrored(i, kwargs)))
        return distribute_utils.update_regroup(self, updates, group)
예제 #2
0
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and tpu_values.enclosing_tpu_context() is not None:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        value_list = value.values
        # pylint: disable=protected-access
        if isinstance(value, values.DistributedVariable
                      ) and value._packed_variable is not None:
            value_list = tuple(
                value._packed_variable.on_device(d)
                for d in value._packed_variable.devices)
        # pylint: enable=protected-access

        # Currently XLA op by op mode has a limit for the number of inputs for a
        # single op, thus we break one `add_n` op into a group of `add_n` ops to
        # work around the constraint.
        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
            output = math_ops.add_n(value_list)
        else:
            output = array_ops.zeros_like(value_list[0],
                                          dtype=value_list[0].dtype)
            for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
                output += math_ops.add_n(
                    value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])

        if reduce_op == reduce_util.ReduceOp.MEAN:
            output *= (1. / len(value_list))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
예제 #3
0
    def experimental_logical_device(self, logical_device_id):
        """Places variables and ops on the specified logical device."""
        num_logical_devices_per_replica = self._tpu_devices.shape[1]
        if logical_device_id >= num_logical_devices_per_replica:
            raise ValueError(
                "`logical_device_id` not in range (was {}, but there are only {} "
                "logical devices per replica).".format(
                    logical_device_id, num_logical_devices_per_replica))

        self._logical_device_stack.append(logical_device_id)
        try:
            if tpu_values.enclosing_tpu_context() is None:
                yield
            else:
                with ops.device(tpu.core(logical_device_id)):
                    yield
        finally:
            self._logical_device_stack.pop()
예제 #4
0
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and tpu_values.enclosing_tpu_context() is not None:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        # Always performs the reduction on the TPU host.
        with ops.device(self._host_device):
            output = math_ops.add_n(value.values)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                output *= (1. / len(value.values))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
예제 #5
0
  def _broadcast_to(self, tensor, destinations):
    del destinations
    # This is both a fast path for Python constants, and a way to delay
    # converting Python values to a tensor until we know what type it
    # should be converted to. Otherwise we have trouble with:
    #   global_step.assign_add(1)
    # since the `1` gets broadcast as an int32 but global_step is int64.
    if isinstance(tensor, (float, int)):
      return tensor
    if tpu_values.enclosing_tpu_context() is not None:
      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
      result = tpu_ops.all_to_all(
          broadcast_tensor,
          concat_dimension=0,
          split_dimension=0,
          split_count=self._num_replicas_in_sync)

      # This uses the broadcasted value from the first replica because the only
      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
      return result[0]
    return tensor
예제 #6
0
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
            var, resource_variable_ops.BaseResourceVariable)
        if tpu_values.enclosing_tpu_context() is not None:
            if group:
                return fn(var, *args, **kwargs)
            else:
                return (fn(var, *args, **kwargs), )

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = []
        for i, v in enumerate(var.values):
            name = "update_%d" % i
            with ops.device(v.device), \
                 distribute_lib.UpdateContext(i), \
                 ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates.append(
                    fn(v, *values.select_replica_mirrored(i, args),
                       **values.select_replica_mirrored(i, kwargs)))
        return values.update_regroup(self, updates, group)