def __init__(
        self,
        model: TFPreTrainedModel,
        args: TFTrainingArguments,
        train_dataset: Optional[tf.data.Dataset] = None,
        eval_dataset: Optional[tf.data.Dataset] = None,
        test_dataset: Optional[tf.data.Dataset] = None,
        dataset_info: Optional[DatasetInfo] = None,
    ):
        self.model = model
        self.args = args
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.test_dataset = test_dataset
        self.dataset_info = dataset_info

        self.gradient_accumulator = GradientAccumulator()
        self.accum_steps = 1

        if self.args.strategy_name == "mirrored":
            self.strategy = tf.distribute.MirroredStrategy()
        elif self.args.strategy_name == "onedevice":
            if len(tf.config.list_physical_devices('GPU')) >= 1:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/gpu:0")
            else:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/cpu:0")
        else:
            raise ValueError("The strategy {} does not exists.".format(
                self.args.strategy_name))

        # To conform with Trainer's API we call this from here.
        # All args should be in the `args` already.
        self._setup_training()
Esempio n. 2
0
    def __init__(self,
                 config_path: str = None,
                 config: TrainerConfig = None,
                 **kwargs):
        """
        The list of keys in kwargs here should be generic to all the possible models/architectures
        and not specific to such or such dataset/task.
        """
        if config and not config_path:
            if not isinstance(config, TrainerConfig):
                raise ValueError(
                    "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                    .format(self.__class__.__name__))
                self.config = config
        elif config_path and not config:
            self.config, unused_kwargs = TrainerConfig.from_trainer(
                config_path,
                return_unused_kwargs=True,
                return_unused_config=True,
                **kwargs)
        else:
            raise ValueError(
                "the config_path and config parameters cannot be both filled or None."
            )

        self.strategy_name: str = unused_kwargs.pop("strategy_name",
                                                    "onedevice")
        self.data_processor_config: Dict = unused_kwargs.pop(
            "data_processor", None)

        assert len(
            unused_kwargs) == 0, "unrecognized params passed: %s" % ",".join(
                unused_kwargs.keys())

        self.datasets: Dict[str, tf.data.Dataset] = {}
        self.dataset_info: DatasetInfo
        self.gradient_accumulator = GradientAccumulator()
        self.accum_steps = 1

        if self.config.mode == "classification":
            self.processor = DataProcessorForSequenceClassification(
                **self.data_processor_config)
            self.model_class = TFAutoModelForSequenceClassification
        elif self.config.mode == "labelling":
            self.processor = DataProcessorForTokenClassification(
                **self.data_processor_config)
            self.model_class = TFAutoModelForTokenClassification

        if self.strategy_name == "mirrored":
            self.strategy = tf.distribute.MirroredStrategy()
        elif self.strategy_name == "onedevice":
            if len(tf.config.list_physical_devices('GPU')) >= 1:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/gpu:0")
            else:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/cpu:0")
        else:
            raise ValueError("The strategy {} does not exists.".format(
                self.strategy_name))
Esempio n. 3
0
    def testGradientAccumulatorDistributionStrategy(self):
        context._context = None
        ops.enable_eager_execution_internal()
        physical_devices = tf.config.list_physical_devices("CPU")
        if len(physical_devices) == 1:
            tf.config.set_logical_device_configuration(physical_devices[0], [
                tf.config.LogicalDeviceConfiguration(),
                tf.config.LogicalDeviceConfiguration()
            ])
        devices = tf.config.list_logical_devices(device_type="CPU")
        strategy = tf.distribute.MirroredStrategy(devices=devices[:2])

        with strategy.scope():
            accumulator = GradientAccumulator()
            variable = tf.Variable([4.0, 3.0])
            optimizer, _ = create_optimizer(5e-5, 10, 5)
            gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)

        def accumulate_on_replica(gradient):
            accumulator([gradient])

        def apply_on_replica():
            optimizer.apply_gradients(
                list(zip(accumulator.gradients, [variable])))

        @tf.function
        def accumulate(grad1, grad2):
            with strategy.scope():
                local_variables = strategy.experimental_local_results(
                    gradient_placeholder)
                local_variables[0].assign(grad1)
                local_variables[1].assign(grad2)
                strategy.experimental_run_v2(accumulate_on_replica,
                                             args=(gradient_placeholder, ))

        @tf.function
        def apply_grad():
            with strategy.scope():
                strategy.experimental_run_v2(apply_on_replica)

        def _check_local_values(grad1, grad2):
            values = strategy.experimental_local_results(
                accumulator._gradients[0])
            self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
            self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)

        accumulate([1.0, 2.0], [-1.0, 1.0])
        accumulate([3.0, -1.0], [-1.0, -1.0])
        accumulate([-2.0, 2.0], [3.0, -2.0])
        self.assertEqual(accumulator.step, 3)
        _check_local_values([2.0, 3.0], [1.0, -2.0])
        apply_grad()
        self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
        accumulator.reset()
        self.assertEqual(accumulator.step, 0)
        _check_local_values([0.0, 0.0], [0.0, 0.0])
 def testGradientAccumulator(self):
     accumulator = GradientAccumulator()
     accumulator([tf.constant([1.0, 2.0])])
     accumulator([tf.constant([-2.0, 1.0])])
     accumulator([tf.constant([-1.0, 2.0])])
     with self.assertRaises(ValueError):
         accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
     self.assertEqual(accumulator.step, 3)
     self.assertEqual(len(accumulator.gradients), 1)
     self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
     accumulator.reset()
     self.assertEqual(accumulator.step, 0)
     self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
Esempio n. 5
0
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples,
          labels, train_batch_size, pad_token_label_id):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args[
            "gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
    else:
        num_train_steps = (math.ceil(num_train_examples / train_batch_size) //
                           args["gradient_accumulation_steps"] *
                           args["num_train_epochs"])

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
        optimizer = create_optimizer(args["learning_rate"], num_train_steps,
                                     args["warmup_steps"])

        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, "dynamic")

        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
        gradient_accumulator = GradientAccumulator()

    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d",
                 args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d",
                 args["gradient_accumulation_steps"])
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients,
                                      model.trainable_variables):
            if gradient is not None:
                scaled_gradient = gradient / (
                    args["n_device"] * args["gradient_accumulation_steps"])
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
            inputs = {
                "attention_mask": train_features["attention_mask"],
                "training": True
            }

            if "token_type_ids" in train_features:
                inputs["token_type_ids"] = train_features["token_type_ids"]

            with tf.GradientTape() as tape:
                logits = model(train_features["input_ids"], **inputs)[0]
                active_loss = tf.reshape(train_labels,
                                         (-1, )) != pad_token_label_id
                active_logits = tf.boolean_mask(
                    tf.reshape(logits, (-1, len(labels))), active_loss)
                active_labels = tf.boolean_mask(
                    tf.reshape(train_labels, (-1, )), active_loss)
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn,
                                                          args=(train_features,
                                                                train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                    per_example_losses,
                                    axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
    train_iterator = master_bar(range(args["num_train_epochs"]))
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
        epoch_iterator = progress_bar(train_dataset,
                                      total=num_train_steps,
                                      parent=train_iterator,
                                      display=args["n_device"] > 1)
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

                if step % args["gradient_accumulation_steps"] == 0:
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args[
                            "logging_steps"] == 0:
                        # Log metrics
                        if (
                                args["n_device"] == 1
                                and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args,
                                strategy,
                                model,
                                tokenizer,
                                labels,
                                pad_token_label_id,
                                mode="dev")
                            report = metrics.classification_report(y_true,
                                                                   y_pred,
                                                                   digits=4)

                            logging.info("Eval at step " + str(global_step) +
                                         "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))

                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss,
                                                  global_step)
                                tf.summary.scalar("precision", precision,
                                                  global_step)
                                tf.summary.scalar("recall", recall,
                                                  global_step)
                                tf.summary.scalar("f1", f1, global_step)

                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) /
                                args["logging_steps"], global_step)

                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss",
                                          loss_metric.result(),
                                          step=step)

                    if args["save_steps"] > 0 and global_step % args[
                            "save_steps"] == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args["output_dir"],
                            "checkpoint-{}".format(global_step))

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)

                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s",
                                     output_dir)

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
                step += 1

        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() -
                                                    current_time))
Esempio n. 6
0
    def _custom_train(
        self,
        train_dataset,
        tokenizer,
        model,
        num_train_examples,
        train_batch_size,
    ):
        config = self.parent.config._asdict()
        config["strategy"] = self.parent.config.strategy
        config["n_device"] = self.parent.config.n_device
        labels = config["ner_tags"]
        if config["max_steps"] > 0:
            num_train_steps = (
                config["max_steps"] * config["gradient_accumulation_steps"]
            )
            config["epochs"] = 1
        else:
            num_train_steps = (
                math.ceil(num_train_examples / train_batch_size)
                // config["gradient_accumulation_steps"]
                * config["epochs"]
            )

        with config["strategy"].scope():
            loss_fct = self.tf.keras.losses.SparseCategoricalCrossentropy(
                reduction=self.tf.keras.losses.Reduction.NONE
            )
            optimizer = create_optimizer(
                config["learning_rate"],
                num_train_steps,
                config["warmup_steps"],
            )

            if config["use_fp16"]:
                optimizer = self.tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    optimizer, "dynamic"
                )

            loss_metric = self.tf.keras.metrics.Mean(
                name="loss", dtype=self.tf.float32
            )
            gradient_accumulator = GradientAccumulator()

        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d", num_train_examples)
        self.logger.info("  Num Epochs = %d", config["epochs"])
        self.logger.info(
            "  Instantaneous batch size per device = %d",
            config["per_device_train_batch_size"],
        )
        self.logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            train_batch_size * config["gradient_accumulation_steps"],
        )
        self.logger.info(
            "  Gradient Accumulation steps = %d",
            config["gradient_accumulation_steps"],
        )
        self.logger.info("  Total training steps = %d", num_train_steps)

        self.logger.debug(model.summary())

        @self.tf.function
        def apply_gradients():
            grads_and_vars = []

            for gradient, variable in zip(
                gradient_accumulator.gradients, model.trainable_variables
            ):
                if gradient is not None:
                    scaled_gradient = gradient / (
                        config["n_device"]
                        * config["gradient_accumulation_steps"]
                    )
                    grads_and_vars.append((scaled_gradient, variable))
                else:
                    grads_and_vars.append((gradient, variable))

            optimizer.apply_gradients(grads_and_vars, config["max_grad_norm"])
            gradient_accumulator.reset()

        @self.tf.function
        def train_step(train_features, train_labels):
            def step_fn(train_features, train_labels):
                inputs = {
                    "attention_mask": train_features["input_mask"],
                    "training": True,
                }

                if config["model_architecture_type"] != "distilbert":
                    inputs["token_type_ids"] = (
                        train_features["segment_ids"]
                        if config["model_architecture_type"]
                        in ["bert", "xlnet"]
                        else None
                    )

                with self.tf.GradientTape() as tape:
                    logits = model(train_features["input_ids"], **inputs)[0]
                    logits = self.tf.reshape(logits, (-1, len(labels) + 1))
                    active_loss = self.tf.reshape(
                        train_features["input_mask"], (-1,)
                    )
                    active_logits = self.tf.boolean_mask(logits, active_loss)
                    train_labels = self.tf.reshape(train_labels, (-1,))
                    active_labels = self.tf.boolean_mask(
                        train_labels, active_loss
                    )
                    cross_entropy = loss_fct(active_labels, active_logits)
                    loss = self.tf.reduce_sum(cross_entropy) * (
                        1.0 / train_batch_size
                    )
                    grads = tape.gradient(loss, model.trainable_variables)
                    print(grads)

                    gradient_accumulator(grads)

                return cross_entropy

            per_example_losses = config["strategy"].run(
                step_fn, args=(train_features, train_labels)
            )
            mean_loss = config["strategy"].reduce(
                self.tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0
            )

            return mean_loss

        current_time = datetime.datetime.now()
        train_iterator = master_bar(range(config["epochs"]))
        global_step = 0
        self.logger_loss = 0.0

        for epoch in train_iterator:
            epoch_iterator = progress_bar(
                train_dataset,
                total=num_train_steps,
                parent=train_iterator,
                display=config["n_device"] > 1,
            )
            step = 1

            with config["strategy"].scope():
                for train_features, train_labels in epoch_iterator:
                    loss = train_step(train_features, train_labels)

                    if step % config["gradient_accumulation_steps"] == 0:
                        config["strategy"].run(apply_gradients)
                        loss_metric(loss)
                        global_step += 1
                        if (
                            config["save_steps"] > 0
                            and global_step % config["save_steps"] == 0
                        ):
                            # Save model checkpoint
                            output_dir = os.path.join(
                                config["output_dir"],
                                "checkpoint-{}".format(global_step),
                            )

                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)

                            model.save_pretrained(output_dir)
                            self.logger.info(
                                "Saving model checkpoint to %s", output_dir
                            )

                    train_iterator.child.comment = (
                        f"loss : {loss_metric.result()}"
                    )
                    step += 1

            train_iterator.write(
                f"loss epoch {epoch + 1}: {loss_metric.result()}"
            )
            loss_metric.reset_states()
        self.logger.debug(
            "  Training took time = {}".format(
                datetime.datetime.now() - current_time
            )
        )
def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments,
                               TrainingArguments, LoggingArguments))
    model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses(
    )

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    squad_steps = get_squad_steps(log_args.extra_squad_steps)
    is_sagemaker = data_args.fsx_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        metadata = (f"{model_args.model_type}"
                    f"-{model_args.model_size}"
                    f"-{model_args.load_from}"
                    f"-{hvd.size()}gpus"
                    f"-{train_args.batch_size}batch"
                    f"-{train_args.gradient_accumulation_steps}accum"
                    f"-{train_args.learning_rate}maxlr"
                    f"-{train_args.end_learning_rate}endlr"
                    f"-{train_args.learning_rate_decay_power}power"
                    f"-{train_args.max_grad_norm}maxgrad"
                    f"-{train_args.optimizer}opt"
                    f"-{train_args.total_steps}steps"
                    f"-{data_args.max_seq_length}seq"
                    f"-{data_args.max_predictions_per_seq}preds"
                    f"-{'preln' if pre_layer_norm else 'postln'}"
                    f"{loss_str}"
                    f"-{model_args.hidden_dropout_prob}dropout"
                    f"-{train_args.seed}seed")
        run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                f"{data_args.fsx_prefix}/logs/albert/{run_name}.log"),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    if model_args.model_type == "albert":
        model_desc = f"albert-{model_args.model_size}-v2"
    elif model_args.model_type == "bert":
        model_desc = f"bert-{model_args.model_size}-uncased"

    config = AutoConfig.from_pretrained(model_desc)
    config.pre_layer_norm = pre_layer_norm
    config.hidden_dropout_prob = model_args.hidden_dropout_prob
    model = TFAutoModelForPreTraining.from_config(config)

    # Create optimizer and enable AMP loss scaling.
    schedule = LinearWarmupPolyDecaySchedule(
        max_learning_rate=train_args.learning_rate,
        end_learning_rate=train_args.end_learning_rate,
        warmup_steps=train_args.warmup_steps,
        total_steps=train_args.total_steps,
        power=train_args.learning_rate_decay_power,
    )
    if train_args.optimizer == "lamb":
        opt = LAMB(
            learning_rate=schedule,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        )
    elif train_args.optimizer == "adam":
        opt = AdamW(weight_decay=0.0, learning_rate=schedule)
    opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        opt, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_opt_weights = None
    if model_args.load_from == "scratch":
        pass
    elif model_args.load_from.startswith("huggingface"):
        assert (model_args.model_type == "albert"
                ), "Only loading pretrained albert models is supported"
        huggingface_name = f"albert-{model_args.model_size}-v2"
        if model_args.load_from == "huggingface":
            albert = TFAlbertModel.from_pretrained(huggingface_name,
                                                   config=config)
            model.albert = albert
    else:
        model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix(
            model_args.checkpoint_path)

        model = TFAutoModelForPreTraining.from_config(config)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            loaded_opt_weights = np.load(opt_ckpt, allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    if model_args.model_type == "albert":
        train_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
    if model_args.model_type == "bert":
        train_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/training/*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/validation/*.tfrecord"

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_mlm_dataset(
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        batch_size=train_args.batch_size,
    )  # Of shape [batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_mlm_dataset(
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            batch_size=train_args.batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 0
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = opt.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            opt=opt,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 0:
            if hvd.rank() == 0 and loaded_opt_weights is not None:
                opt.set_weights(loaded_opt_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(opt.variables(), root_rank=0)
            i = opt.get_weights()[0] - 1

        is_final_step = i >= train_args.total_steps - 1
        do_squad = i in squad_steps or is_final_step
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                model_size=model_args.model_size,
                fsx_prefix=data_args.fsx_prefix,
                step=i,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (
                (i > 0) and
                (i % log_args.checkpoint_frequency == 0)) or is_final_step
            do_validation = (
                (i > 0) and
                (i % log_args.validation_frequency == 0)) or is_final_step

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 0:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = f"{data_args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}"
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                opt_ckpt = f"{checkpoint_prefix}-opt.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {opt_ckpt}")
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                opt_weights = opt.get_weights()
                np.save(opt_ckpt, opt_weights)
                # opt.set_weights(opt_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    f"{data_args.fsx_prefix}/logs/albert/{run_name}")
                with summary_writer.as_default():
                    HP_MODEL_TYPE = hp.HParam("model_type",
                                              hp.Discrete(["albert", "bert"]))
                    HP_MODEL_SIZE = hp.HParam("model_size",
                                              hp.Discrete(["base", "large"]))
                    HP_LEARNING_RATE = hp.HParam("learning_rate",
                                                 hp.RealInterval(1e-5, 1e-1))
                    HP_BATCH_SIZE = hp.HParam("global_batch_size",
                                              hp.IntInterval(1, 64))
                    HP_PRE_LAYER_NORM = hp.HParam("pre_layer_norm",
                                                  hp.Discrete([True, False]))
                    HP_HIDDEN_DROPOUT = hp.HParam("hidden_dropout")
                    hparams = [
                        HP_MODEL_TYPE,
                        HP_MODEL_SIZE,
                        HP_BATCH_SIZE,
                        HP_LEARNING_RATE,
                        HP_PRE_LAYER_NORM,
                        HP_HIDDEN_DROPOUT,
                    ]

                    HP_F1 = hp.Metric("squad_f1")
                    HP_EXACT = hp.Metric("squad_exact")
                    HP_MLM = hp.Metric("val_mlm_acc")
                    HP_SOP = hp.Metric("val_sop_acc")
                    HP_TRAIN_LOSS = hp.Metric("train_loss")
                    HP_VAL_LOSS = hp.Metric("val_loss")
                    metrics = [
                        HP_TRAIN_LOSS, HP_VAL_LOSS, HP_F1, HP_EXACT, HP_MLM,
                        HP_SOP
                    ]

                    hp.hparams_config(
                        hparams=hparams,
                        metrics=metrics,
                    )
                    hp.hparams(
                        {
                            HP_MODEL_TYPE: model_args.model_type,
                            HP_MODEL_SIZE: model_args.model_size,
                            HP_LEARNING_RATE: train_args.learning_rate,
                            HP_BATCH_SIZE: train_args.batch_size * hvd.size(),
                            HP_PRE_LAYER_NORM: model_args.pre_layer_norm
                            == "true",
                            HP_HIDDEN_DROPOUT: model_args.hidden_dropout_prob,
                        },
                        trial_id=run_name,
                    )

            # Log to TensorBoard
            with summary_writer.as_default():
                tf.summary.scalar("weight_norm", weight_norm, step=i)
                tf.summary.scalar("loss_scale", loss_scale, step=i)
                tf.summary.scalar("learning_rate", learning_rate, step=i)
                tf.summary.scalar("train_loss", loss, step=i)
                tf.summary.scalar("train_mlm_loss", mlm_loss, step=i)
                tf.summary.scalar("train_mlm_acc", mlm_acc, step=i)
                tf.summary.scalar("train_sop_loss", sop_loss, step=i)
                tf.summary.scalar("train_sop_acc", sop_acc, step=i)
                tf.summary.scalar("grad_norm", grad_norm, step=i)
                if do_validation:
                    tf.summary.scalar("val_loss", val_loss, step=i)
                    tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i)
                    tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i)
                    tf.summary.scalar("val_sop_loss", val_sop_loss, step=i)
                    tf.summary.scalar("val_sop_acc", val_sop_acc, step=i)
                if do_squad:
                    tf.summary.scalar("squad_f1", squad_f1, step=i)
                    tf.summary.scalar("squad_exact", squad_exact, step=i)

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments,
         LoggingArguments, PathArguments))
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings
               ) == 0, f"The args {remaining_strings} could not be parsed."

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        if log_args.run_name is None:
            metadata = (
                f"{model_args.model_type}"
                f"-{model_args.model_size}"
                f"-{model_args.load_from}"
                f"-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.learning_rate}maxlr"
                f"-{train_args.learning_rate_decay_power}power"
                f"-{train_args.optimizer}opt"
                f"-{train_args.total_steps}steps"
                f"-{'preln' if pre_layer_norm else 'postln'}"
                f"{loss_str}"
                f"-{model_args.hidden_dropout_prob}dropout")
            run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
        else:
            run_name = log_args.run_name

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                os.path.join(path_args.filesystem_prefix, path_args.log_dir,
                             f"{run_name}.log")),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    # Create optimizer and enable AMP loss scaling.
    if train_args.optimizer == "lamb":
        optimizer = get_lamb_optimizer(train_args)
    elif train_args.optimizer == "adamw":
        optimizer = get_adamw_optimizer(train_args)

    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        optimizer, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_optimizer_weights = None

    model = create_model(model_class=TFAutoModelForPreTraining,
                         model_args=model_args)
    tokenizer = create_tokenizer(model_args.model_type)
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix,
                                       model_args.checkpoint_path)
        model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(
            checkpoint_path)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            if model_args.load_optimizer_state == "true":
                loaded_optimizer_weights = np.load(optimizer_ckpt,
                                                   allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir,
                              "*.tfrecord")
    validation_glob = os.path.join(path_args.filesystem_prefix,
                                   path_args.val_dir, "*.tfrecord")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
    )  # Of shape [per_gpu_batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 1
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = optimizer.learning_rate(
            step=tf.constant(i, dtype=tf.float32))
        # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = optimizer.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            optimizer=optimizer,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 1:
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            i = optimizer.get_weights()[0]

        is_final_step = i >= train_args.total_steps
        do_squad = (log_args.squad_frequency != 0) and (
            (i % log_args.squad_frequency == 0) or is_final_step)
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                tokenizer=tokenizer,
                model_size=model_args.model_size,
                filesystem_prefix=path_args.filesystem_prefix,
                step=i,
                dataset=data_args.squad_version,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)
            gc.collect()

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (log_args.checkpoint_frequency != 0) and (
                (i % log_args.checkpoint_frequency == 0) or is_final_step)
            do_validation = (log_args.validation_frequency != 0) and (
                (i % log_args.validation_frequency == 0) or is_final_step)

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 1:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = os.path.join(path_args.filesystem_prefix,
                                                 path_args.checkpoint_dir,
                                                 f"{run_name}-step{i}")
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                optimizer_weights = optimizer.get_weights()
                np.save(optimizer_ckpt, optimizer_weights)
                # optimizer.set_weights(optimizer_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix,
                                 path_args.log_dir, run_name))
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    "global_batch_size":
                    train_args.per_gpu_batch_size * hvd.size(),
                }
                if is_wandb_available():
                    wandb.init(config=config, project=model_args.model_type)
                    wandb.run.save()
                    wandb_run_name = wandb.run.name

            train_metrics = {
                "weight_norm": weight_norm,
                "grad_norm": grad_norm,
                "loss_scale": loss_scale,
                "learning_rate": learning_rate,
                "train/loss": loss,
                "train/mlm_loss": mlm_loss,
                "train/mlm_acc": mlm_acc,
                "train/sop_loss": sop_loss,
                "train/sop_acc": sop_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_loss,
                    "val/mlm_loss": val_mlm_loss,
                    "val/mlm_acc": val_mlm_acc,
                    "val/sop_loss": val_sop_loss,
                    "val/sop_acc": val_sop_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_squad:
                squad_metrics = {
                    "squad/f1": squad_f1,
                    "squad/exact": squad_exact,
                }
                all_metrics = {**all_metrics, **squad_metrics}

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=i)
            # Log to Weights & Biases
            if is_wandb_available():
                wandb.log({"step": i, **all_metrics})

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")