def _step(self, samples, labels, first_batch): self._lr_scheduler() with tf.GradientTape() as tape: probs = self._model(samples, training=True) loss = self._loss_fn(labels, probs) if self._amp: loss = self._embedding_optimizer.get_scaled_loss(loss) embedding_vars, dense_vars = sok.split_embedding_variable_from_others(self._model.trainable_variables) embedding_grads, dense_grads = tape.gradient(loss, [embedding_vars, dense_vars]) if self._amp: embedding_grads = self._embedding_optimizer.get_unscaled_gradients(embedding_grads) dense_grads = self._embedding_optimizer.get_unscaled_gradients(dense_grads) # embedding_grads = [scale_grad(g, hvd.size()) for g in embedding_grads] with sok.OptimizerScope(embedding_vars): self._embedding_optimizer.apply_gradients(zip(embedding_grads, embedding_vars), experimental_aggregate_gradients=False) # with tf.control_dependencies(embedding_grads): dense_grads = [hvd.allreduce(grad, op=hvd.Average, compression=hvd.compression.NoneCompressor) for grad in dense_grads] self._dense_optimizer.apply_gradients(zip(dense_grads, dense_vars), experimental_aggregate_gradients=False) if first_batch: hvd.broadcast_variables(dense_vars, root_rank=0) hvd.broadcast_variables(self._dense_optimizer.variables(), root_rank=0) return loss
def _train_step(inputs, labels): with tf.GradientTape() as tape: logit, embedding_vector = plugin_demo(inputs, training=True) loss = _replica_loss(labels, logit) if args.mixed_precision: _loss = emb_opt.get_scaled_loss(loss) else: _loss = loss embedding_variables, other_variable = sok.split_embedding_variable_from_others( plugin_demo.trainable_variables) grads, emb_grads = tape.gradient(_loss, [other_variable, embedding_variables]) if args.mixed_precision: grads = emb_opt.get_unscaled_gradients(grads) emb_grads = emb_opt.get_unscaled_gradients(emb_grads) with tf.control_dependencies([*emb_grads]): # in case NCCL runs concurrently via SOK and TF if 'plugin' not in args.optimizer: with sok.OptimizerScope(embedding_variables): emb_opt.apply_gradients( zip(emb_grads, embedding_variables), experimental_aggregate_gradients=False) else: emb_opt.apply_gradients(zip(emb_grads, embedding_variables), experimental_aggregate_gradients=False) dense_opt.apply_gradients(zip(grads, other_variable)) return loss, embedding_vector
def _step_fn(inputs, labels): logit, embedding_vector = sok_dense_demo(inputs, training=training) loss = _replica_loss(labels, logit) if args.mixed_precision: _loss = emb_opt.get_scaled_loss(loss) else: _loss = loss emb_var, other_var = sok.split_embedding_variable_from_others( sok_dense_demo.trainable_variables) grads = tf.gradients( _loss, emb_var + other_var, colocate_gradients_with_ops=True, unconnected_gradients=tf.UnconnectedGradients.NONE) emb_grads, other_grads = grads[:len(emb_var)], grads[len(emb_var):] if args.mixed_precision: other_grads = emb_opt.get_unscaled_gradients(other_grads) emb_grads = emb_opt.get_unscaled_gradients(emb_grads) if "plugin" in args.optimizer: emb_train_op = emb_opt.apply_gradients(zip(emb_grads, emb_var)) else: with sok.OptimizerScope(emb_var): emb_train_op = emb_opt.apply_gradients( zip(emb_grads, emb_var)) with tf.control_dependencies([*emb_grads]): # in case NCCL runs concurrently via SOK and horovod other_grads = strategy.reduce("sum", other_grads) other_train_op = dense_opt.apply_gradients( zip(other_grads, other_var)) with tf.control_dependencies([emb_train_op, other_train_op]): total_loss = strategy.reduce("sum", loss) total_loss = tf.identity(total_loss) return total_loss, embedding_vector
def _train_step(inputs, labels): with tf.GradientTape() as tape: logit, all_vectors = model(inputs, training=True) loss = _replica_loss(labels, logit) if args.mixed_precision: _loss = emb_opt.get_scaled_loss(loss) else: _loss = loss emb_variable, other_variable = sok.split_embedding_variable_from_others(model.trainable_variables) grads, emb_grads = tape.gradient(_loss, [other_variable, emb_variable]) if args.mixed_precision: grads = emb_opt.get_unscaled_gradients(grads) emb_grads = emb_opt.get_unscaled_gradients(emb_grads) if "plugin" not in args.optimizer: with sok.OptimizerScope(emb_variable): emb_opt.apply_gradients(zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) else: emb_opt.apply_gradients(zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) with tf.control_dependencies(emb_grads): # mannually all-reduce dense gradients replica_context = tf.distribute.get_replica_context() grads = replica_context.all_reduce("sum", grads, options=comm_options) dense_opt.apply_gradients(zip(grads, other_variable), experimental_aggregate_gradients=False) # manually all-reduce loss, it is ok, because replica_loss has already been used to # update local variables. loss = replica_context.all_reduce(tf.distribute.ReduceOp.SUM, loss, options=comm_options) return loss, all_vectors
def _train_step(inputs, labels, first_batch): with tf.GradientTape() as tape: logit, all_vectors = model(inputs, training=True) replica_loss = _replica_loss(labels, logit) if args.mixed_precision: _loss = emb_opt.get_scaled_loss(replica_loss) else: _loss = replica_loss emb_var, other_var = sok.split_embedding_variable_from_others( model.trainable_variables) emb_grads, grads = tape.gradient(_loss, [emb_var, other_var]) if args.mixed_precision: emb_grads = emb_opt.get_unscaled_gradients(emb_grads) grads = emb_opt.get_unscaled_gradients(grads) if "plugin" not in args.optimizer: with sok.OptimizerScope(emb_var): emb_opt.apply_gradients(zip(emb_grads, emb_var), experimental_aggregate_gradients=False) else: emb_opt.apply_gradients(zip(emb_grads, emb_var), experimental_aggregate_gradients=False) with tf.control_dependencies(emb_grads): grads = [hvd.allreduce(grad) for grad in grads] dense_opt.apply_gradients(zip(grads, other_var)) if first_batch: hvd.broadcast_variables(other_var, root_rank=0) hvd.broadcast_variables(dense_opt.variables(), root_rank=0) total_loss = hvd.allreduce(replica_loss) return total_loss, all_vectors
def split_embedding_variables_from_others(model): if isinstance(model.embedding_layer, SOKEmbedding): return sok.split_embedding_variable_from_others(model.trainable_variables) else: dense_vars = [] for layer in model.layers: if layer != model.embedding_layer: dense_vars.extend(layer.trainable_variables) return model.embedding_layer.trainable_variables, dense_vars
def train_step(inputs, labels): with tf.GradientTape() as tape: logit = model(inputs, training=True) loss = _replica_loss(labels, logit) scaled_loss = optimizer.get_scaled_loss(loss) emb_vars, other_vars =\ sok.split_embedding_variable_from_others(model.trainable_variables) scaled_emb_grads, scaled_other_grads = tape.gradient( scaled_loss, [emb_vars, other_vars]) emb_grads = optimizer.get_unscaled_gradients(scaled_emb_grads) other_grads = optimizer.get_unscaled_gradients(scaled_other_grads) with sok.OptimizerScope(emb_vars): optimizer.apply_gradients(zip(emb_grads, emb_vars), experimental_aggregate_gradients=False) optimizer.apply_gradients(zip(other_grads, other_vars)) return loss
def train_step(inputs, labels): logit = model(inputs, training=True) loss = _replica_loss(labels, logit) scaled_loss = optimizer.get_scaled_loss(loss) scaled_gradients = tf.gradients(scaled_loss, model.trainable_variables) emb_vars, other_vars =\ sok.split_embedding_variable_from_others(model.trainable_variables) scaled_emb_grads, scaled_other_grads =\ scaled_gradients[:len(emb_vars)], scaled_gradients[len(emb_vars):] emb_grads = optimizer.get_unscaled_gradients(scaled_emb_grads) other_grads = optimizer.get_unscaled_gradients(scaled_other_grads) other_grads = [hvd.allreduce(grad) for grad in other_grads] with sok.OptimizerScope(emb_vars): emb_train_op = optimizer.apply_gradients(zip(emb_grads, emb_vars)) other_train_op = optimizer.apply_gradients(zip(other_grads, other_vars)) total_loss = hvd.allreduce(loss) with tf.control_dependencies([emb_train_op, other_train_op]): return tf.identity(total_loss)
def _train_step(inputs, labels, first_batch): with tf.GradientTape() as tape, tf.GradientTape() as emb_tape: logit = model(inputs, training=True) replica_loss = _replica_loss(labels, logit) if args.mixed_precision: _loss = embedding_optimizer.get_scaled_loss(replica_loss) else: _loss = replica_loss tape = hvd.DistributedGradientTape(tape) emb_variable, other_variable = sok.split_embedding_variable_from_others( model.trainable_variables) emb_grads = emb_tape.gradient(_loss, emb_variable) grads = tape.gradient(_loss, other_variable) if args.mixed_precision: emb_grads = embedding_optimizer.get_unscaled_gradients(emb_grads) grads = embedding_optimizer.get_unscaled_gradients(grads) if 'plugin' not in args.optimizer: with sok.OptimizerScope(emb_variable): embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) else: embedding_optimizer.apply_gradients( zip(emb_grads, emb_variable), experimental_aggregate_gradients=False) dense_optimizer.apply_gradients(zip(grads, other_variable)) # Note: broadcast should be done after the first gradient step to ensure optimizer initialization. if first_batch: hvd.broadcast_variables(other_variable, root_rank=0) hvd.broadcast_variables(dense_optimizer.variables(), root_rank=0) return replica_loss