def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments,
         LoggingArguments, PathArguments))
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings
               ) == 0, f"The args {remaining_strings} could not be parsed."

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        if log_args.run_name is None:
            metadata = (
                f"{model_args.model_type}"
                f"-{model_args.model_size}"
                f"-{model_args.load_from}"
                f"-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.learning_rate}maxlr"
                f"-{train_args.learning_rate_decay_power}power"
                f"-{train_args.optimizer}opt"
                f"-{train_args.total_steps}steps"
                f"-{'preln' if pre_layer_norm else 'postln'}"
                f"{loss_str}"
                f"-{model_args.hidden_dropout_prob}dropout")
            run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
        else:
            run_name = log_args.run_name

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                os.path.join(path_args.filesystem_prefix, path_args.log_dir,
                             f"{run_name}.log")),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    # Create optimizer and enable AMP loss scaling.
    if train_args.optimizer == "lamb":
        optimizer = get_lamb_optimizer(train_args)
    elif train_args.optimizer == "adamw":
        optimizer = get_adamw_optimizer(train_args)

    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        optimizer, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_optimizer_weights = None

    model = create_model(model_class=TFAutoModelForPreTraining,
                         model_args=model_args)
    tokenizer = create_tokenizer(model_args.model_type)
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix,
                                       model_args.checkpoint_path)
        model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(
            checkpoint_path)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            if model_args.load_optimizer_state == "true":
                loaded_optimizer_weights = np.load(optimizer_ckpt,
                                                   allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir,
                              "*.tfrecord")
    validation_glob = os.path.join(path_args.filesystem_prefix,
                                   path_args.val_dir, "*.tfrecord")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
    )  # Of shape [per_gpu_batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 1
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = optimizer.learning_rate(
            step=tf.constant(i, dtype=tf.float32))
        # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = optimizer.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            optimizer=optimizer,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 1:
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            i = optimizer.get_weights()[0]

        is_final_step = i >= train_args.total_steps
        do_squad = (log_args.squad_frequency != 0) and (
            (i % log_args.squad_frequency == 0) or is_final_step)
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                tokenizer=tokenizer,
                model_size=model_args.model_size,
                filesystem_prefix=path_args.filesystem_prefix,
                step=i,
                dataset=data_args.squad_version,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)
            gc.collect()

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (log_args.checkpoint_frequency != 0) and (
                (i % log_args.checkpoint_frequency == 0) or is_final_step)
            do_validation = (log_args.validation_frequency != 0) and (
                (i % log_args.validation_frequency == 0) or is_final_step)

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 1:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = os.path.join(path_args.filesystem_prefix,
                                                 path_args.checkpoint_dir,
                                                 f"{run_name}-step{i}")
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                optimizer_weights = optimizer.get_weights()
                np.save(optimizer_ckpt, optimizer_weights)
                # optimizer.set_weights(optimizer_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix,
                                 path_args.log_dir, run_name))
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    "global_batch_size":
                    train_args.per_gpu_batch_size * hvd.size(),
                }
                if is_wandb_available():
                    wandb.init(config=config, project=model_args.model_type)
                    wandb.run.save()
                    wandb_run_name = wandb.run.name

            train_metrics = {
                "weight_norm": weight_norm,
                "grad_norm": grad_norm,
                "loss_scale": loss_scale,
                "learning_rate": learning_rate,
                "train/loss": loss,
                "train/mlm_loss": mlm_loss,
                "train/mlm_acc": mlm_acc,
                "train/sop_loss": sop_loss,
                "train/sop_acc": sop_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_loss,
                    "val/mlm_loss": val_mlm_loss,
                    "val/mlm_acc": val_mlm_acc,
                    "val/sop_loss": val_sop_loss,
                    "val/sop_acc": val_sop_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_squad:
                squad_metrics = {
                    "squad/f1": squad_f1,
                    "squad/exact": squad_exact,
                }
                all_metrics = {**all_metrics, **squad_metrics}

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=i)
            # Log to Weights & Biases
            if is_wandb_available():
                wandb.log({"step": i, **all_metrics})

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")
Exemple #2
0
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments, LoggingArguments, PathArguments)
    )
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings) == 0, f"The args {remaining_strings} could not be parsed."

    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    if train_args.eager == "true":
        tf.config.experimental_run_functions_eagerly(True)

    tokenizer = ElectraTokenizerFast.from_pretrained("bert-base-uncased")

    gen_config = ElectraConfig.from_pretrained(f"google/electra-{model_args.model_size}-generator")
    dis_config = ElectraConfig.from_pretrained(
        f"google/electra-{model_args.model_size}-discriminator"
    )

    gen = TFElectraForMaskedLM(config=gen_config)
    dis = TFElectraForPreTraining(config=dis_config)
    optimizer = get_adamw_optimizer(train_args)

    # Tie the weights
    if model_args.electra_tie_weights == "true":
        gen.electra.embeddings = dis.electra.embeddings

    loaded_optimizer_weights = None
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix, model_args.checkpoint_path)
        dis_ckpt, gen_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(checkpoint_path)
        if hvd.rank() == 0:
            dis.load_weights(dis_ckpt)
            gen.load_weights(gen_ckpt)
            loaded_optimizer_weights = np.load(optimizer_ckpt, allow_pickle=True)

    start_time = time.perf_counter()

    if hvd.rank() == 0:
        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            TqdmLoggingHandler(),
        ]
        summary_writer = None  # Only create a writer if we make it through a successful step
        logging.basicConfig(level=level, format=format, handlers=handlers)
        wandb_run_name = None

        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        if log_args.run_name is None:
            metadata = (
                f"electra-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.total_steps}steps"
            )
            run_name = (
                f"{current_time}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
            )
        else:
            run_name = log_args.run_name

    logger.info(f"Training with dataset at {path_args.train_dir}")
    logger.info(f"Validating with dataset at {path_args.val_dir}")

    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir, "*.tfrecord*")
    validation_glob = os.path.join(path_args.filesystem_prefix, path_args.val_dir, "*.tfrecord*")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)
    logger.info(
        f"Number of train files {len(train_filenames)}, number of validation files {len(validation_filenames)}"
    )

    tf_train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
        max_seq_length=data_args.max_seq_length,
    )

    tf_train_dataset = tf_train_dataset.prefetch(buffer_size=8)

    if hvd.rank() == 0:
        tf_val_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
            max_seq_length=data_args.max_seq_length,
        )
        tf_val_dataset = tf_val_dataset.prefetch(buffer_size=8)

    wandb_run_name = None

    step = 1
    for batch in tf_train_dataset:
        learning_rate = optimizer.learning_rate(step=tf.constant(step, dtype=tf.float32))
        ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        train_result = train_step(
            optimizer=optimizer,
            gen=gen,
            dis=dis,
            ids=ids,
            attention_mask=attention_mask,
            mask_token_id=tokenizer.mask_token_id,
        )

        if step == 1:
            # Horovod broadcast
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(gen.variables, root_rank=0)
            hvd.broadcast_variables(dis.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            step = optimizer.get_weights()[0]

        is_final_step = step >= train_args.total_steps
        if hvd.rank() == 0:
            do_log = step % log_args.log_frequency == 0
            do_checkpoint = (step > 1) and (
                (step % log_args.checkpoint_frequency == 0) or is_final_step
            )
            do_validation = step % log_args.validation_frequency == 0

            if do_log:
                elapsed_time = time.perf_counter() - start_time  # Off for first log
                it_s = log_args.log_frequency / elapsed_time
                start_time = time.perf_counter()
                description = f"Step {step} -- gen_loss: {train_result.gen_loss:.3f}, dis_loss: {train_result.dis_loss:.3f}, gen_acc: {train_result.gen_acc:.3f}, dis_acc: {train_result.dis_acc:.3f}, it/s: {it_s:.3f}\n"
                logger.info(description)

            if do_validation:
                for batch in tf_val_dataset.take(1):
                    val_ids = batch["input_ids"]
                    val_attention_mask = batch["attention_mask"]
                    val_result = val_step(
                        gen=gen,
                        dis=dis,
                        ids=val_ids,
                        attention_mask=val_attention_mask,
                        mask_token_id=tokenizer.mask_token_id,
                    )
                    log_example(
                        tokenizer,
                        val_ids,
                        val_result.masked_ids,
                        val_result.corruption_mask,
                        val_result.gen_ids,
                        val_result.dis_preds,
                    )
                    description = f"VALIDATION, Step {step} -- val_gen_loss: {val_result.gen_loss:.3f}, val_dis_loss: {val_result.dis_loss:.3f}, val_gen_acc: {val_result.gen_acc:.3f}, val_dis_acc: {val_result.dis_acc:.3f}\n"
                    logger.info(description)

            train_metrics = {
                "learning_rate": learning_rate,
                "train/loss": train_result.loss,
                "train/gen_loss": train_result.gen_loss,
                "train/dis_loss": train_result.dis_loss,
                "train/gen_acc": train_result.gen_acc,
                "train/dis_acc": train_result.dis_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_result.loss,
                    "val/gen_loss": val_result.gen_loss,
                    "val/dis_loss": val_result.dis_loss,
                    "val/gen_acc": val_result.gen_acc,
                    "val/dis_acc": val_result.dis_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_log:
                all_metrics = {"it_s": it_s, **all_metrics}

            if is_wandb_available():
                if wandb_run_name is None:
                    config = {
                        **asdict(model_args),
                        **asdict(data_args),
                        **asdict(train_args),
                        **asdict(log_args),
                        **asdict(path_args),
                        "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                        "n_gpus": hvd.size(),
                    }
                    wandb.init(config=config, project="electra")
                    wandb.run.save()
                    wandb_run_name = wandb.run.name
                wandb.log({"step": step, **all_metrics})

                # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix, path_args.log_dir, run_name)
                )
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    **asdict(path_args),
                    "global_batch_size": train_args.per_gpu_batch_size * hvd.size(),
                    "n_gpus": hvd.size(),
                }

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=step)

            if do_checkpoint:
                dis_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-discriminator.ckpt",
                )
                gen_model_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-generator.ckpt",
                )
                optimizer_ckpt = os.path.join(
                    path_args.filesystem_prefix,
                    path_args.checkpoint_dir,
                    f"{run_name}-step{step}-optimizer.npy",
                )
                logger.info(
                    f"Saving discriminator model at {dis_model_ckpt}, generator model at {gen_model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                dis.save_weights(dis_model_ckpt)
                gen.save_weights(gen_model_ckpt)
                np.save(optimizer_ckpt, optimizer.get_weights())

        step += 1
        if is_final_step:
            break