def apply_gradients(self, grads_and_vars, global_step=None, name=None): summed_grads_and_vars = [] for (grad, var) in grads_and_vars: if grad is None: summed_grads_and_vars.append((grad, var)) else: with ops.colocate_with(grad): # gradient accumulation if self._gradients_to_accumulate > 1 and not self._pipelining: grad = gen_poputil_ops.ipu_stateful_gradient_accumulate( grad / self._gradients_to_accumulate, num_mini_batches=self._gradients_to_accumulate) # replication if self._replicas > 1: grad = gen_poputil_ops.ipu_replication_normalise( cross_replica_ops.cross_replica_sum(grad)) grad = math_ops.cast(grad, var.dtype) summed_grads_and_vars.append((grad, var)) if self._pipelining: # can do weight decay here as apply_gradients is only called on last accumulation step summed_grads_and_vars = self.add_WD(summed_grads_and_vars) ret = self._optimizer.apply_gradients(summed_grads_and_vars, global_step, name) if self._sharded: sharding.propagate_sharding(ops.get_default_graph()) return ret
def grad_fn(grad, param, m, v): if self.add_cross_replica_sums: grad = cross_replica_ops.cross_replica_sum(grad) cast_grad = tf.cast(grad, dtype=tf.float32) cast_grad = cast_grad / self.loss_scaling # Standard Adam update. next_m = (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, cast_grad)) next_v = (tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(cast_grad))) update = tf.cast(next_m / (tf.sqrt(next_v) + self.epsilon), param.dtype) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += self.weight_decay_rate * param update_with_lr = self.learning_rate * update next_param = param - update_with_lr return next_param, next_v, next_m
def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. Calls popops_cross_replica_sum.cross_replica_sum() to sum gradient contributions across replicas, and then applies the real optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by compute_gradients(). global_step: Optional Variable to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the Optimizer constructor. Returns: An `Operation` that applies the gradients. If `global_step` was not None, that operation also increments `global_step`. Raises: ValueError: If the grads_and_vars is malformed. """ summed_grads_and_vars = [] for (grad, var) in grads_and_vars: if grad is None: summed_grads_and_vars.append((grad, var)) else: with ops.colocate_with(grad): summed_grads_and_vars.append( (gen_poputil_ops.ipu_replication_normalise( cross_replica_ops.cross_replica_sum(grad)), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def comp_fn(): def body(total_accuracy, image, label): accuracy = validation_graph_builder(model, image, label, opts) return total_accuracy + (tf.cast(accuracy, tf.float32) / opts["validation_batches_per_step"]) accuracy = loops.repeat(int(opts["validation_batches_per_step"]), body, [tf.constant(0, tf.float32)], valid_iterator) if opts['replicas'] > 1: accuracy = cross_replica_ops.cross_replica_sum(accuracy) / (opts['replicas']*opts['shards']) return accuracy
def _reduce_to(self, reduce_op, value, destinations): del destinations if not _is_current_device_ipu(): return value if reduce_op not in (reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN): raise ValueError("Unsupported reduce op: {}".format(reduce_op)) result = cross_replica_ops.cross_replica_sum(value) if reduce_op == reduce_util.ReduceOp.MEAN: result = gen_poputil_ops.ipu_replication_normalise(result) return result
def _reduce_to(self, reduce_op, value, destinations): del destinations if not _is_current_device_ipu(): # If not on IPU, use Horovod for the reduction. return hvd_allreduce(value, op=_to_horovod_op(reduce_op)) # On IPU we do a compiled reduction with GCL. if reduce_op not in (reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN): raise ValueError("Unsupported reduce op: {}".format(reduce_op)) result = cross_replica_ops.cross_replica_sum(value) if reduce_op == reduce_util.ReduceOp.MEAN: result = gen_poputil_ops.ipu_replication_normalise(result) return result
def apply_gradients(self, grads_and_vars, global_step=None, name=None): summed_grads_and_vars = [] for (grad, var) in grads_and_vars: if grad is None: summed_grads_and_vars.append((grad, var)) else: with ops.colocate_with(grad): # gradient accumulation if self._gradient_accumulation_count > 1 and not self._pipelining: grad = gen_poputil_ops.ipu_stateful_gradient_accumulate( grad, num_mini_batches=self._gradient_accumulation_count) # replication if self._replicas > 1: grad = gen_poputil_ops.ipu_replication_normalise( cross_replica_ops.cross_replica_sum(grad)) # distribution with IPUMultiWorkerStrategy needs additional normalisation by the number of workers if isinstance( distribute.get_strategy(), ipu_multi_worker_strategy.IPUMultiWorkerStrategy): grad /= distribute.get_strategy().num_replicas_in_sync grad = math_ops.cast(grad, var.dtype) summed_grads_and_vars.append((grad, var)) if self._pipelining: # can do weight decay here as apply_gradients is only called on last accumulation step summed_grads_and_vars = self.add_WD(summed_grads_and_vars) if self._grad_scale != 1.0: # don't rescale batch norm moving average statistics as they are not affected by loss scaling summed_grads_and_vars = [ (grad, var) if 'batch_norm/moving_' in var.name else (grad / self._grad_scale, var) for grad, var in summed_grads_and_vars ] ret = self._optimizer.apply_gradients(summed_grads_and_vars, global_step, name) if self._sharded: sharding.propagate_sharding(ops.get_default_graph()) return ret
def comp_fn(): def body(total_accuracy, data_dict): accuracy = validation_graph_builder(model, data_dict, opts) if opts['latency']: timestamp_enqueue = timestamp_queue.enqueue( data_dict['timestamp']) return (total_accuracy + (tf.cast(accuracy, tf.float32) / opts["validation_batches_per_step"]), timestamp_enqueue) else: return total_accuracy + ( tf.cast(accuracy, tf.float32) / opts["validation_batches_per_step"]) accuracy = loops.repeat( int(opts["validation_batches_per_step"]), body, [tf.constant(0, tf.float32)], valid_iterator) if opts['total_replicas'] * opts['shards'] > 1 and not opts.get( 'inference', False): accuracy = cross_replica_ops.cross_replica_sum( accuracy) / (opts['total_replicas'] * opts['shards']) return accuracy
def grad_fn(grad, param, m, v): if self.add_cross_replica_sums: grad = cross_replica_ops.cross_replica_sum(grad) # We convert the gradient to fp32 and we rescale it cast_grad = tf.cast(grad, dtype=tf.float32) cast_grad = cast_grad / self.loss_scaling if self.use_nvlamb: # We de normalize the gradients cast_grad = cast_grad * self.clipping_value / global_norm # Standard Adam update. next_m = (tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, cast_grad)) next_v = (tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, tf.square(cast_grad))) # Beta scaling of momentum and velocity if self.debiasing: m_hat = next_m / (1.0 - tf.pow( self.beta_1, tf.cast(self.step, dtype=tf.float32))) # x10 v_hat = next_v / (1.0 - tf.pow( self.beta_2, tf.cast(self.step, dtype=tf.float32)) ) # x1000 else: m_hat = next_m v_hat = next_v # TODO: Check if it is possible to convert to fp16 here. # m_hat = tf.cast(m_hat, dtype = tf.float16) # v_hat = tf.cast(v_hat, dtype = tf.float16) update = m_hat / (tf.sqrt(tf.math.abs(v_hat)) + tf.cast(self.epsilon, dtype=v_hat.dtype)) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += tf.cast(self.weight_decay_rate, dtype=update.dtype) * tf.cast( param, dtype=update.dtype) if 'qkv' in param_name: # We reshape the parameters reshaped_param = self.forward_transform(param) reshaped_update = self.forward_transform(update) else: reshaped_param = tf.reshape(param, [-1]) reshaped_update = tf.reshape(update, [-1]) # Norms are then computed in fp32 w_norm = linalg_ops.norm(tf.cast(reshaped_param, dtype=tf.float32), ord=2, axis=-1) u_norm = linalg_ops.norm(reshaped_update, ord=2, axis=-1) reshaped_update = tf.cast(reshaped_update, dtype=self.target_type) if self.weight_clip: w_norm = tf.math.minimum( w_norm, tf.cast(self.weight_clip, dtype=w_norm.dtype)) # We set the ratio to 1 if either the w norm and the u norms are 0 ratio = array_ops.where( math_ops.greater(w_norm, 0), array_ops.where( math_ops.greater(u_norm, 0), (tf.cast(w_norm, dtype=tf.float32) / u_norm), tf.constant(1.0, dtype=tf.float32, shape=w_norm.shape)), tf.constant(1.0, dtype=tf.float32, shape=w_norm.shape)) # We reshape the ration in order to be broadcastable ratio = tf.reshape(ratio, shape=ratio.shape.as_list() + [1]) # We combine the learning rate and the ratio at fp32 ratio = ratio * tf.cast(self.learning_rate, dtype=tf.float32) # We now downcast to do the next operation # If the scaledd is present we do not need this operation ratio = tf.cast(ratio, dtype=tf.float16) update_with_lr = ratio * reshaped_update # Backward transform to the same as param if 'qkv' in param_name: update_with_lr = self.backward_transform(update_with_lr) else: update_with_lr = tf.reshape(update_with_lr, shape=param.shape) update_with_lr = tf.cast(update_with_lr, dtype=param.dtype) next_param = param - update_with_lr return next_param, next_m, next_v
def my_body(acc, x): index = math_ops.cast(replication_index(), np.float32) + 1.0 acc += cross_replica_sum(index) acc += x * x return acc, outfeed.enqueue(index)