def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None): pbar = tqdm(train_dataset) for i, batch in enumerate(pbar): with tf.GradientTape() as tape: inputs, targets = batch outputs = model(batch) loss_value = loss(targets, outputs.logits) if SDP_ENABLED: tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) pbar.set_description(f"Loss: {loss_value:.4f}") if SDP_ENABLED: if i == 0: sdp.broadcast_variables(model.variables, root_rank=0) sdp.broadcast_variables(opt.variables(), root_rank=0) first_batch = False if max_steps and i >= max_steps: break train_results = {"loss": loss_value.numpy()} return train_results
def training_step(images, labels, batch): """ Runs one training step, with considerations for the distributed nature of the job. :param images: images, sorted :param labels: labels for images, same order :param batch: batch number. variables must be broadcasted on first batch :return: loss for this step across all workers """ with tf.GradientTape() as tape: # record losses for automatic differentiation (learning) probs = model(images, training=True) loss_value = loss(labels, probs) # need to wrap tape with more tape, because it is distributed tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) optimizer.apply_gradients(zip( grads, model.trainable_variables)) # learn on this worker if batch == 0: # workers share learning broadcast by broadcasting variables. one set of variables across workers. dist.broadcast_variables(model.variables, root_rank=0) dist.broadcast_variables(optimizer.variables(), root_rank=0) # all loss values from all workers reduced to one loss value loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(mnist_model, loss, opt, images, labels, batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) ######################################################## ####### 4. SageMaker Distributed Data Parallel ######## ####### - Optimize AllReduce operation ######## ######################################################## # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) ####################################################### grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) ######################################################## ####### 5. SageMaker Distributed Data Parallel ####### ####### - Broadcast the initial model variables ####### ####### from rank 0 to ranks 1 ~ n ####### ######################################################## if batch == 0: # SMDataParallel: Broadcast model and optimizer variables smdp.broadcast_variables(mnist_model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) ####################################################### return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) if first_batch: dist.broadcast_variables(mnist_model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) # Create a new DistributedGradientTape, which uses TensorFlow’s GradientTape under the hood, # using an AllReduce to combine gradient values before applying gradients to model weights. tape = smdataparallel.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) # Broadcast model and optimizer variable are first forward pass for sync if first_batch: smdataparallel.broadcast_variables(mnist_model.variables, root_rank=0) smdataparallel.broadcast_variables(opt.variables(), root_rank=0) return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: probs = mnist_model(images, training=True) loss_value = loss(labels, probs) # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = dist.DistributedGradientTape(tape) grads = tape.gradient(loss_value, mnist_model.trainable_variables) opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) if first_batch: # SMDataParallel: Broadcast model and optimizer variables dist.broadcast_variables(mnist_model.variables, root_rank=0) dist.broadcast_variables(opt.variables(), root_rank=0) # SMDataParallel: all_reduce call loss_value = dist.oob_allreduce( loss_value) # Average the loss across workers return loss_value
def training_step(images, labels, first_batch): with tf.GradientTape() as tape: train_pred = model(images, training=True) loss_value = loss(labels, train_pred) # Change: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape tape = smdp.DistributedGradientTape(tape) grads = tape.gradient(loss_value, model.trainable_variables) opt.apply_gradients(zip(grads, model.trainable_variables)) if first_batch: # Change: Broadcast model and optimizer variables smdp.broadcast_variables(model.variables, root_rank=0) smdp.broadcast_variables(opt.variables(), root_rank=0) # Change: all_reduce call train_loss_value = smdp.oob_allreduce( loss_value) # Average the loss across workers train_loss(train_loss_value) train_accuracy(labels, train_pred) return
def main(): parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments) ) ( model_args, data_args, train_args, log_args, path_args, remaining_strings, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) # SageMaker may have some extra strings. TODO: Test this on SM. assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed." tf.random.set_seed(train_args.seed) tf.autograph.set_verbosity(0) # Settings init parse_bool = lambda arg: arg == "true" do_gradient_accumulation = train_args.gradient_accumulation_steps > 1 do_xla = not parse_bool(train_args.skip_xla) do_eager = parse_bool(train_args.eager) skip_sop = parse_bool(train_args.skip_sop) skip_mlm = parse_bool(train_args.skip_mlm) pre_layer_norm = parse_bool(model_args.pre_layer_norm) fast_squad = parse_bool(log_args.fast_squad) dummy_eval = parse_bool(log_args.dummy_eval) is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml") disable_tqdm = is_sagemaker global max_grad_norm max_grad_norm = train_args.max_grad_norm # TODO : Change to obfuscate smddpcommon. This code does not use GradientTape, so need to pass it like this. if train_args.bucket_cap_mb: bucket_cap_bytes = int(train_args.bucket_cap_mb * 1024 * 1024) else: bucket_cap_bytes = int(64 * 1024 * 1024) hc.setBucketSize(bucket_cap_bytes) gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.set_visible_devices(gpus[smddp.local_rank()], "GPU") # XLA, AutoGraph tf.config.optimizer.set_jit(do_xla) tf.config.experimental_run_functions_eagerly(do_eager) if smddp.rank() == 0: # Run name should only be used on one process to avoid race conditions current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") platform = "sm" if is_sagemaker else "eks" if skip_sop: loss_str = "-skipsop" elif skip_mlm: loss_str = "-skipmlm" else: loss_str = "" if log_args.run_name is None: metadata = ( f"{model_args.model_type}" f"-{model_args.model_size}" f"-{model_args.load_from}" f"-{smddp.size()}gpus" f"-{train_args.per_gpu_batch_size * smddp.size() * train_args.gradient_accumulation_steps}globalbatch" f"-{train_args.learning_rate}maxlr" f"-{train_args.learning_rate_decay_power}power" f"-{train_args.optimizer}opt" f"-{train_args.total_steps}steps" f"-{'preln' if pre_layer_norm else 'postln'}" f"{loss_str}" f"-{model_args.hidden_dropout_prob}dropout" ) run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}" else: run_name = log_args.run_name # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" if not os.path.exists(path_args.log_dir): os.makedirs(path_args.log_dir) handlers = [ logging.FileHandler( os.path.join(path_args.filesystem_prefix, path_args.log_dir, f"{run_name}.log") ), TqdmLoggingHandler(), ] logging.basicConfig(level=level, format=format, handlers=handlers) # Check that arguments passed in properly, only after registering the alert_func and logging assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm" wrap_global_functions(do_gradient_accumulation) # Create optimizer and enable AMP loss scaling. if train_args.optimizer == "lamb": optimizer = get_lamb_optimizer(train_args) elif train_args.optimizer == "adamw": optimizer = get_adamw_optimizer(train_args) if _PRE_TF_2_4_0: optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale="dynamic" ) else: optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) gradient_accumulator = GradientAccumulator() loaded_optimizer_weights = None model = create_model(model_class=TFAutoModelForPreTraining, model_args=model_args) tokenizer = create_tokenizer(model_args.model_type) if model_args.load_from == "checkpoint": checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path) model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path) if smddp.rank() == 0: model.load_weights(model_ckpt) if model_args.load_optimizer_state == "true": loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True) # We do not set the weights yet, we have to do a first step to initialize the optimizer. # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories # Move to same folder structure and remove if/else train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord") validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord") train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) train_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=train_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, per_gpu_batch_size=train_args.per_gpu_batch_size, ) # Of shape [per_gpu_batch_size, ...] # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...] train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps) # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps. train_dataset = train_dataset.prefetch(buffer_size=8) # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks if smddp.rank() == 0: validation_dataset = get_dataset_from_tfrecords( model_type=model_args.model_type, filenames=validation_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, per_gpu_batch_size=train_args.per_gpu_batch_size, ) # validation_dataset = validation_dataset.batch(1) validation_dataset = validation_dataset.prefetch(buffer_size=8) pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step logger.info(f"Starting training, job name {run_name}") i = 1 start_time = time.perf_counter() train_start_time = time.perf_counter() for batch in train_dataset: learning_rate = optimizer.learning_rate(step=tf.constant(i, dtype=tf.float32)) # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32)) loss_scale = optimizer.loss_scale() if _PRE_TF_2_4_0 else optimizer.loss_scale loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step( model=model, optimizer=optimizer, gradient_accumulator=gradient_accumulator, batch=batch, gradient_accumulation_steps=train_args.gradient_accumulation_steps, skip_sop=skip_sop, skip_mlm=skip_mlm, ) # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors if i == 1: if smddp.rank() == 0 and loaded_optimizer_weights is not None: optimizer.set_weights(loaded_optimizer_weights) print (" RANK {} is broadcasting".format(smddp.rank())) #smddp.broadcast_variables(model.variables + optimizer.variables(), root_rank=0) smddp.broadcast_variables(model.variables, root_rank=0) smddp.broadcast_variables(optimizer.variables(), root_rank=0) print(" RANK {} is done broadcasting".format(smddp.rank())) # smddp.broadcast_variables(optimizer.variables(), root_rank=0) i = optimizer.get_weights()[0] is_final_step = i >= train_args.total_steps do_squad = (log_args.squad_frequency != 0) and ( (i % log_args.squad_frequency == 0) or is_final_step ) # Squad requires all the ranks to train, but results are only returned on rank 0 if do_squad: from albert.run_squad import get_squad_results_while_pretraining squad_results = get_squad_results_while_pretraining( model=model, tokenizer=tokenizer, model_size=model_args.model_size, filesystem_prefix=path_args.filesystem_prefix, step=i, dataset=data_args.squad_version, fast=log_args.fast_squad, dummy_eval=log_args.dummy_eval, ) if smddp.rank() == 0: squad_exact, squad_f1 = squad_results["exact"], squad_results["f1"] logger.info(f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}") # Re-wrap autograph so it doesn't get arg mismatches wrap_global_functions(do_gradient_accumulation) gc.collect() if smddp.rank() == 0: do_log = i % log_args.log_frequency == 0 do_checkpoint = (log_args.checkpoint_frequency != 0) and ( (i % log_args.checkpoint_frequency == 0) or is_final_step ) do_validation = (log_args.validation_frequency != 0) and ( (i % log_args.validation_frequency == 0) or is_final_step ) pbar.update(1) description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}" pbar.set_description(description) if do_log: elapsed_time = time.perf_counter() - start_time if i == 1: logger.info(f"First step: {elapsed_time:.3f} secs") elif is_final_step: total_time = time.perf_counter() - train_start_time seq_per_sec = i * train_args.per_gpu_batch_size * smddp.size() * train_args.gradient_accumulation_steps / total_time logger.info(f"Final step {i}: {description} -- Average seq_per_sec: {seq_per_sec:.2f} -- Total Time: {total_time}") else: it_per_sec = log_args.log_frequency / elapsed_time logger.info(f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}") start_time = time.perf_counter() if do_checkpoint: checkpoint_prefix = os.path.join( path_args.filesystem_prefix, path_args.checkpoint_dir, f"{run_name}-step{i}" ) model_ckpt = f"{checkpoint_prefix}.ckpt" optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy" logger.info(f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}") model.save_weights(model_ckpt) # model.load_weights(model_ckpt) optimizer_weights = optimizer.get_weights() np.save(optimizer_ckpt, optimizer_weights) # optimizer.set_weights(optimizer_weights) if do_validation: val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation( model=model, validation_dataset=validation_dataset, skip_sop=skip_sop, skip_mlm=skip_mlm, ) description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}" logger.info(f"Validation step {i} -- {description}") # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name) ) config = { **asdict(model_args), **asdict(data_args), **asdict(train_args), **asdict(log_args), "global_batch_size": train_args.per_gpu_batch_size * smddp.size(), } if is_wandb_available(): wandb.init(config=config, project=model_args.model_type) wandb.run.save() wandb_run_name = wandb.run.name train_metrics = { "weight_norm": weight_norm, "grad_norm": grad_norm, "loss_scale": loss_scale, "learning_rate": learning_rate, "train/loss": loss, "train/mlm_loss": mlm_loss, "train/mlm_acc": mlm_acc, "train/sop_loss": sop_loss, "train/sop_acc": sop_acc, } all_metrics = {**train_metrics} if do_validation: val_metrics = { "val/loss": val_loss, "val/mlm_loss": val_mlm_loss, "val/mlm_acc": val_mlm_acc, "val/sop_loss": val_sop_loss, "val/sop_acc": val_sop_acc, } all_metrics = {**all_metrics, **val_metrics} if do_squad: squad_metrics = { "squad/f1": squad_f1, "squad/exact": squad_exact, } all_metrics = {**all_metrics, **squad_metrics} # Log to TensorBoard with summary_writer.as_default(): for name, val in all_metrics.items(): tf.summary.scalar(name, val, step=i) # Log to Weights & Biases if is_wandb_available(): wandb.log({"step": i, **all_metrics}) i += 1 if is_final_step: break if smddp.rank() == 0: pbar.close() logger.info(f"Finished pretraining, job name {run_name}")