def load_qa_from_pretrained( model: Optional[tf.keras.Model] = None, name: Optional[str] = None, path: Optional[str] = None, # path to checkpoint from TF...ForPreTraining config: Optional[PretrainedConfig] = None, ) -> tf.keras.Model: """ Load a TF...QuestionAnswering model by taking the main layer of a pretrained model. Preserves the model.config attribute. """ assert (bool(name) ^ bool(model) ^ (bool(path) and bool(config)) ), "Pass either name, model, or (path and config)" if name is not None: return TFAutoModelForQuestionAnswering.from_pretrained(name) elif model is not None: pretrained_model = model elif path is not None: pretrained_model = TFAutoModelForPreTraining.from_config(config) pretrained_model.load_weights(path) qa_model = TFAutoModelForQuestionAnswering.from_config( pretrained_model.config) pretrained_main_layer = getattr(pretrained_model, qa_model.base_model_prefix) assert ( pretrained_main_layer is not None ), f"{pretrained_model} has no attribute '{model.base_model_prefix}'" # Generalized way of saying `model.albert = pretrained_model.albert` setattr(qa_model, qa_model.base_model_prefix, pretrained_main_layer) return qa_model
def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments)) model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses( ) tf.random.set_seed(train_args.seed) tf.autograph.set_verbosity(0) # Settings init parse_bool = lambda arg: arg == "true" do_gradient_accumulation = train_args.gradient_accumulation_steps > 1 do_xla = not parse_bool(train_args.skip_xla) do_eager = parse_bool(train_args.eager) skip_sop = parse_bool(train_args.skip_sop) skip_mlm = parse_bool(train_args.skip_mlm) pre_layer_norm = parse_bool(model_args.pre_layer_norm) fast_squad = parse_bool(log_args.fast_squad) dummy_eval = parse_bool(log_args.dummy_eval) squad_steps = get_squad_steps(log_args.extra_squad_steps) is_sagemaker = data_args.fsx_prefix.startswith("/opt/ml") disable_tqdm = is_sagemaker global max_grad_norm max_grad_norm = train_args.max_grad_norm # Horovod init hvd.init() gpus = tf.config.list_physical_devices("GPU") for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU") # XLA, AutoGraph tf.config.optimizer.set_jit(do_xla) tf.config.experimental_run_functions_eagerly(do_eager) if hvd.rank() == 0: # Run name should only be used on one process to avoid race conditions current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") platform = "sm" if is_sagemaker else "eks" if skip_sop: loss_str = "-skipsop" elif skip_mlm: loss_str = "-skipmlm" else: loss_str = "" metadata = (f"{model_args.model_type}" f"-{model_args.model_size}" f"-{model_args.load_from}" f"-{hvd.size()}gpus" f"-{train_args.batch_size}batch" f"-{train_args.gradient_accumulation_steps}accum" f"-{train_args.learning_rate}maxlr" f"-{train_args.end_learning_rate}endlr" f"-{train_args.learning_rate_decay_power}power" f"-{train_args.max_grad_norm}maxgrad" f"-{train_args.optimizer}opt" f"-{train_args.total_steps}steps" f"-{data_args.max_seq_length}seq" f"-{data_args.max_predictions_per_seq}preds" f"-{'preln' if pre_layer_norm else 'postln'}" f"{loss_str}" f"-{model_args.hidden_dropout_prob}dropout" f"-{train_args.seed}seed") run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}" # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" handlers = [ logging.FileHandler( f"{data_args.fsx_prefix}/logs/albert/{run_name}.log"), TqdmLoggingHandler(), ] logging.basicConfig(level=level, format=format, handlers=handlers) # Check that arguments passed in properly, only after registering the alert_func and logging assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm" wrap_global_functions(do_gradient_accumulation) if model_args.model_type == "albert": model_desc = f"albert-{model_args.model_size}-v2" elif model_args.model_type == "bert": model_desc = f"bert-{model_args.model_size}-uncased" config = AutoConfig.from_pretrained(model_desc) config.pre_layer_norm = pre_layer_norm config.hidden_dropout_prob = model_args.hidden_dropout_prob model = TFAutoModelForPreTraining.from_config(config) # Create optimizer and enable AMP loss scaling. schedule = LinearWarmupPolyDecaySchedule( max_learning_rate=train_args.learning_rate, end_learning_rate=train_args.end_learning_rate, warmup_steps=train_args.warmup_steps, total_steps=train_args.total_steps, power=train_args.learning_rate_decay_power, ) if train_args.optimizer == "lamb": opt = LAMB( learning_rate=schedule, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) elif train_args.optimizer == "adam": opt = AdamW(weight_decay=0.0, learning_rate=schedule) opt = tf.train.experimental.enable_mixed_precision_graph_rewrite( opt, loss_scale="dynamic") gradient_accumulator = GradientAccumulator() loaded_opt_weights = None if model_args.load_from == "scratch": pass elif model_args.load_from.startswith("huggingface"): assert (model_args.model_type == "albert" ), "Only loading pretrained albert models is supported" huggingface_name = f"albert-{model_args.model_size}-v2" if model_args.load_from == "huggingface": albert = TFAlbertModel.from_pretrained(huggingface_name, config=config) model.albert = albert else: model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix( model_args.checkpoint_path) model = TFAutoModelForPreTraining.from_config(config) if hvd.rank() == 0: model.load_weights(model_ckpt) loaded_opt_weights = np.load(opt_ckpt, allow_pickle=True) # We do not set the weights yet, we have to do a first step to initialize the optimizer. # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories # Move to same folder structure and remove if/else if model_args.model_type == "albert": train_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" validation_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" if model_args.model_type == "bert": train_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/training/*.tfrecord" validation_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/validation/*.tfrecord" train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) train_dataset = get_mlm_dataset( filenames=train_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, batch_size=train_args.batch_size, ) # Of shape [batch_size, ...] # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...] train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps) # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps. train_dataset = train_dataset.prefetch(buffer_size=8) # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks if hvd.rank() == 0: validation_dataset = get_mlm_dataset( filenames=validation_filenames, max_seq_length=data_args.max_seq_length, max_predictions_per_seq=data_args.max_predictions_per_seq, batch_size=train_args.batch_size, ) # validation_dataset = validation_dataset.batch(1) validation_dataset = validation_dataset.prefetch(buffer_size=8) pbar = tqdm.tqdm(train_args.total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step logger.info(f"Starting training, job name {run_name}") i = 0 start_time = time.perf_counter() for batch in train_dataset: learning_rate = schedule(step=tf.constant(i, dtype=tf.float32)) loss_scale = opt.loss_scale() loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step( model=model, opt=opt, gradient_accumulator=gradient_accumulator, batch=batch, gradient_accumulation_steps=train_args.gradient_accumulation_steps, skip_sop=skip_sop, skip_mlm=skip_mlm, ) # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors if i == 0: if hvd.rank() == 0 and loaded_opt_weights is not None: opt.set_weights(loaded_opt_weights) hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(opt.variables(), root_rank=0) i = opt.get_weights()[0] - 1 is_final_step = i >= train_args.total_steps - 1 do_squad = i in squad_steps or is_final_step # Squad requires all the ranks to train, but results are only returned on rank 0 if do_squad: squad_results = get_squad_results_while_pretraining( model=model, model_size=model_args.model_size, fsx_prefix=data_args.fsx_prefix, step=i, fast=log_args.fast_squad, dummy_eval=log_args.dummy_eval, ) if hvd.rank() == 0: squad_exact, squad_f1 = squad_results["exact"], squad_results[ "f1"] logger.info( f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}" ) # Re-wrap autograph so it doesn't get arg mismatches wrap_global_functions(do_gradient_accumulation) if hvd.rank() == 0: do_log = i % log_args.log_frequency == 0 do_checkpoint = ( (i > 0) and (i % log_args.checkpoint_frequency == 0)) or is_final_step do_validation = ( (i > 0) and (i % log_args.validation_frequency == 0)) or is_final_step pbar.update(1) description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}" pbar.set_description(description) if do_log: elapsed_time = time.perf_counter() - start_time if i == 0: logger.info(f"First step: {elapsed_time:.3f} secs") else: it_per_sec = log_args.log_frequency / elapsed_time logger.info( f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}" ) start_time = time.perf_counter() if do_checkpoint: checkpoint_prefix = f"{data_args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}" model_ckpt = f"{checkpoint_prefix}.ckpt" opt_ckpt = f"{checkpoint_prefix}-opt.npy" logger.info( f"Saving model at {model_ckpt}, optimizer at {opt_ckpt}") model.save_weights(model_ckpt) # model.load_weights(model_ckpt) opt_weights = opt.get_weights() np.save(opt_ckpt, opt_weights) # opt.set_weights(opt_weights) if do_validation: val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation( model=model, validation_dataset=validation_dataset, skip_sop=skip_sop, skip_mlm=skip_mlm, ) description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}" logger.info(f"Validation step {i} -- {description}") # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( f"{data_args.fsx_prefix}/logs/albert/{run_name}") with summary_writer.as_default(): HP_MODEL_TYPE = hp.HParam("model_type", hp.Discrete(["albert", "bert"])) HP_MODEL_SIZE = hp.HParam("model_size", hp.Discrete(["base", "large"])) HP_LEARNING_RATE = hp.HParam("learning_rate", hp.RealInterval(1e-5, 1e-1)) HP_BATCH_SIZE = hp.HParam("global_batch_size", hp.IntInterval(1, 64)) HP_PRE_LAYER_NORM = hp.HParam("pre_layer_norm", hp.Discrete([True, False])) HP_HIDDEN_DROPOUT = hp.HParam("hidden_dropout") hparams = [ HP_MODEL_TYPE, HP_MODEL_SIZE, HP_BATCH_SIZE, HP_LEARNING_RATE, HP_PRE_LAYER_NORM, HP_HIDDEN_DROPOUT, ] HP_F1 = hp.Metric("squad_f1") HP_EXACT = hp.Metric("squad_exact") HP_MLM = hp.Metric("val_mlm_acc") HP_SOP = hp.Metric("val_sop_acc") HP_TRAIN_LOSS = hp.Metric("train_loss") HP_VAL_LOSS = hp.Metric("val_loss") metrics = [ HP_TRAIN_LOSS, HP_VAL_LOSS, HP_F1, HP_EXACT, HP_MLM, HP_SOP ] hp.hparams_config( hparams=hparams, metrics=metrics, ) hp.hparams( { HP_MODEL_TYPE: model_args.model_type, HP_MODEL_SIZE: model_args.model_size, HP_LEARNING_RATE: train_args.learning_rate, HP_BATCH_SIZE: train_args.batch_size * hvd.size(), HP_PRE_LAYER_NORM: model_args.pre_layer_norm == "true", HP_HIDDEN_DROPOUT: model_args.hidden_dropout_prob, }, trial_id=run_name, ) # Log to TensorBoard with summary_writer.as_default(): tf.summary.scalar("weight_norm", weight_norm, step=i) tf.summary.scalar("loss_scale", loss_scale, step=i) tf.summary.scalar("learning_rate", learning_rate, step=i) tf.summary.scalar("train_loss", loss, step=i) tf.summary.scalar("train_mlm_loss", mlm_loss, step=i) tf.summary.scalar("train_mlm_acc", mlm_acc, step=i) tf.summary.scalar("train_sop_loss", sop_loss, step=i) tf.summary.scalar("train_sop_acc", sop_acc, step=i) tf.summary.scalar("grad_norm", grad_norm, step=i) if do_validation: tf.summary.scalar("val_loss", val_loss, step=i) tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i) tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i) tf.summary.scalar("val_sop_loss", val_sop_loss, step=i) tf.summary.scalar("val_sop_acc", val_sop_acc, step=i) if do_squad: tf.summary.scalar("squad_f1", squad_f1, step=i) tf.summary.scalar("squad_exact", squad_exact, step=i) i += 1 if is_final_step: break if hvd.rank() == 0: pbar.close() logger.info(f"Finished pretraining, job name {run_name}")
def main( fsx_prefix: str, model_type: str, model_size: str, batch_size: int, max_seq_length: int, gradient_accumulation_steps: int, optimizer: str, name: str, learning_rate: float, end_learning_rate: float, warmup_steps: int, total_steps: int, skip_sop: bool, skip_mlm: bool, pre_layer_norm: bool, fast_squad: bool, dummy_eval: bool, squad_steps: List[int], hidden_dropout_prob: float, ): # Hard-coded values that don't need to be arguments max_predictions_per_seq = 20 log_frequency = 1000 checkpoint_frequency = 5000 validate_frequency = 2000 histogram_frequency = 100 do_gradient_accumulation = gradient_accumulation_steps > 1 if hvd.rank() == 0: # Run name should only be used on one process to avoid race conditions current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") platform = "eks" if args.fsx_prefix == "/fsx" else "sm" if skip_sop: loss_str = "-skipsop" elif skip_mlm: loss_str = "-skipmlm" else: loss_str = "" amp_str = ("-skipamp" if not tf.config.optimizer.get_experimental_options().get( "auto_mixed_precision", False) else "") ln_str = "-preln" if pre_layer_norm else "-postln" dropout_str = f"-{hidden_dropout_prob}dropout" if hidden_dropout_prob != 0 else "" name_str = f"-{name}" if name else "" metadata = f"{model_type}-{model_size}-{args.load_from}-{hvd.size()}gpus-{batch_size}batch-{gradient_accumulation_steps}accum-{learning_rate}lr-{args.max_grad_norm}maxgrad-{optimizer}opt-{total_steps}steps-{max_seq_length}seq{amp_str}{ln_str}{loss_str}{dropout_str}{name_str}" run_name = f"{current_time}-{platform}-{metadata}" # Logging should only happen on a single process # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time level = logging.INFO format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s" handlers = [ logging.FileHandler(f"{fsx_prefix}/logs/albert/{run_name}.log"), logging.StreamHandler(), ] logging.basicConfig(level=level, format=format, handlers=handlers) # Check that arguments passed in properly, only after registering the alert_func and logging assert not (skip_sop and skip_mlm), "Cannot use --skip_sop and --skip_mlm" wrap_global_functions(do_gradient_accumulation) if model_type == "albert": model_desc = f"albert-{model_size}-v2" elif model_type == "bert": model_desc = f"bert-{model_size}-uncased" config = AutoConfig.from_pretrained(model_desc) config.pre_layer_norm = pre_layer_norm config.output_hidden_states = True config.hidden_dropout_prob = hidden_dropout_prob model = TFAutoModelForPreTraining.from_config(config) if args.load_from == "scratch": pass else: assert model_type == "albert", "Only loading pretrained albert models is supported" huggingface_name = f"albert-{model_size}-v2" if args.load_from == "huggingface": albert = TFAlbertModel.from_pretrained(huggingface_name, config=config) model.albert = albert elif args.load_from == "huggingfacepreds": mlm_model = TFAlbertForMaskedLM.from_pretrained(huggingface_name, config=config) model.albert = mlm_model.albert model.cls.predictions = mlm_model.predictions tokenizer = get_tokenizer() schedule = LinearWarmupLinearDecaySchedule( max_learning_rate=learning_rate, end_learning_rate=end_learning_rate, warmup_steps=warmup_steps, total_steps=total_steps, ) if optimizer == "lamb": opt = LAMB( learning_rate=schedule, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], ) elif optimizer == "adam": opt = AdamW(weight_decay=0.0, learning_rate=schedule) opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer( opt, loss_scale="dynamic") gradient_accumulator = GradientAccumulator() # Train filenames are [1, 2047] # Val filenames are [0] # Note the different subdirectories train_glob = f"{fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" validation_glob = f"{fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord" train_filenames = glob.glob(train_glob) validation_filenames = glob.glob(validation_glob) train_dataset = get_mlm_dataset( filenames=train_filenames, max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, ) # Of shape [batch_size, ...] train_dataset = train_dataset.batch( gradient_accumulation_steps ) # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...] # train_dataset = ( # train_dataset.repeat() # ) # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps. train_dataset = train_dataset.prefetch(buffer_size=8) # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks if hvd.rank() == 0: validation_dataset = get_mlm_dataset( filenames=validation_filenames, max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, ) # validation_dataset = validation_dataset.batch(1) validation_dataset = validation_dataset.prefetch(buffer_size=8) pbar = tqdm.tqdm(total_steps) summary_writer = None # Only create a writer if we make it through a successful step if hvd.rank() == 0: logger.info(f"Starting training, job name {run_name}") for i, batch in enumerate(train_dataset): learning_rate = schedule(step=tf.constant(i, dtype=tf.float32)) loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm = train_step( model=model, opt=opt, gradient_accumulator=gradient_accumulator, batch=batch, gradient_accumulation_steps=gradient_accumulation_steps, skip_sop=skip_sop, skip_mlm=skip_mlm, ) # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors if i == 0: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(opt.variables(), root_rank=0) is_final_step = i >= total_steps - 1 do_squad = i in squad_steps or is_final_step # Squad requires all the ranks to train, but results are only returned on rank 0 if do_squad: squad_results = get_squad_results( model=model, model_size=model_size, step=i, fast=fast_squad, dummy_eval=dummy_eval, ) if hvd.rank() == 0: squad_exact, squad_f1 = squad_results["exact"], squad_results[ "f1"] logger.info( f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}" ) # Re-wrap autograph so it doesn't get arg mismatches wrap_global_functions(do_gradient_accumulation) if hvd.rank() == 0: do_log = i % log_frequency == 0 do_checkpoint = (i % checkpoint_frequency == 0) or is_final_step do_validation = (i % validate_frequency == 0) or is_final_step pbar.update(1) description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}" pbar.set_description(description) if do_log: logger.info(f"Train step {i} -- {description}") if do_checkpoint: checkpoint_path = f"{fsx_prefix}/checkpoints/albert/{run_name}-step{i}.ckpt" logger.info(f"Saving checkpoint at {checkpoint_path}") model.save_weights(checkpoint_path) # model.load_weights(checkpoint_path) if do_validation: val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation( model=model, validation_dataset=validation_dataset, skip_sop=skip_sop, skip_mlm=skip_mlm, ) description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}" logger.info(f"Validation step {i} -- {description}") # Create summary_writer after the first step if summary_writer is None: summary_writer = tf.summary.create_file_writer( f"{fsx_prefix}/logs/albert/{run_name}") # Log to TensorBoard weight_norm = tf.math.sqrt( tf.math.reduce_sum([ tf.norm(var, ord=2)**2 for var in model.trainable_variables ])) with summary_writer.as_default(): tf.summary.scalar("weight_norm", weight_norm, step=i) tf.summary.scalar("learning_rate", learning_rate, step=i) tf.summary.scalar("train_loss", loss, step=i) tf.summary.scalar("train_mlm_loss", mlm_loss, step=i) tf.summary.scalar("train_mlm_acc", mlm_acc, step=i) tf.summary.scalar("train_sop_loss", sop_loss, step=i) tf.summary.scalar("train_sop_acc", sop_acc, step=i) tf.summary.scalar("grad_norm", grad_norm, step=i) if do_validation: tf.summary.scalar("val_loss", val_loss, step=i) tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i) tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i) tf.summary.scalar("val_sop_loss", val_sop_loss, step=i) tf.summary.scalar("val_sop_acc", val_sop_acc, step=i) if do_squad: tf.summary.scalar("squad_f1", squad_f1, step=i) tf.summary.scalar("squad_exact", squad_exact, step=i) if is_final_step: break if hvd.rank() == 0: pbar.close() logger.info(f"Finished pretraining, job name {run_name}")