def _update(self, var, fn, args, kwargs, group): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) if tpu_values.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: return (fn(var, *args, **kwargs), ) # Otherwise, we revert to MirroredStrategy behavior and update the variable # on each replica directly. updates = [] values_and_devices = [] packed_var = var._packed_variable # pylint: disable=protected-access if packed_var is not None: for device in packed_var.devices: values_and_devices.append((packed_var, device)) else: for value in var.values: values_and_devices.append((value, value.device)) for i, value_and_device in enumerate(values_and_devices): value = value_and_device[0] device = value_and_device[1] name = "update_%d" % i with ops.device(device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(value, *distribute_utils.select_replica_mirrored(i, args), **distribute_utils.select_replica_mirrored(i, kwargs))) return distribute_utils.update_regroup(self, updates, group)
def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.DistributedValues) or tensor_util.is_tensor(value) ) and tpu_values.enclosing_tpu_context() is not None: if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, self._num_replicas_in_sync) value_list = value.values # pylint: disable=protected-access if isinstance(value, values.DistributedVariable ) and value._packed_variable is not None: value_list = tuple( value._packed_variable.on_device(d) for d in value._packed_variable.devices) # pylint: enable=protected-access # Currently XLA op by op mode has a limit for the number of inputs for a # single op, thus we break one `add_n` op into a group of `add_n` ops to # work around the constraint. # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: output = math_ops.add_n(value_list) else: output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype) for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT): output += math_ops.add_n( value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT]) if reduce_op == reduce_util.ReduceOp.MEAN: output *= (1. / len(value_list)) devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: # If necessary, copy to requested destination. dest_canonical = device_util.canonicalize(devices[0]) host_canonical = device_util.canonicalize(self._host_device) if dest_canonical != host_canonical: with ops.device(dest_canonical): output = array_ops.identity(output) else: output = cross_device_ops_lib.simple_broadcast( output, destinations) return output
def experimental_logical_device(self, logical_device_id): """Places variables and ops on the specified logical device.""" num_logical_devices_per_replica = self._tpu_devices.shape[1] if logical_device_id >= num_logical_devices_per_replica: raise ValueError( "`logical_device_id` not in range (was {}, but there are only {} " "logical devices per replica).".format( logical_device_id, num_logical_devices_per_replica)) self._logical_device_stack.append(logical_device_id) try: if tpu_values.enclosing_tpu_context() is None: yield else: with ops.device(tpu.core(logical_device_id)): yield finally: self._logical_device_stack.pop()
def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.DistributedValues) or tensor_util.is_tensor(value) ) and tpu_values.enclosing_tpu_context() is not None: if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, value, destinations, self._num_replicas_in_sync) # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`. # Always performs the reduction on the TPU host. with ops.device(self._host_device): output = math_ops.add_n(value.values) if reduce_op == reduce_util.ReduceOp.MEAN: output *= (1. / len(value.values)) devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: # If necessary, copy to requested destination. dest_canonical = device_util.canonicalize(devices[0]) host_canonical = device_util.canonicalize(self._host_device) if dest_canonical != host_canonical: with ops.device(dest_canonical): output = array_ops.identity(output) else: output = cross_device_ops_lib.simple_broadcast( output, destinations) return output
def _broadcast_to(self, tensor, destinations): del destinations # This is both a fast path for Python constants, and a way to delay # converting Python values to a tensor until we know what type it # should be converted to. Otherwise we have trouble with: # global_step.assign_add(1) # since the `1` gets broadcast as an int32 but global_step is int64. if isinstance(tensor, (float, int)): return tensor if tpu_values.enclosing_tpu_context() is not None: broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] result = tpu_ops.all_to_all( broadcast_tensor, concat_dimension=0, split_dimension=0, split_count=self._num_replicas_in_sync) # This uses the broadcasted value from the first replica because the only # caller of this is for ONLY_FIRST_REPLICA variables aggregation. return result[0] return tensor
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance( var, resource_variable_ops.BaseResourceVariable) if tpu_values.enclosing_tpu_context() is not None: if group: return fn(var, *args, **kwargs) else: return (fn(var, *args, **kwargs), ) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = [] for i, v in enumerate(var.values): name = "update_%d" % i with ops.device(v.device), \ distribute_lib.UpdateContext(i), \ ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates.append( fn(v, *values.select_replica_mirrored(i, args), **values.select_replica_mirrored(i, kwargs))) return values.update_regroup(self, updates, group)