def ComputeSplits(batch_size, num_splits): """Creates a tensor of size num_splits of number of values per split. Assigns each split floor(batch_size/num_splits) and round-robins the remainder (if any) to each split. Example:: batch_size: [5] num_splits: 3 returns: [2, 2, 1] Args: batch_size: tensor of rank 0, size of tensor to be split num_splits: number of splits to split tensor into Returns: tensor of length num_splits containing sizes of each split """ values = tf.tile(tf.div([batch_size], num_splits), tf.constant([num_splits], dtype=tf.int32)) mods = tf.tile(tf.constant([1]), tf.mod([batch_size], num_splits)) zeros = tf.tile(tf.constant([0]), tf.subtract(tf.shape(values), tf.shape(mods))) mods = tf.concat([mods, zeros], 0) ret = tf.add(values, mods) # for some reason TF erases shape information if num_splits is 1 if num_splits == 1: ret.set_shape([1]) return ret
def _InputBatch(self): np.random.seed(1) bs, sl = 10, 7 src_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_labels = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_weights = tf.constant(np.ones(shape=[bs, sl], dtype=np.float32)) src_paddings = tf.zeros([bs, sl]) tgt_paddings = tf.zeros([bs, sl]) ret = py_utils.NestedMap() ret.src = py_utils.NestedMap() ret.tgt = py_utils.NestedMap() if self.params.split: src_ids = tf.split(src_ids, 2, 0) src_paddings = tf.split(src_paddings, 2, 0) tgt_ids = tf.split(tgt_ids, 2, 0) tgt_labels = tf.split(tgt_labels, 2, 0) tgt_paddings = tf.split(tgt_paddings, 2, 0) tgt_weights = tf.split(tgt_weights, 2, 0) ret.src.ids = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: src_ids[0], lambda: src_ids[1]) ret.src.paddings = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: src_paddings[0], lambda: src_paddings[1]) ret.tgt.ids = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_ids[0], lambda: tgt_ids[1]) ret.tgt.labels = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_labels[0], lambda: tgt_labels[1]) ret.tgt.paddings = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_paddings[0], lambda: tgt_paddings[1]) ret.tgt.weights = tf.cond( tf.equal(tf.mod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_weights[0], lambda: tgt_weights[1]) else: ret.src.ids = src_ids ret.src.paddings = src_paddings ret.tgt.ids = tgt_ids ret.tgt.labels = tgt_labels ret.tgt.paddings = tgt_paddings ret.tgt.weights = tgt_weights return ret
def Apply(self, lr, var_grad): p = self.params def _Acc(vg): """Updating accumulators.""" v, g = vg with tf.variable_scope(v.op.name): _, a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a) var_grad = var_grad.Transform(_Acc) def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ]) return tf.cond( tf.equal(tf.mod(self.theta.global_step, p.accum_steps), p.accum_steps - 1), _ApplyAndReset, lambda: tf.group(tf.no_op()))
def WrapAngleRad(angles_rad, min_val=-np.pi, max_val=np.pi): """Wrap the value of `angles_rad` to the range [min_val, max_val].""" max_min_diff = max_val - min_val return min_val + tf.mod(angles_rad + max_val, max_min_diff)
def bucket_fn(num): # Drops record if num[0] is odd. return tf.cond(tf.equal(tf.mod(num[0], 2), 0), lambda: 1, lambda: -tf.cast(num[0], tf.int32))
def __init__(self, learning_rate, momentum=0.0, initial_accumulator_value=0.0, start_preconditioning_steps=1000, statistics_computation_frequency=1, matrix_epsilon=1e-6, synchronous_preconditioning=False, second_moment_averaging=1.0, fallback_to_diagonal_dim=4096, max_any_dim=6656, block_size=4096, block_partition_threshold_size=1000000, global_step=None, exponent_multiplier=1.0, name="DistributedShampoo"): """Construct a DistributedShampoo optimizer. Args: learning_rate: A `Tensor` or a floating point value. The learning rate. momentum: A `Tensor` or a floating point value. Momentum is not applied to sparse updates. initial_accumulator_value: A floating point value. start_preconditioning_steps: A int32 value which indicates when to start preconditioning. statistics_computation_frequency: A int32 step value which indicates how often to compute statistics for preconditioning. matrix_epsilon: An epsilon regularizer to make the matrices positive definite. synchronous_preconditioning: Whether to run preconditioning synchronously. second_moment_averaging: 1.0 means sum of gradients squares, while less than 1.0 switches to RMSProp style exponential moving averages of the second moments. fallback_to_diagonal_dim: Fallback to diagonal version of AFMA if the any of the dimension is larger than fallback_to_diagonal_dim. max_any_dim: If maximum value for any dimension is greater than this value we skip preconditioning and fall back to the diagonal. block_size: Dimension of the partitioned tensors. block_partition_threshold_size: Partitions diemnsions beyond this size. global_step: Global step for training. exponent_multiplier: A multiplier 'e` for the exponent for the inverse calculation. e * -1/(2*rank). Only applies when calculating inverses through svd. name: Optional name prefix for the operations created when applying gradients. """ super(DistributedShampoo, self).__init__(False, name) self._learning_rate = learning_rate self._momentum = momentum self._initial_accumulator_value = initial_accumulator_value self._start_preconditioning_steps = start_preconditioning_steps self._matrix_epsilon = matrix_epsilon self._synchronous_preconditioning = synchronous_preconditioning self._second_moment_averaging = second_moment_averaging self._fallback_to_diagonal_dim = fallback_to_diagonal_dim self._max_any_dim = max_any_dim self._block_size = block_size # NOTE: On XLA - int64 is not handled properly. if global_step is not None: self._global_step = tf.cast(tf.identity(global_step), tf.int32) else: self._global_step = tf.cast( tf.identity(tf.train.get_or_create_global_step()), tf.int32) self._run_nondiagonal_update = tf.greater_equal( self._global_step, self._start_preconditioning_steps) start_steps_f = tf.cast(self._start_preconditioning_steps, tf.float32) global_step_f = tf.cast(self._global_step, tf.float32) self._run_nondiagonal_update_warmup = tf.minimum( 1.0, tf.maximum((global_step_f - start_steps_f) / start_steps_f, 0.0)) # Computes statistics every K steps. self._statistics_computation_frequency = statistics_computation_frequency self._run_statistics_computation = tf.equal( tf.mod(self._global_step, self._statistics_computation_frequency), 0) # All vars that are preconditioned. self._all_vars_for_preconditioning = [] self._exponent_multiplier = exponent_multiplier self._partition_info = PartitionConfig(block_partition_threshold_size, block_size) self._partitioner_metadata = {}