def _reduce(self, aggregation, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) elif aggregation != vs.VariableAggregation.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize( self.get_host_cpu_device(0)) else: raise ValueError( 'Multiple devices are not supported for TPUStrategy') if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: return value[0] output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def _update(self, var, fn, *args, **kwargs): # TODO(jhseu): Consider supporting grouped==False. assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access return fn(var, *args, **kwargs) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) # Make a single control dependency to keep the variables mirrored. If one # assignment is fetched, then run all assignments. sorted_keys = sorted(updates.keys()) update_tuple = control_flow_ops.tuple( [updates[d] for d in sorted_keys]) for i, d in enumerate(sorted_keys): updates[d] = update_tuple[i] return values.regroup(updates, values.Mirrored)
def _reduce(self, aggregation, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) elif aggregation != vs.VariableAggregation.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self.get_host_cpu_device(0)) else: raise ValueError('Multiple devices are not supported for TPUStrategy') if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: return value[0] output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize(self._host_device) else: raise ValueError( "Multiple devices are not supported for TPUStrategy") output = math_ops.add_n(value) if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output
def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext( d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, group)
def _update(self, var, options, fn, *args, **kwargs): assert isinstance(var, values.TPUMirroredVariable) should_group = options.pop("grouped") assert not options # Validate that we are processing all of the options. if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if should_group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) return values.update_regroup(self, updates, should_group)
def _update(self, var, fn, *args, **kwargs): # TODO(jhseu): Consider supporting grouped==False. assert isinstance(var, values.TPUMirroredVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access return fn(var, *args, **kwargs) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) # Make a single control dependency to keep the variables mirrored. If one # assignment is fetched, then run all assignments. sorted_keys = sorted(updates.keys()) update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys]) for i, d in enumerate(sorted_keys): updates[d] = update_tuple[i] return values.regroup(updates, values.Mirrored)