def load_qa_from_pretrained(
    model: Optional[tf.keras.Model] = None,
    name: Optional[str] = None,
    path: Optional[str] = None,  # path to checkpoint from TF...ForPreTraining
    config: Optional[PretrainedConfig] = None,
) -> tf.keras.Model:
    """
    Load a TF...QuestionAnswering model by taking the main layer of a pretrained model.
    Preserves the model.config attribute.
    """
    assert (bool(name) ^ bool(model) ^ (bool(path) and bool(config))
            ), "Pass either name, model, or (path and config)"

    if name is not None:
        return TFAutoModelForQuestionAnswering.from_pretrained(name)
    elif model is not None:
        pretrained_model = model
    elif path is not None:
        pretrained_model = TFAutoModelForPreTraining.from_config(config)
        pretrained_model.load_weights(path)

    qa_model = TFAutoModelForQuestionAnswering.from_config(
        pretrained_model.config)
    pretrained_main_layer = getattr(pretrained_model,
                                    qa_model.base_model_prefix)
    assert (
        pretrained_main_layer is not None
    ), f"{pretrained_model} has no attribute '{model.base_model_prefix}'"
    # Generalized way of saying `model.albert = pretrained_model.albert`
    setattr(qa_model, qa_model.base_model_prefix, pretrained_main_layer)
    return qa_model
def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments,
                               TrainingArguments, LoggingArguments))
    model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses(
    )

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    squad_steps = get_squad_steps(log_args.extra_squad_steps)
    is_sagemaker = data_args.fsx_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        metadata = (f"{model_args.model_type}"
                    f"-{model_args.model_size}"
                    f"-{model_args.load_from}"
                    f"-{hvd.size()}gpus"
                    f"-{train_args.batch_size}batch"
                    f"-{train_args.gradient_accumulation_steps}accum"
                    f"-{train_args.learning_rate}maxlr"
                    f"-{train_args.end_learning_rate}endlr"
                    f"-{train_args.learning_rate_decay_power}power"
                    f"-{train_args.max_grad_norm}maxgrad"
                    f"-{train_args.optimizer}opt"
                    f"-{train_args.total_steps}steps"
                    f"-{data_args.max_seq_length}seq"
                    f"-{data_args.max_predictions_per_seq}preds"
                    f"-{'preln' if pre_layer_norm else 'postln'}"
                    f"{loss_str}"
                    f"-{model_args.hidden_dropout_prob}dropout"
                    f"-{train_args.seed}seed")
        run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                f"{data_args.fsx_prefix}/logs/albert/{run_name}.log"),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    if model_args.model_type == "albert":
        model_desc = f"albert-{model_args.model_size}-v2"
    elif model_args.model_type == "bert":
        model_desc = f"bert-{model_args.model_size}-uncased"

    config = AutoConfig.from_pretrained(model_desc)
    config.pre_layer_norm = pre_layer_norm
    config.hidden_dropout_prob = model_args.hidden_dropout_prob
    model = TFAutoModelForPreTraining.from_config(config)

    # Create optimizer and enable AMP loss scaling.
    schedule = LinearWarmupPolyDecaySchedule(
        max_learning_rate=train_args.learning_rate,
        end_learning_rate=train_args.end_learning_rate,
        warmup_steps=train_args.warmup_steps,
        total_steps=train_args.total_steps,
        power=train_args.learning_rate_decay_power,
    )
    if train_args.optimizer == "lamb":
        opt = LAMB(
            learning_rate=schedule,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        )
    elif train_args.optimizer == "adam":
        opt = AdamW(weight_decay=0.0, learning_rate=schedule)
    opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        opt, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_opt_weights = None
    if model_args.load_from == "scratch":
        pass
    elif model_args.load_from.startswith("huggingface"):
        assert (model_args.model_type == "albert"
                ), "Only loading pretrained albert models is supported"
        huggingface_name = f"albert-{model_args.model_size}-v2"
        if model_args.load_from == "huggingface":
            albert = TFAlbertModel.from_pretrained(huggingface_name,
                                                   config=config)
            model.albert = albert
    else:
        model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix(
            model_args.checkpoint_path)

        model = TFAutoModelForPreTraining.from_config(config)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            loaded_opt_weights = np.load(opt_ckpt, allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    if model_args.model_type == "albert":
        train_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
    if model_args.model_type == "bert":
        train_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/training/*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/validation/*.tfrecord"

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_mlm_dataset(
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        batch_size=train_args.batch_size,
    )  # Of shape [batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_mlm_dataset(
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            batch_size=train_args.batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 0
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = opt.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            opt=opt,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 0:
            if hvd.rank() == 0 and loaded_opt_weights is not None:
                opt.set_weights(loaded_opt_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(opt.variables(), root_rank=0)
            i = opt.get_weights()[0] - 1

        is_final_step = i >= train_args.total_steps - 1
        do_squad = i in squad_steps or is_final_step
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                model_size=model_args.model_size,
                fsx_prefix=data_args.fsx_prefix,
                step=i,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (
                (i > 0) and
                (i % log_args.checkpoint_frequency == 0)) or is_final_step
            do_validation = (
                (i > 0) and
                (i % log_args.validation_frequency == 0)) or is_final_step

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 0:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = f"{data_args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}"
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                opt_ckpt = f"{checkpoint_prefix}-opt.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {opt_ckpt}")
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                opt_weights = opt.get_weights()
                np.save(opt_ckpt, opt_weights)
                # opt.set_weights(opt_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    f"{data_args.fsx_prefix}/logs/albert/{run_name}")
                with summary_writer.as_default():
                    HP_MODEL_TYPE = hp.HParam("model_type",
                                              hp.Discrete(["albert", "bert"]))
                    HP_MODEL_SIZE = hp.HParam("model_size",
                                              hp.Discrete(["base", "large"]))
                    HP_LEARNING_RATE = hp.HParam("learning_rate",
                                                 hp.RealInterval(1e-5, 1e-1))
                    HP_BATCH_SIZE = hp.HParam("global_batch_size",
                                              hp.IntInterval(1, 64))
                    HP_PRE_LAYER_NORM = hp.HParam("pre_layer_norm",
                                                  hp.Discrete([True, False]))
                    HP_HIDDEN_DROPOUT = hp.HParam("hidden_dropout")
                    hparams = [
                        HP_MODEL_TYPE,
                        HP_MODEL_SIZE,
                        HP_BATCH_SIZE,
                        HP_LEARNING_RATE,
                        HP_PRE_LAYER_NORM,
                        HP_HIDDEN_DROPOUT,
                    ]

                    HP_F1 = hp.Metric("squad_f1")
                    HP_EXACT = hp.Metric("squad_exact")
                    HP_MLM = hp.Metric("val_mlm_acc")
                    HP_SOP = hp.Metric("val_sop_acc")
                    HP_TRAIN_LOSS = hp.Metric("train_loss")
                    HP_VAL_LOSS = hp.Metric("val_loss")
                    metrics = [
                        HP_TRAIN_LOSS, HP_VAL_LOSS, HP_F1, HP_EXACT, HP_MLM,
                        HP_SOP
                    ]

                    hp.hparams_config(
                        hparams=hparams,
                        metrics=metrics,
                    )
                    hp.hparams(
                        {
                            HP_MODEL_TYPE: model_args.model_type,
                            HP_MODEL_SIZE: model_args.model_size,
                            HP_LEARNING_RATE: train_args.learning_rate,
                            HP_BATCH_SIZE: train_args.batch_size * hvd.size(),
                            HP_PRE_LAYER_NORM: model_args.pre_layer_norm
                            == "true",
                            HP_HIDDEN_DROPOUT: model_args.hidden_dropout_prob,
                        },
                        trial_id=run_name,
                    )

            # Log to TensorBoard
            with summary_writer.as_default():
                tf.summary.scalar("weight_norm", weight_norm, step=i)
                tf.summary.scalar("loss_scale", loss_scale, step=i)
                tf.summary.scalar("learning_rate", learning_rate, step=i)
                tf.summary.scalar("train_loss", loss, step=i)
                tf.summary.scalar("train_mlm_loss", mlm_loss, step=i)
                tf.summary.scalar("train_mlm_acc", mlm_acc, step=i)
                tf.summary.scalar("train_sop_loss", sop_loss, step=i)
                tf.summary.scalar("train_sop_acc", sop_acc, step=i)
                tf.summary.scalar("grad_norm", grad_norm, step=i)
                if do_validation:
                    tf.summary.scalar("val_loss", val_loss, step=i)
                    tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i)
                    tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i)
                    tf.summary.scalar("val_sop_loss", val_sop_loss, step=i)
                    tf.summary.scalar("val_sop_acc", val_sop_acc, step=i)
                if do_squad:
                    tf.summary.scalar("squad_f1", squad_f1, step=i)
                    tf.summary.scalar("squad_exact", squad_exact, step=i)

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")
Exemple #3
0
def main(
    fsx_prefix: str,
    model_type: str,
    model_size: str,
    batch_size: int,
    max_seq_length: int,
    gradient_accumulation_steps: int,
    optimizer: str,
    name: str,
    learning_rate: float,
    end_learning_rate: float,
    warmup_steps: int,
    total_steps: int,
    skip_sop: bool,
    skip_mlm: bool,
    pre_layer_norm: bool,
    fast_squad: bool,
    dummy_eval: bool,
    squad_steps: List[int],
    hidden_dropout_prob: float,
):
    # Hard-coded values that don't need to be arguments
    max_predictions_per_seq = 20
    log_frequency = 1000
    checkpoint_frequency = 5000
    validate_frequency = 2000
    histogram_frequency = 100
    do_gradient_accumulation = gradient_accumulation_steps > 1

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "eks" if args.fsx_prefix == "/fsx" else "sm"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        amp_str = ("-skipamp"
                   if not tf.config.optimizer.get_experimental_options().get(
                       "auto_mixed_precision", False) else "")
        ln_str = "-preln" if pre_layer_norm else "-postln"
        dropout_str = f"-{hidden_dropout_prob}dropout" if hidden_dropout_prob != 0 else ""
        name_str = f"-{name}" if name else ""
        metadata = f"{model_type}-{model_size}-{args.load_from}-{hvd.size()}gpus-{batch_size}batch-{gradient_accumulation_steps}accum-{learning_rate}lr-{args.max_grad_norm}maxgrad-{optimizer}opt-{total_steps}steps-{max_seq_length}seq{amp_str}{ln_str}{loss_str}{dropout_str}{name_str}"
        run_name = f"{current_time}-{platform}-{metadata}"

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(f"{fsx_prefix}/logs/albert/{run_name}.log"),
            logging.StreamHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    if model_type == "albert":
        model_desc = f"albert-{model_size}-v2"
    elif model_type == "bert":
        model_desc = f"bert-{model_size}-uncased"

    config = AutoConfig.from_pretrained(model_desc)
    config.pre_layer_norm = pre_layer_norm
    config.output_hidden_states = True
    config.hidden_dropout_prob = hidden_dropout_prob
    model = TFAutoModelForPreTraining.from_config(config)

    if args.load_from == "scratch":
        pass
    else:
        assert model_type == "albert", "Only loading pretrained albert models is supported"
        huggingface_name = f"albert-{model_size}-v2"
        if args.load_from == "huggingface":
            albert = TFAlbertModel.from_pretrained(huggingface_name,
                                                   config=config)
            model.albert = albert
        elif args.load_from == "huggingfacepreds":
            mlm_model = TFAlbertForMaskedLM.from_pretrained(huggingface_name,
                                                            config=config)
            model.albert = mlm_model.albert
            model.cls.predictions = mlm_model.predictions

    tokenizer = get_tokenizer()
    schedule = LinearWarmupLinearDecaySchedule(
        max_learning_rate=learning_rate,
        end_learning_rate=end_learning_rate,
        warmup_steps=warmup_steps,
        total_steps=total_steps,
    )
    if optimizer == "lamb":
        opt = LAMB(
            learning_rate=schedule,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        )
    elif optimizer == "adam":
        opt = AdamW(weight_decay=0.0, learning_rate=schedule)

    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
        opt, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    # Train filenames are [1, 2047]
    # Val filenames are [0]
    # Note the different subdirectories
    train_glob = f"{fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
    validation_glob = f"{fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{max_seq_length}_max_predictions_per_seq_{max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_mlm_dataset(
        filenames=train_filenames,
        max_seq_length=max_seq_length,
        max_predictions_per_seq=max_predictions_per_seq,
        batch_size=batch_size,
    )  # Of shape [batch_size, ...]
    train_dataset = train_dataset.batch(
        gradient_accumulation_steps
    )  # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...]
    # train_dataset = (
    #    train_dataset.repeat()
    # )  # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_mlm_dataset(
            filenames=validation_filenames,
            max_seq_length=max_seq_length,
            max_predictions_per_seq=max_predictions_per_seq,
            batch_size=batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(total_steps)
        summary_writer = None  # Only create a writer if we make it through a successful step

    if hvd.rank() == 0:
        logger.info(f"Starting training, job name {run_name}")

    for i, batch in enumerate(train_dataset):
        learning_rate = schedule(step=tf.constant(i, dtype=tf.float32))
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm = train_step(
            model=model,
            opt=opt,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 0:
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(opt.variables(), root_rank=0)

        is_final_step = i >= total_steps - 1
        do_squad = i in squad_steps or is_final_step
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results(
                model=model,
                model_size=model_size,
                step=i,
                fast=fast_squad,
                dummy_eval=dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)

        if hvd.rank() == 0:
            do_log = i % log_frequency == 0
            do_checkpoint = (i % checkpoint_frequency == 0) or is_final_step
            do_validation = (i % validate_frequency == 0) or is_final_step

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                logger.info(f"Train step {i} -- {description}")

            if do_checkpoint:
                checkpoint_path = f"{fsx_prefix}/checkpoints/albert/{run_name}-step{i}.ckpt"
                logger.info(f"Saving checkpoint at {checkpoint_path}")
                model.save_weights(checkpoint_path)
                # model.load_weights(checkpoint_path)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    f"{fsx_prefix}/logs/albert/{run_name}")
            # Log to TensorBoard
            weight_norm = tf.math.sqrt(
                tf.math.reduce_sum([
                    tf.norm(var, ord=2)**2 for var in model.trainable_variables
                ]))
            with summary_writer.as_default():
                tf.summary.scalar("weight_norm", weight_norm, step=i)
                tf.summary.scalar("learning_rate", learning_rate, step=i)
                tf.summary.scalar("train_loss", loss, step=i)
                tf.summary.scalar("train_mlm_loss", mlm_loss, step=i)
                tf.summary.scalar("train_mlm_acc", mlm_acc, step=i)
                tf.summary.scalar("train_sop_loss", sop_loss, step=i)
                tf.summary.scalar("train_sop_acc", sop_acc, step=i)
                tf.summary.scalar("grad_norm", grad_norm, step=i)
                if do_validation:
                    tf.summary.scalar("val_loss", val_loss, step=i)
                    tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i)
                    tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i)
                    tf.summary.scalar("val_sop_loss", val_sop_loss, step=i)
                    tf.summary.scalar("val_sop_acc", val_sop_acc, step=i)
                if do_squad:
                    tf.summary.scalar("squad_f1", squad_f1, step=i)
                    tf.summary.scalar("squad_exact", squad_exact, step=i)

        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")