def allreduce(self, x, mesh_axes, reduction_fn_string): """Grouped allreduce, (summed across the given dimensions). Args: x: a LaidOutTensor mesh_axes: a list of integers reduction_fn_string: "SUM" Returns: a LaidOutTensor Raises: ValueError: if the reduction is not yet implemented. """ if not mesh_axes: return x x = x.to_laid_out_tensor() if reduction_fn_string == "SUM": group_assignment = self._create_group_assignment(mesh_axes) return self.LaidOutTensor( [tpu_ops.cross_replica_sum(x.one_slice, group_assignment)]) else: for axis in mesh_axes: x = self.allconcat(x, axis, 0, stack=True) x = self.LaidOutTensor( [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)]) return x
def mtf_model_fn(self, features, mesh): hparams = self._hparams hparams.batch_size = 10 hparams.io_size = 4 hparams.hidden_size = 2 tf_x = tf.matmul( tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size), [hparams.batch_size, 1]), tf.reshape(tf.lin_space(0., 1.0, hparams.io_size), [1, hparams.io_size])) # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size]) hidden_1_variable = tf.get_variable( "a", shape=[hparams.io_size, hparams.hidden_size], initializer=tf.random_normal_initializer()) hidden_2_variable = tf.get_variable( "b", shape=[hparams.hidden_size, hparams.io_size], initializer=tf.random_normal_initializer()) hidden_layer_1 = tf.matmul(tf_x, hidden_1_variable) hidden_layer_2 = tf.matmul(hidden_layer_1, hidden_2_variable) hidden_layer_2 = tpu_ops.cross_replica_sum(hidden_layer_2) loss = tf.reduce_mean(tf.square(hidden_layer_2 - tf_x)) return None, loss
def _reduce(self, aggregation, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) elif aggregation != vs.VariableAggregation.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self.get_host_cpu_device(0)) else: raise ValueError('Multiple devices are not supported for TPUStrategy') if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: return value[0] output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def _reduce(self, aggregation, value, destinations): graph = ops.get_default_graph() cf_context = graph._get_control_flow_context() # pylint: disable=protected-access # If we're inside the ReplicateContext, reduction should be done using # CrossReplicaSum while outside we can directly use an add_n op. while cf_context: if isinstance(cf_context, tpu.TPUReplicateContext): if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) return tpu_ops.cross_replica_sum(value) cf_context = cf_context.outer_context # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize(self._host) else: raise ValueError( 'Multiple devices are not supported for TPUStrategy') output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def _reduce(self, aggregation, value, destinations): graph = ops.get_default_graph() cf_context = graph._get_control_flow_context() # pylint: disable=protected-access # If we're inside the ReplicateContext, reduction should be done using # CrossReplicaSum while outside we can directly use an add_n op. while cf_context: if isinstance(cf_context, tpu.TPUReplicateContext): if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) return tpu_ops.cross_replica_sum(value) cf_context = cf_context.outer_context # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host) else: raise ValueError('Multiple devices are not supported for TPUStrategy') output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, self._device_map, value, destinations) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize(self._host_device) else: raise ValueError( "Multiple devices are not supported for TPUStrategy") output = math_ops.add_n(value) if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, self._device_map, value, destinations) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") output = math_ops.add_n(value) if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. Calls tpu_ops.cross_replica_sum() to sum gradient contributions across replicas, and then applies the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: An `Operation` that applies the gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: ValueError: If the grads_and_vars is malformed. """ summed_grads_and_vars = [] for (grad, var) in grads_and_vars: if grad is None: summed_grads_and_vars.append((grad, var)) else: summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def allreduce(self, x, mesh_axes, reduction_fn_string): """Grouped allreduce, (summed across the given dimensions). Args: x: a LaidOutTensor mesh_axes: a list of integers reduction_fn_string: "SUM" Returns: a LaidOutTensor Raises: ValueError: if the reduction is not yet implemented. """ if not mesh_axes: return x x = x.to_laid_out_tensor() if reduction_fn_string == "SUM": group_assignment = self._create_group_assignment(mesh_axes) tf_in = x.one_slice dtype = tf_in.dtype if not (dtype == tf.float32 or dtype == tf.bfloat16): tf.logging.info("Casting %s to float32 for allreduce" % tf_in.dtype) tf_in = tf.cast(tf_in, tf.float32) tf_out = tpu_ops.cross_replica_sum(tf_in, group_assignment) if tf_out.dtype != dtype: tf_out = tf.cast(tf_out, dtype) return self.LaidOutTensor([tf_out]) else: for axis in mesh_axes: x = self.allconcat(x, axis, 0, stack=True) x = self.LaidOutTensor( [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)]) return x
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize(self._host_device) else: raise ValueError( "Multiple devices are not supported for TPUStrategy") output = math_ops.add_n(value) if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output
def _reduce(self, aggregation, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self.num_towers) elif aggregation != vs.VariableAggregation.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize( devices[0]) == device_util.canonicalize( self.get_host_cpu_device(0)) else: raise ValueError( 'Multiple devices are not supported for TPUStrategy') if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: return value[0] output = math_ops.add_n(value) if aggregation == vs.VariableAggregation.MEAN: return output * (1. / len(value)) return output
def allreduce(self, x, mesh_axes, reduction_fn_string): """Grouped allreduce, (summed across the given dimensions). Args: x: a LaidOutTensor mesh_axes: a list of integers reduction_fn_string: "SUM" Returns: a LaidOutTensor Raises: ValueError: if the reduction is not yet implemented. """ if not mesh_axes: return x x = x.to_laid_out_tensor() if reduction_fn_string == "SUM": partitioning = [ mtf.pnum_to_group(self.shape, mesh_axes, pnum) for pnum in xrange(self.size)] return self.LaidOutTensor( [tpu_ops.cross_replica_sum(x.one_slice, partitioning)]) else: for axis in mesh_axes: x = self.allconcat(x, axis, 0, stack=True) x = self.LaidOutTensor( [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)]) return x
def cross_replica_average(inputs, num_shards=None, num_shards_per_group=None, physical_shape=None, tile_shape=None, use_spatial_partitioning=False): """Customized cross replica sum op.""" # if num_shards_per_group is defined, apply distributed batch norm. group_assignment = None if num_shards_per_group > 0: if num_shards % num_shards_per_group != 0: raise ValueError( 'num_shards: %d mod num_shards_per_group: %d, should be 0' % (num_shards, num_shards_per_group)) num_groups = num_shards // num_shards_per_group if physical_shape is not None and tile_shape is not None: if use_spatial_partitioning: group_assignment = spatial_partitioning_group_assignment( physical_shape, tile_shape, num_groups) else: group_assignment = normal_group_assignment(physical_shape, tile_shape, num_groups) else: group_assignment = [ [ # pylint: disable=g-complex-comprehension x for x in range(num_shards) if x // num_shards_per_group == y ] for y in range(num_groups) ] return tpu_ops.cross_replica_sum(inputs, group_assignment) / math_ops.cast( num_shards_per_group, inputs.dtype)
def cross_replica_average(inputs, num_shards, distributed_group_size): """Calculates the average value of inputs tensor across TPU replicas.""" group_assignment = None if num_shards is not None and distributed_group_size != num_shards: group_assignment = [ i // distributed_group_size for i in range(num_shards) ] return tpu_ops.cross_replica_sum(inputs, group_assignment) / tf.cast( distributed_group_size, inputs.dtype)
def _cross_replica_sum(self, grads_and_vars): summed_grads_and_vars = [] for (grad, var) in grads_and_vars: if grad is None: summed_grads_and_vars.append((grad, var)) else: with ops.colocate_with(grad): summed_grads_and_vars.append( (tpu_ops.cross_replica_sum(grad, self._group_assignment), var)) return summed_grads_and_vars
def cross_replica_average(inputs, num_shards, distributed_group_size): """Calculates the average value of inputs tensor across TPU replicas.""" group_assignment = None if num_shards is not None and distributed_group_size != num_shards: group_size = distributed_group_size group_assignment = [] for g in range(num_shards // group_size): replica_ids = [g * group_size + i for i in range(group_size)] group_assignment.append(replica_ids) return tpu_ops.cross_replica_sum(inputs, group_assignment) / tf.cast( distributed_group_size, inputs.dtype)
def _cross_replica_average(self, t, num_shards_per_group): """Calculates the average value of input tensor across TPU replicas.""" num_shards = tpu_function.get_tpu_context().number_of_shards group_assignment = None if num_shards_per_group > 1: if num_shards % num_shards_per_group != 0: raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0' % (num_shards, num_shards_per_group)) num_groups = num_shards // num_shards_per_group group_assignment = [[ x for x in range(num_shards) if x // num_shards_per_group == y ] for y in range(num_groups)] return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast( num_shards_per_group, t.dtype)
def cross_replica_average(inputs, num_shards=None, num_shards_per_group=None): """Customized cross replica sum op.""" # if num_shards_per_group is defined, apply distributed batch norm. group_assignment = None if num_shards_per_group > 0: if num_shards % num_shards_per_group != 0: raise ValueError( 'num_shards: %d mod num_shards_per_group: %d, should be 0' % (num_shards, num_shards_per_group)) num_groups = num_shards // num_shards_per_group group_assignment = [[ x for x in range(num_shards) if x // num_shards_per_group == y ] for y in range(num_groups)] return tpu_ops.cross_replica_sum(inputs, group_assignment) / math_ops.cast( num_shards_per_group, inputs.dtype)
def cross_replica_mean(tensor, name=None): """Takes mean value of a Tensor across all TPU cores. Args: tensor: Tensor to be synchronized. name: None or string. Name of Op. Returns: Average of Tensor across all TPU cores. Raises: ValueError: If called outside of TPU context. """ with ops.name_scope(name, "cross_replica_mean", [tensor]): num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: raise ValueError( "Cannot take cross_replica_mean() outside of TPU Context.") if num_shards == 1: return tensor return tpu_ops.cross_replica_sum(tensor / num_shards)
def cross_replica_mean(tensor, name=None): """Takes mean value of a Tensor across all TPU cores. Args: tensor: Tensor to be synchronized. name: None or string. Name of Op. Returns: Average of Tensor across all TPU cores. Raises: ValueError: If called outside of TPU context. """ with ops.name_scope(name, "cross_replica_mean", [tensor]): num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: raise ValueError( "Cannot take cross_replica_mean() outside of TPU Context.") if num_shards == 1: return tensor return tpu_ops.cross_replica_sum(tensor / num_shards)
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, self._device_map, value, destinations) devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) != 1: raise ValueError( "Multiple devices are not supported for TPUStrategy") # Always performs the reduction on the TPU host. with ops.device(self._host_device): output = math_ops.add_n(value.values) if reduce_op == reduce_util.ReduceOp.MEAN: output *= (1. / len(value.values)) # If necessary, copy to requested destination. dest_canonical = device_util.canonicalize(devices[0]) host_canonical = device_util.canonicalize(self._host_device) if dest_canonical != host_canonical: with ops.device(devices[0]): output = array_ops.identity(output) return output
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) if not isinstance(value, values.DistributedValues): # This function handles reducing values that are not PerReplica or # Mirrored values. For example, the same value could be present on all # replicas in which case `value` would be a single value or value could # be 0. return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, self._device_map, value, destinations) devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) != 1: raise ValueError("Multiple devices are not supported for TPUStrategy") # Always performs the reduction on the TPU host. with ops.device(self._host_device): output = math_ops.add_n(value.values) if reduce_op == reduce_util.ReduceOp.MEAN: output *= (1. / len(value.values)) # If necessary, copy to requested destination. dest_canonical = device_util.canonicalize(devices[0]) host_canonical = device_util.canonicalize(self._host_device) if dest_canonical != host_canonical: with ops.device(devices[0]): output = array_ops.identity(output) return output
def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_replicas_in_sync) elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") output = math_ops.add_n(value) if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output
def get_gradients(self, loss, params): num_shards = tpu_function.get_tpu_context().number_of_shards grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def _reduce(self, method_string, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. if method_string == 'mean': # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_cores_per_host) return tpu_ops.cross_replica_sum(value)
def _reduce(self, method_string, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. if method_string == 'mean': # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_cores_per_host) return tpu_ops.cross_replica_sum(value)
def get_gradients(self, loss, params): num_shards = tpu_function.get_tpu_context().number_of_shards grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def _reduce(self, aggregation, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_cores_per_host) return tpu_ops.cross_replica_sum(value)