Exemple #1
0
    def _reduce(self, aggregation, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if aggregation == vs.VariableAggregation.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self.num_towers)
            elif aggregation != vs.VariableAggregation.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(
                    self.get_host_cpu_device(0))
        else:
            raise ValueError(
                'Multiple devices are not supported for TPUStrategy')

        if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
            return value[0]
        output = math_ops.add_n(value)
        if aggregation == vs.VariableAggregation.MEAN:
            return output * (1. / len(value))
        return output
Exemple #2
0
    def _update(self, var, fn, *args, **kwargs):
        # TODO(jhseu): Consider supporting grouped==False.
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            return fn(var, *args, **kwargs)

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))

        # Make a single control dependency to keep the variables mirrored. If one
        # assignment is fetched, then run all assignments.
        sorted_keys = sorted(updates.keys())
        update_tuple = control_flow_ops.tuple(
            [updates[d] for d in sorted_keys])
        for i, d in enumerate(sorted_keys):
            updates[d] = update_tuple[i]
        return values.regroup(updates, values.Mirrored)
Exemple #3
0
  def _reduce(self, aggregation, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if aggregation == vs.VariableAggregation.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self.num_towers)
      elif aggregation != vs.VariableAggregation.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self.get_host_cpu_device(0))
    else:
      raise ValueError('Multiple devices are not supported for TPUStrategy')

    if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
      return value[0]
    output = math_ops.add_n(value)
    if aggregation == vs.VariableAggregation.MEAN:
      return output * (1. / len(value))
    return output
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host_device)
        else:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        output = math_ops.add_n(value)
        if reduce_op == reduce_util.ReduceOp.MEAN:
            return output * (1. / len(value))
        return output
    def _update(self, var, fn, args, kwargs, group):
        assert isinstance(var, values.TPUMirroredVariable)
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if group:
                return fn(var, *args, **kwargs)
            else:
                return [fn(var, *args, **kwargs)]

        # Otherwise, we revert to MirroredStrategy behavior and update each variable
        # directly.
        updates = {}
        for d, v in var._index.items():  # pylint: disable=protected-access
            name = "update_%d" % self._device_index.get(d)
            with ops.device(d), distribute_lib.UpdateContext(
                    d), ops.name_scope(name):
                # If args and kwargs are not mirrored, the value is returned as is.
                updates[d] = fn(v, *values.select_device_mirrored(d, args),
                                **values.select_device_mirrored(d, kwargs))
        return values.update_regroup(self, updates, group)
Exemple #6
0
  def _update(self, var, options, fn, *args, **kwargs):
    assert isinstance(var, values.TPUMirroredVariable)
    should_group = options.pop("grouped")
    assert not options  # Validate that we are processing all of the options.

    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if should_group:
        return fn(var, *args, **kwargs)
      else:
        return [fn(var, *args, **kwargs)]

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))
    return values.update_regroup(self, updates, should_group)
  def _update(self, var, fn, *args, **kwargs):
    # TODO(jhseu): Consider supporting grouped==False.
    assert isinstance(var, values.TPUMirroredVariable)
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      return fn(var, *args, **kwargs)

    # Otherwise, we revert to MirroredStrategy behavior and update each variable
    # directly.
    updates = {}
    for d, v in var._index.items():  # pylint: disable=protected-access
      name = "update_%d" % self._device_index.get(d)
      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
        # If args and kwargs are not mirrored, the value is returned as is.
        updates[d] = fn(v,
                        *values.select_device_mirrored(d, args),
                        **values.select_device_mirrored(d, kwargs))

    # Make a single control dependency to keep the variables mirrored. If one
    # assignment is fetched, then run all assignments.
    sorted_keys = sorted(updates.keys())
    update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
    for i, d in enumerate(sorted_keys):
      updates[d] = update_tuple[i]
    return values.regroup(updates, values.Mirrored)