def __init__(
        self,
        model: TFPreTrainedModel,
        args: TFTrainingArguments,
        train_dataset: Optional[tf.data.Dataset] = None,
        eval_dataset: Optional[tf.data.Dataset] = None,
        test_dataset: Optional[tf.data.Dataset] = None,
        dataset_info: Optional[DatasetInfo] = None,
    ):
        self.model = model
        self.args = args
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.test_dataset = test_dataset
        self.dataset_info = dataset_info

        self.gradient_accumulator = GradientAccumulator()
        self.accum_steps = 1

        if self.args.strategy_name == "mirrored":
            self.strategy = tf.distribute.MirroredStrategy()
        elif self.args.strategy_name == "onedevice":
            if len(tf.config.list_physical_devices('GPU')) >= 1:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/gpu:0")
            else:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/cpu:0")
        else:
            raise ValueError("The strategy {} does not exists.".format(
                self.args.strategy_name))

        # To conform with Trainer's API we call this from here.
        # All args should be in the `args` already.
        self._setup_training()
예제 #2
0
    def __init__(self,
                 config_path: str = None,
                 config: TrainerConfig = None,
                 **kwargs):
        """
        The list of keys in kwargs here should be generic to all the possible models/architectures
        and not specific to such or such dataset/task.
        """
        if config and not config_path:
            if not isinstance(config, TrainerConfig):
                raise ValueError(
                    "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                    .format(self.__class__.__name__))
                self.config = config
        elif config_path and not config:
            self.config, unused_kwargs = TrainerConfig.from_trainer(
                config_path,
                return_unused_kwargs=True,
                return_unused_config=True,
                **kwargs)
        else:
            raise ValueError(
                "the config_path and config parameters cannot be both filled or None."
            )

        self.strategy_name: str = unused_kwargs.pop("strategy_name",
                                                    "onedevice")
        self.data_processor_config: Dict = unused_kwargs.pop(
            "data_processor", None)

        assert len(
            unused_kwargs) == 0, "unrecognized params passed: %s" % ",".join(
                unused_kwargs.keys())

        self.datasets: Dict[str, tf.data.Dataset] = {}
        self.dataset_info: DatasetInfo
        self.gradient_accumulator = GradientAccumulator()
        self.accum_steps = 1

        if self.config.mode == "classification":
            self.processor = DataProcessorForSequenceClassification(
                **self.data_processor_config)
            self.model_class = TFAutoModelForSequenceClassification
        elif self.config.mode == "labelling":
            self.processor = DataProcessorForTokenClassification(
                **self.data_processor_config)
            self.model_class = TFAutoModelForTokenClassification

        if self.strategy_name == "mirrored":
            self.strategy = tf.distribute.MirroredStrategy()
        elif self.strategy_name == "onedevice":
            if len(tf.config.list_physical_devices('GPU')) >= 1:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/gpu:0")
            else:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/cpu:0")
        else:
            raise ValueError("The strategy {} does not exists.".format(
                self.strategy_name))
예제 #3
0
    def testGradientAccumulatorDistributionStrategy(self):
        context._context = None
        ops.enable_eager_execution_internal()
        physical_devices = tf.config.list_physical_devices("CPU")
        if len(physical_devices) == 1:
            tf.config.set_logical_device_configuration(physical_devices[0], [
                tf.config.LogicalDeviceConfiguration(),
                tf.config.LogicalDeviceConfiguration()
            ])
        devices = tf.config.list_logical_devices(device_type="CPU")
        strategy = tf.distribute.MirroredStrategy(devices=devices[:2])

        with strategy.scope():
            accumulator = GradientAccumulator()
            variable = tf.Variable([4.0, 3.0])
            optimizer, _ = create_optimizer(5e-5, 10, 5)
            gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)

        def accumulate_on_replica(gradient):
            accumulator([gradient])

        def apply_on_replica():
            optimizer.apply_gradients(
                list(zip(accumulator.gradients, [variable])))

        @tf.function
        def accumulate(grad1, grad2):
            with strategy.scope():
                local_variables = strategy.experimental_local_results(
                    gradient_placeholder)
                local_variables[0].assign(grad1)
                local_variables[1].assign(grad2)
                strategy.experimental_run_v2(accumulate_on_replica,
                                             args=(gradient_placeholder, ))

        @tf.function
        def apply_grad():
            with strategy.scope():
                strategy.experimental_run_v2(apply_on_replica)

        def _check_local_values(grad1, grad2):
            values = strategy.experimental_local_results(
                accumulator._gradients[0])
            self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
            self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)

        accumulate([1.0, 2.0], [-1.0, 1.0])
        accumulate([3.0, -1.0], [-1.0, -1.0])
        accumulate([-2.0, 2.0], [3.0, -2.0])
        self.assertEqual(accumulator.step, 3)
        _check_local_values([2.0, 3.0], [1.0, -2.0])
        apply_grad()
        self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
        accumulator.reset()
        self.assertEqual(accumulator.step, 0)
        _check_local_values([0.0, 0.0], [0.0, 0.0])
 def testGradientAccumulator(self):
     accumulator = GradientAccumulator()
     accumulator([tf.constant([1.0, 2.0])])
     accumulator([tf.constant([-2.0, 1.0])])
     accumulator([tf.constant([-1.0, 2.0])])
     with self.assertRaises(ValueError):
         accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
     self.assertEqual(accumulator.step, 3)
     self.assertEqual(len(accumulator.gradients), 1)
     self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
     accumulator.reset()
     self.assertEqual(accumulator.step, 0)
     self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
예제 #5
0
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples,
          labels, train_batch_size, pad_token_label_id):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args[
            "gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
    else:
        num_train_steps = (math.ceil(num_train_examples / train_batch_size) //
                           args["gradient_accumulation_steps"] *
                           args["num_train_epochs"])

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
        optimizer = create_optimizer(args["learning_rate"], num_train_steps,
                                     args["warmup_steps"])

        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                optimizer, "dynamic")

        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
        gradient_accumulator = GradientAccumulator()

    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d",
                 args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d",
                 args["gradient_accumulation_steps"])
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients,
                                      model.trainable_variables):
            if gradient is not None:
                scaled_gradient = gradient / (
                    args["n_device"] * args["gradient_accumulation_steps"])
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
            inputs = {
                "attention_mask": train_features["attention_mask"],
                "training": True
            }

            if "token_type_ids" in train_features:
                inputs["token_type_ids"] = train_features["token_type_ids"]

            with tf.GradientTape() as tape:
                logits = model(train_features["input_ids"], **inputs)[0]
                active_loss = tf.reshape(train_labels,
                                         (-1, )) != pad_token_label_id
                active_logits = tf.boolean_mask(
                    tf.reshape(logits, (-1, len(labels))), active_loss)
                active_labels = tf.boolean_mask(
                    tf.reshape(train_labels, (-1, )), active_loss)
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn,
                                                          args=(train_features,
                                                                train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                    per_example_losses,
                                    axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
    train_iterator = master_bar(range(args["num_train_epochs"]))
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
        epoch_iterator = progress_bar(train_dataset,
                                      total=num_train_steps,
                                      parent=train_iterator,
                                      display=args["n_device"] > 1)
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

                if step % args["gradient_accumulation_steps"] == 0:
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

                    if args["logging_steps"] > 0 and global_step % args[
                            "logging_steps"] == 0:
                        # Log metrics
                        if (
                                args["n_device"] == 1
                                and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args,
                                strategy,
                                model,
                                tokenizer,
                                labels,
                                pad_token_label_id,
                                mode="dev")
                            report = metrics.classification_report(y_true,
                                                                   y_pred,
                                                                   digits=4)

                            logging.info("Eval at step " + str(global_step) +
                                         "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))

                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss,
                                                  global_step)
                                tf.summary.scalar("precision", precision,
                                                  global_step)
                                tf.summary.scalar("recall", recall,
                                                  global_step)
                                tf.summary.scalar("f1", f1, global_step)

                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) /
                                args["logging_steps"], global_step)

                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss",
                                          loss_metric.result(),
                                          step=step)

                    if args["save_steps"] > 0 and global_step % args[
                            "save_steps"] == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args["output_dir"],
                            "checkpoint-{}".format(global_step))

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)

                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s",
                                     output_dir)

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
                step += 1

        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() -
                                                    current_time))
예제 #6
0
    def _custom_train(
        self,
        train_dataset,
        tokenizer,
        model,
        num_train_examples,
        train_batch_size,
    ):
        config = self.parent.config._asdict()
        config["strategy"] = self.parent.config.strategy
        config["n_device"] = self.parent.config.n_device
        labels = config["ner_tags"]
        if config["max_steps"] > 0:
            num_train_steps = (
                config["max_steps"] * config["gradient_accumulation_steps"]
            )
            config["epochs"] = 1
        else:
            num_train_steps = (
                math.ceil(num_train_examples / train_batch_size)
                // config["gradient_accumulation_steps"]
                * config["epochs"]
            )

        with config["strategy"].scope():
            loss_fct = self.tf.keras.losses.SparseCategoricalCrossentropy(
                reduction=self.tf.keras.losses.Reduction.NONE
            )
            optimizer = create_optimizer(
                config["learning_rate"],
                num_train_steps,
                config["warmup_steps"],
            )

            if config["use_fp16"]:
                optimizer = self.tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                    optimizer, "dynamic"
                )

            loss_metric = self.tf.keras.metrics.Mean(
                name="loss", dtype=self.tf.float32
            )
            gradient_accumulator = GradientAccumulator()

        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d", num_train_examples)
        self.logger.info("  Num Epochs = %d", config["epochs"])
        self.logger.info(
            "  Instantaneous batch size per device = %d",
            config["per_device_train_batch_size"],
        )
        self.logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            train_batch_size * config["gradient_accumulation_steps"],
        )
        self.logger.info(
            "  Gradient Accumulation steps = %d",
            config["gradient_accumulation_steps"],
        )
        self.logger.info("  Total training steps = %d", num_train_steps)

        self.logger.debug(model.summary())

        @self.tf.function
        def apply_gradients():
            grads_and_vars = []

            for gradient, variable in zip(
                gradient_accumulator.gradients, model.trainable_variables
            ):
                if gradient is not None:
                    scaled_gradient = gradient / (
                        config["n_device"]
                        * config["gradient_accumulation_steps"]
                    )
                    grads_and_vars.append((scaled_gradient, variable))
                else:
                    grads_and_vars.append((gradient, variable))

            optimizer.apply_gradients(grads_and_vars, config["max_grad_norm"])
            gradient_accumulator.reset()

        @self.tf.function
        def train_step(train_features, train_labels):
            def step_fn(train_features, train_labels):
                inputs = {
                    "attention_mask": train_features["input_mask"],
                    "training": True,
                }

                if config["model_architecture_type"] != "distilbert":
                    inputs["token_type_ids"] = (
                        train_features["segment_ids"]
                        if config["model_architecture_type"]
                        in ["bert", "xlnet"]
                        else None
                    )

                with self.tf.GradientTape() as tape:
                    logits = model(train_features["input_ids"], **inputs)[0]
                    logits = self.tf.reshape(logits, (-1, len(labels) + 1))
                    active_loss = self.tf.reshape(
                        train_features["input_mask"], (-1,)
                    )
                    active_logits = self.tf.boolean_mask(logits, active_loss)
                    train_labels = self.tf.reshape(train_labels, (-1,))
                    active_labels = self.tf.boolean_mask(
                        train_labels, active_loss
                    )
                    cross_entropy = loss_fct(active_labels, active_logits)
                    loss = self.tf.reduce_sum(cross_entropy) * (
                        1.0 / train_batch_size
                    )
                    grads = tape.gradient(loss, model.trainable_variables)
                    print(grads)

                    gradient_accumulator(grads)

                return cross_entropy

            per_example_losses = config["strategy"].run(
                step_fn, args=(train_features, train_labels)
            )
            mean_loss = config["strategy"].reduce(
                self.tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0
            )

            return mean_loss

        current_time = datetime.datetime.now()
        train_iterator = master_bar(range(config["epochs"]))
        global_step = 0
        self.logger_loss = 0.0

        for epoch in train_iterator:
            epoch_iterator = progress_bar(
                train_dataset,
                total=num_train_steps,
                parent=train_iterator,
                display=config["n_device"] > 1,
            )
            step = 1

            with config["strategy"].scope():
                for train_features, train_labels in epoch_iterator:
                    loss = train_step(train_features, train_labels)

                    if step % config["gradient_accumulation_steps"] == 0:
                        config["strategy"].run(apply_gradients)
                        loss_metric(loss)
                        global_step += 1
                        if (
                            config["save_steps"] > 0
                            and global_step % config["save_steps"] == 0
                        ):
                            # Save model checkpoint
                            output_dir = os.path.join(
                                config["output_dir"],
                                "checkpoint-{}".format(global_step),
                            )

                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)

                            model.save_pretrained(output_dir)
                            self.logger.info(
                                "Saving model checkpoint to %s", output_dir
                            )

                    train_iterator.child.comment = (
                        f"loss : {loss_metric.result()}"
                    )
                    step += 1

            train_iterator.write(
                f"loss epoch {epoch + 1}: {loss_metric.result()}"
            )
            loss_metric.reset_states()
        self.logger.debug(
            "  Training took time = {}".format(
                datetime.datetime.now() - current_time
            )
        )
def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments,
                               TrainingArguments, LoggingArguments))
    model_args, data_args, train_args, log_args = parser.parse_args_into_dataclasses(
    )

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    squad_steps = get_squad_steps(log_args.extra_squad_steps)
    is_sagemaker = data_args.fsx_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        metadata = (f"{model_args.model_type}"
                    f"-{model_args.model_size}"
                    f"-{model_args.load_from}"
                    f"-{hvd.size()}gpus"
                    f"-{train_args.batch_size}batch"
                    f"-{train_args.gradient_accumulation_steps}accum"
                    f"-{train_args.learning_rate}maxlr"
                    f"-{train_args.end_learning_rate}endlr"
                    f"-{train_args.learning_rate_decay_power}power"
                    f"-{train_args.max_grad_norm}maxgrad"
                    f"-{train_args.optimizer}opt"
                    f"-{train_args.total_steps}steps"
                    f"-{data_args.max_seq_length}seq"
                    f"-{data_args.max_predictions_per_seq}preds"
                    f"-{'preln' if pre_layer_norm else 'postln'}"
                    f"{loss_str}"
                    f"-{model_args.hidden_dropout_prob}dropout"
                    f"-{train_args.seed}seed")
        run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                f"{data_args.fsx_prefix}/logs/albert/{run_name}.log"),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    if model_args.model_type == "albert":
        model_desc = f"albert-{model_args.model_size}-v2"
    elif model_args.model_type == "bert":
        model_desc = f"bert-{model_args.model_size}-uncased"

    config = AutoConfig.from_pretrained(model_desc)
    config.pre_layer_norm = pre_layer_norm
    config.hidden_dropout_prob = model_args.hidden_dropout_prob
    model = TFAutoModelForPreTraining.from_config(config)

    # Create optimizer and enable AMP loss scaling.
    schedule = LinearWarmupPolyDecaySchedule(
        max_learning_rate=train_args.learning_rate,
        end_learning_rate=train_args.end_learning_rate,
        warmup_steps=train_args.warmup_steps,
        total_steps=train_args.total_steps,
        power=train_args.learning_rate_decay_power,
    )
    if train_args.optimizer == "lamb":
        opt = LAMB(
            learning_rate=schedule,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        )
    elif train_args.optimizer == "adam":
        opt = AdamW(weight_decay=0.0, learning_rate=schedule)
    opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        opt, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_opt_weights = None
    if model_args.load_from == "scratch":
        pass
    elif model_args.load_from.startswith("huggingface"):
        assert (model_args.model_type == "albert"
                ), "Only loading pretrained albert models is supported"
        huggingface_name = f"albert-{model_args.model_size}-v2"
        if model_args.load_from == "huggingface":
            albert = TFAlbertModel.from_pretrained(huggingface_name,
                                                   config=config)
            model.albert = albert
    else:
        model_ckpt, opt_ckpt = get_checkpoint_paths_from_prefix(
            model_args.checkpoint_path)

        model = TFAutoModelForPreTraining.from_config(config)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            loaded_opt_weights = np.load(opt_ckpt, allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    if model_args.model_type == "albert":
        train_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/train/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/albert_pretraining/tfrecords/validation/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/albert_*.tfrecord"
    if model_args.model_type == "bert":
        train_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/training/*.tfrecord"
        validation_glob = f"{data_args.fsx_prefix}/bert_pretraining/max_seq_len_{data_args.max_seq_length}_max_predictions_per_seq_{data_args.max_predictions_per_seq}_masked_lm_prob_15/validation/*.tfrecord"

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_mlm_dataset(
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        batch_size=train_args.batch_size,
    )  # Of shape [batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_mlm_dataset(
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            batch_size=train_args.batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 0
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = opt.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            opt=opt,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 0:
            if hvd.rank() == 0 and loaded_opt_weights is not None:
                opt.set_weights(loaded_opt_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(opt.variables(), root_rank=0)
            i = opt.get_weights()[0] - 1

        is_final_step = i >= train_args.total_steps - 1
        do_squad = i in squad_steps or is_final_step
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                model_size=model_args.model_size,
                fsx_prefix=data_args.fsx_prefix,
                step=i,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (
                (i > 0) and
                (i % log_args.checkpoint_frequency == 0)) or is_final_step
            do_validation = (
                (i > 0) and
                (i % log_args.validation_frequency == 0)) or is_final_step

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 0:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = f"{data_args.fsx_prefix}/checkpoints/albert/{run_name}-step{i}"
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                opt_ckpt = f"{checkpoint_prefix}-opt.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {opt_ckpt}")
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                opt_weights = opt.get_weights()
                np.save(opt_ckpt, opt_weights)
                # opt.set_weights(opt_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    f"{data_args.fsx_prefix}/logs/albert/{run_name}")
                with summary_writer.as_default():
                    HP_MODEL_TYPE = hp.HParam("model_type",
                                              hp.Discrete(["albert", "bert"]))
                    HP_MODEL_SIZE = hp.HParam("model_size",
                                              hp.Discrete(["base", "large"]))
                    HP_LEARNING_RATE = hp.HParam("learning_rate",
                                                 hp.RealInterval(1e-5, 1e-1))
                    HP_BATCH_SIZE = hp.HParam("global_batch_size",
                                              hp.IntInterval(1, 64))
                    HP_PRE_LAYER_NORM = hp.HParam("pre_layer_norm",
                                                  hp.Discrete([True, False]))
                    HP_HIDDEN_DROPOUT = hp.HParam("hidden_dropout")
                    hparams = [
                        HP_MODEL_TYPE,
                        HP_MODEL_SIZE,
                        HP_BATCH_SIZE,
                        HP_LEARNING_RATE,
                        HP_PRE_LAYER_NORM,
                        HP_HIDDEN_DROPOUT,
                    ]

                    HP_F1 = hp.Metric("squad_f1")
                    HP_EXACT = hp.Metric("squad_exact")
                    HP_MLM = hp.Metric("val_mlm_acc")
                    HP_SOP = hp.Metric("val_sop_acc")
                    HP_TRAIN_LOSS = hp.Metric("train_loss")
                    HP_VAL_LOSS = hp.Metric("val_loss")
                    metrics = [
                        HP_TRAIN_LOSS, HP_VAL_LOSS, HP_F1, HP_EXACT, HP_MLM,
                        HP_SOP
                    ]

                    hp.hparams_config(
                        hparams=hparams,
                        metrics=metrics,
                    )
                    hp.hparams(
                        {
                            HP_MODEL_TYPE: model_args.model_type,
                            HP_MODEL_SIZE: model_args.model_size,
                            HP_LEARNING_RATE: train_args.learning_rate,
                            HP_BATCH_SIZE: train_args.batch_size * hvd.size(),
                            HP_PRE_LAYER_NORM: model_args.pre_layer_norm
                            == "true",
                            HP_HIDDEN_DROPOUT: model_args.hidden_dropout_prob,
                        },
                        trial_id=run_name,
                    )

            # Log to TensorBoard
            with summary_writer.as_default():
                tf.summary.scalar("weight_norm", weight_norm, step=i)
                tf.summary.scalar("loss_scale", loss_scale, step=i)
                tf.summary.scalar("learning_rate", learning_rate, step=i)
                tf.summary.scalar("train_loss", loss, step=i)
                tf.summary.scalar("train_mlm_loss", mlm_loss, step=i)
                tf.summary.scalar("train_mlm_acc", mlm_acc, step=i)
                tf.summary.scalar("train_sop_loss", sop_loss, step=i)
                tf.summary.scalar("train_sop_acc", sop_acc, step=i)
                tf.summary.scalar("grad_norm", grad_norm, step=i)
                if do_validation:
                    tf.summary.scalar("val_loss", val_loss, step=i)
                    tf.summary.scalar("val_mlm_loss", val_mlm_loss, step=i)
                    tf.summary.scalar("val_mlm_acc", val_mlm_acc, step=i)
                    tf.summary.scalar("val_sop_loss", val_sop_loss, step=i)
                    tf.summary.scalar("val_sop_acc", val_sop_acc, step=i)
                if do_squad:
                    tf.summary.scalar("squad_f1", squad_f1, step=i)
                    tf.summary.scalar("squad_exact", squad_exact, step=i)

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")
def main():
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments,
         LoggingArguments, PathArguments))
    (
        model_args,
        data_args,
        train_args,
        log_args,
        path_args,
        remaining_strings,
    ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    # SageMaker may have some extra strings. TODO: Test this on SM.
    assert len(remaining_strings
               ) == 0, f"The args {remaining_strings} could not be parsed."

    tf.random.set_seed(train_args.seed)
    tf.autograph.set_verbosity(0)

    # Settings init
    parse_bool = lambda arg: arg == "true"
    do_gradient_accumulation = train_args.gradient_accumulation_steps > 1
    do_xla = not parse_bool(train_args.skip_xla)
    do_eager = parse_bool(train_args.eager)
    skip_sop = parse_bool(train_args.skip_sop)
    skip_mlm = parse_bool(train_args.skip_mlm)
    pre_layer_norm = parse_bool(model_args.pre_layer_norm)
    fast_squad = parse_bool(log_args.fast_squad)
    dummy_eval = parse_bool(log_args.dummy_eval)
    is_sagemaker = path_args.filesystem_prefix.startswith("/opt/ml")
    disable_tqdm = is_sagemaker
    global max_grad_norm
    max_grad_norm = train_args.max_grad_norm

    # Horovod init
    hvd.init()
    gpus = tf.config.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.set_visible_devices(gpus[hvd.local_rank()], "GPU")
    # XLA, AutoGraph
    tf.config.optimizer.set_jit(do_xla)
    tf.config.experimental_run_functions_eagerly(do_eager)

    if hvd.rank() == 0:
        # Run name should only be used on one process to avoid race conditions
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        platform = "sm" if is_sagemaker else "eks"
        if skip_sop:
            loss_str = "-skipsop"
        elif skip_mlm:
            loss_str = "-skipmlm"
        else:
            loss_str = ""

        if log_args.run_name is None:
            metadata = (
                f"{model_args.model_type}"
                f"-{model_args.model_size}"
                f"-{model_args.load_from}"
                f"-{hvd.size()}gpus"
                f"-{train_args.per_gpu_batch_size * hvd.size() * train_args.gradient_accumulation_steps}globalbatch"
                f"-{train_args.learning_rate}maxlr"
                f"-{train_args.learning_rate_decay_power}power"
                f"-{train_args.optimizer}opt"
                f"-{train_args.total_steps}steps"
                f"-{'preln' if pre_layer_norm else 'postln'}"
                f"{loss_str}"
                f"-{model_args.hidden_dropout_prob}dropout")
            run_name = f"{current_time}-{platform}-{metadata}-{train_args.name if train_args.name else 'unnamed'}"
        else:
            run_name = log_args.run_name

        # Logging should only happen on a single process
        # https://stackoverflow.com/questions/9321741/printing-to-screen-and-writing-to-a-file-at-the-same-time
        level = logging.INFO
        format = "%(asctime)-15s %(name)-12s: %(levelname)-8s %(message)s"
        handlers = [
            logging.FileHandler(
                os.path.join(path_args.filesystem_prefix, path_args.log_dir,
                             f"{run_name}.log")),
            TqdmLoggingHandler(),
        ]
        logging.basicConfig(level=level, format=format, handlers=handlers)

        # Check that arguments passed in properly, only after registering the alert_func and logging
        assert not (skip_sop
                    and skip_mlm), "Cannot use --skip_sop and --skip_mlm"

    wrap_global_functions(do_gradient_accumulation)

    # Create optimizer and enable AMP loss scaling.
    if train_args.optimizer == "lamb":
        optimizer = get_lamb_optimizer(train_args)
    elif train_args.optimizer == "adamw":
        optimizer = get_adamw_optimizer(train_args)

    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
        optimizer, loss_scale="dynamic")
    gradient_accumulator = GradientAccumulator()

    loaded_optimizer_weights = None

    model = create_model(model_class=TFAutoModelForPreTraining,
                         model_args=model_args)
    tokenizer = create_tokenizer(model_args.model_type)
    if model_args.load_from == "checkpoint":
        checkpoint_path = os.path.join(path_args.filesystem_prefix,
                                       model_args.checkpoint_path)
        model_ckpt, optimizer_ckpt = get_checkpoint_paths_from_prefix(
            checkpoint_path)
        if hvd.rank() == 0:
            model.load_weights(model_ckpt)
            if model_args.load_optimizer_state == "true":
                loaded_optimizer_weights = np.load(optimizer_ckpt,
                                                   allow_pickle=True)
            # We do not set the weights yet, we have to do a first step to initialize the optimizer.

    # Train filenames are [1, 2047], Val filenames are [0]. Note the different subdirectories
    # Move to same folder structure and remove if/else
    train_glob = os.path.join(path_args.filesystem_prefix, path_args.train_dir,
                              "*.tfrecord")
    validation_glob = os.path.join(path_args.filesystem_prefix,
                                   path_args.val_dir, "*.tfrecord")

    train_filenames = glob.glob(train_glob)
    validation_filenames = glob.glob(validation_glob)

    train_dataset = get_dataset_from_tfrecords(
        model_type=model_args.model_type,
        filenames=train_filenames,
        max_seq_length=data_args.max_seq_length,
        max_predictions_per_seq=data_args.max_predictions_per_seq,
        per_gpu_batch_size=train_args.per_gpu_batch_size,
    )  # Of shape [per_gpu_batch_size, ...]
    # Batch of batches, helpful for gradient accumulation. Shape [grad_steps, per_gpu_batch_size, ...]
    train_dataset = train_dataset.batch(train_args.gradient_accumulation_steps)
    # One iteration with 10 dupes, 8 nodes seems to be 60-70k steps.
    train_dataset = train_dataset.prefetch(buffer_size=8)

    # Validation should only be done on one node, since Horovod doesn't allow allreduce on a subset of ranks
    if hvd.rank() == 0:
        validation_dataset = get_dataset_from_tfrecords(
            model_type=model_args.model_type,
            filenames=validation_filenames,
            max_seq_length=data_args.max_seq_length,
            max_predictions_per_seq=data_args.max_predictions_per_seq,
            per_gpu_batch_size=train_args.per_gpu_batch_size,
        )
        # validation_dataset = validation_dataset.batch(1)
        validation_dataset = validation_dataset.prefetch(buffer_size=8)

        pbar = tqdm.tqdm(total=train_args.total_steps, disable=disable_tqdm)
        summary_writer = None  # Only create a writer if we make it through a successful step
        logger.info(f"Starting training, job name {run_name}")

    i = 1
    start_time = time.perf_counter()
    for batch in train_dataset:
        learning_rate = optimizer.learning_rate(
            step=tf.constant(i, dtype=tf.float32))
        # weight_decay = wd_schedule(step=tf.constant(i, dtype=tf.float32))
        loss_scale = optimizer.loss_scale()
        loss, mlm_loss, mlm_acc, sop_loss, sop_acc, grad_norm, weight_norm = train_step(
            model=model,
            optimizer=optimizer,
            gradient_accumulator=gradient_accumulator,
            batch=batch,
            gradient_accumulation_steps=train_args.gradient_accumulation_steps,
            skip_sop=skip_sop,
            skip_mlm=skip_mlm,
        )

        # Don't want to wrap broadcast_variables() in a tf.function, can lead to asynchronous errors
        if i == 1:
            if hvd.rank() == 0 and loaded_optimizer_weights is not None:
                optimizer.set_weights(loaded_optimizer_weights)
            hvd.broadcast_variables(model.variables, root_rank=0)
            hvd.broadcast_variables(optimizer.variables(), root_rank=0)
            i = optimizer.get_weights()[0]

        is_final_step = i >= train_args.total_steps
        do_squad = (log_args.squad_frequency != 0) and (
            (i % log_args.squad_frequency == 0) or is_final_step)
        # Squad requires all the ranks to train, but results are only returned on rank 0
        if do_squad:
            squad_results = get_squad_results_while_pretraining(
                model=model,
                tokenizer=tokenizer,
                model_size=model_args.model_size,
                filesystem_prefix=path_args.filesystem_prefix,
                step=i,
                dataset=data_args.squad_version,
                fast=log_args.fast_squad,
                dummy_eval=log_args.dummy_eval,
            )
            if hvd.rank() == 0:
                squad_exact, squad_f1 = squad_results["exact"], squad_results[
                    "f1"]
                logger.info(
                    f"SQuAD step {i} -- F1: {squad_f1:.3f}, Exact: {squad_exact:.3f}"
                )
            # Re-wrap autograph so it doesn't get arg mismatches
            wrap_global_functions(do_gradient_accumulation)
            gc.collect()

        if hvd.rank() == 0:
            do_log = i % log_args.log_frequency == 0
            do_checkpoint = (log_args.checkpoint_frequency != 0) and (
                (i % log_args.checkpoint_frequency == 0) or is_final_step)
            do_validation = (log_args.validation_frequency != 0) and (
                (i % log_args.validation_frequency == 0) or is_final_step)

            pbar.update(1)
            description = f"Loss: {loss:.3f}, MLM: {mlm_loss:.3f}, SOP: {sop_loss:.3f}, MLM_acc: {mlm_acc:.3f}, SOP_acc: {sop_acc:.3f}"
            pbar.set_description(description)
            if do_log:
                elapsed_time = time.perf_counter() - start_time
                if i == 1:
                    logger.info(f"First step: {elapsed_time:.3f} secs")
                else:
                    it_per_sec = log_args.log_frequency / elapsed_time
                    logger.info(
                        f"Train step {i} -- {description} -- It/s: {it_per_sec:.2f}"
                    )
                    start_time = time.perf_counter()

            if do_checkpoint:
                checkpoint_prefix = os.path.join(path_args.filesystem_prefix,
                                                 path_args.checkpoint_dir,
                                                 f"{run_name}-step{i}")
                model_ckpt = f"{checkpoint_prefix}.ckpt"
                optimizer_ckpt = f"{checkpoint_prefix}-optimizer.npy"
                logger.info(
                    f"Saving model at {model_ckpt}, optimizer at {optimizer_ckpt}"
                )
                model.save_weights(model_ckpt)
                # model.load_weights(model_ckpt)

                optimizer_weights = optimizer.get_weights()
                np.save(optimizer_ckpt, optimizer_weights)
                # optimizer.set_weights(optimizer_weights)

            if do_validation:
                val_loss, val_mlm_loss, val_mlm_acc, val_sop_loss, val_sop_acc = run_validation(
                    model=model,
                    validation_dataset=validation_dataset,
                    skip_sop=skip_sop,
                    skip_mlm=skip_mlm,
                )
                description = f"Loss: {val_loss:.3f}, MLM: {val_mlm_loss:.3f}, SOP: {val_sop_loss:.3f}, MLM_acc: {val_mlm_acc:.3f}, SOP_acc: {val_sop_acc:.3f}"
                logger.info(f"Validation step {i} -- {description}")

            # Create summary_writer after the first step
            if summary_writer is None:
                summary_writer = tf.summary.create_file_writer(
                    os.path.join(path_args.filesystem_prefix,
                                 path_args.log_dir, run_name))
                config = {
                    **asdict(model_args),
                    **asdict(data_args),
                    **asdict(train_args),
                    **asdict(log_args),
                    "global_batch_size":
                    train_args.per_gpu_batch_size * hvd.size(),
                }
                if is_wandb_available():
                    wandb.init(config=config, project=model_args.model_type)
                    wandb.run.save()
                    wandb_run_name = wandb.run.name

            train_metrics = {
                "weight_norm": weight_norm,
                "grad_norm": grad_norm,
                "loss_scale": loss_scale,
                "learning_rate": learning_rate,
                "train/loss": loss,
                "train/mlm_loss": mlm_loss,
                "train/mlm_acc": mlm_acc,
                "train/sop_loss": sop_loss,
                "train/sop_acc": sop_acc,
            }
            all_metrics = {**train_metrics}
            if do_validation:
                val_metrics = {
                    "val/loss": val_loss,
                    "val/mlm_loss": val_mlm_loss,
                    "val/mlm_acc": val_mlm_acc,
                    "val/sop_loss": val_sop_loss,
                    "val/sop_acc": val_sop_acc,
                }
                all_metrics = {**all_metrics, **val_metrics}
            if do_squad:
                squad_metrics = {
                    "squad/f1": squad_f1,
                    "squad/exact": squad_exact,
                }
                all_metrics = {**all_metrics, **squad_metrics}

            # Log to TensorBoard
            with summary_writer.as_default():
                for name, val in all_metrics.items():
                    tf.summary.scalar(name, val, step=i)
            # Log to Weights & Biases
            if is_wandb_available():
                wandb.log({"step": i, **all_metrics})

        i += 1
        if is_final_step:
            break

    if hvd.rank() == 0:
        pbar.close()
        logger.info(f"Finished pretraining, job name {run_name}")
class TFTrainer:
    model: TFPreTrainedModel
    args: TFTrainingArguments
    train_dataset: Optional[tf.data.Dataset]
    eval_dataset: Optional[tf.data.Dataset]
    test_dataset: Optional[tf.data.Dataset]
    dataset_info: DatasetInfo

    strategy: Strategy

    def __init__(
        self,
        model: TFPreTrainedModel,
        args: TFTrainingArguments,
        train_dataset: Optional[tf.data.Dataset] = None,
        eval_dataset: Optional[tf.data.Dataset] = None,
        test_dataset: Optional[tf.data.Dataset] = None,
        dataset_info: Optional[DatasetInfo] = None,
    ):
        self.model = model
        self.args = args
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.test_dataset = test_dataset
        self.dataset_info = dataset_info

        self.gradient_accumulator = GradientAccumulator()
        self.accum_steps = 1

        if self.args.strategy_name == "mirrored":
            self.strategy = tf.distribute.MirroredStrategy()
        elif self.args.strategy_name == "onedevice":
            if len(tf.config.list_physical_devices('GPU')) >= 1:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/gpu:0")
            else:
                self.strategy = tf.distribute.OneDeviceStrategy(
                    device="/cpu:0")
        else:
            raise ValueError("The strategy {} does not exists.".format(
                self.args.strategy_name))

        # To conform with Trainer's API we call this from here.
        # All args should be in the `args` already.
        self._setup_training()

    def _setup_training(self,
                        checkpoint_path: str = "checkpoints",
                        log_path: str = "logs") -> None:
        """
        Setup the different steps to train a model:
          - check if all the data are given
          - create the proper strategy
          - create the features
          - prepare the model settings

        Args:
          checkpoint_path: the directory path where the model checkpoints will be saved, "./checkpoints" folder by default.
          log_path: the directory path where the Tensorboard logs will be saved, "./logs" folder by default.
          data_cache_dir: the directory path where the data will be cached, "./cache" folder by default.
          model_cache_dir (optional): the directory path where the pretrained model will be cached.
        """
        self._prepare_dataset()

        with self.strategy.scope():
            self._create_optimizer()
            _ = self.optimizer.iterations
            self._set_loss_and_metric()
            self._create_checkpoint_manager(checkpoint_path)
            self._create_summary_writer(log_path)

    def _set_loss_and_metric(self) -> None:
        """
        Create the training loss and metric with their name. Allowed names are those listed
        in the Tensorflow documentation and those contained in the transformers library.
        """
        try:
            self.loss = tf.keras.losses.get({
                "class_name": self.args.loss_name,
                "config": {
                    "from_logits": True,
                    "reduction": tf.keras.losses.Reduction.NONE
                }
            })
        except TypeError:
            self.loss = tf.keras.losses.get({
                "class_name": self.args.loss_name,
                "config": {
                    "reduction": tf.keras.losses.Reduction.NONE
                }
            })

        self.train_acc_metric = tf.keras.metrics.get({
            "class_name": self.args.metric_name,
            "config": {
                "name": "train_accuracy"
            }
        })
        self.test_acc_metric = tf.keras.metrics.get({
            "class_name": self.args.metric_name,
            "config": {
                "name": "test_accuracy"
            }
        })

    def _create_summary_writer(self, log_path: str) -> None:
        """
        Create a summary writer to be able to read the logs in Tensorboard.
        Args:
          log_path: the directory path where the Tensorboard logs will be saved.
        """
        self.log_path = log_path
        self.train_writer = tf.summary.create_file_writer(log_path + "/train")
        self.test_writer = tf.summary.create_file_writer(log_path + "/test")

    def _prepare_dataset(self) -> None:
        """
        Prepare the training, validation and test data.
        Args:
          data_cache_dir: the directory path where the cached data are / should be saved.
        """
        train_batch = self.args.per_gpu_train_batch_size * self.strategy.num_replicas_in_sync
        eval_batch = self.args.per_gpu_eval_batch_size * self.strategy.num_replicas_in_sync
        test_batch = self.args.per_gpu_eval_batch_size
        self.train_steps = math.ceil(self.dataset_info.sizes["train"] /
                                     train_batch)
        self.train_dataset = self.train_dataset.shuffle(128).batch(
            train_batch).repeat(-1)
        self.train_dataset = self.strategy.experimental_distribute_dataset(
            self.train_dataset)
        self.validation_steps = math.ceil(
            self.dataset_info.sizes["validation"] / eval_batch)
        self.eval_dataset = self.eval_dataset.batch(eval_batch)
        self.eval_dataset = self.strategy.experimental_distribute_dataset(
            self.eval_dataset)
        self.test_steps = math.ceil(self.dataset_info.sizes["test"] /
                                    test_batch)
        self.test_dataset = self.test_dataset.batch(test_batch)

    def _create_optimizer(self) -> None:
        """
        Create the training optimizer with its name. Allowed names are those listed
        in the Tensorflow documentation and those contained in the transformers library.
        """
        if self.args.optimizer_name == "adamw":
            learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
                initial_learning_rate=self.args.learning_rate,
                decay_steps=self.train_steps,
                end_learning_rate=0.0)
            if self.args.warmup_steps:
                learning_rate_fn = WarmUp(
                    initial_learning_rate=self.args.learning_rate,
                    decay_schedule_fn=learning_rate_fn,
                    warmup_steps=self.args.warmup_steps)

            self.optimizer = AdamWeightDecay(
                learning_rate=learning_rate_fn,
                weight_decay_rate=0.01,
                epsilon=self.args.adam_epsilon,
                exclude_from_weight_decay=["layer_norm", "bias"])
        else:
            try:
                self.optimizer = tf.keras.optimizers.get({
                    "class_name":
                    self.args.optimizer_name,
                    "config": {
                        "learning_rate": self.args.learning_rate,
                        "epsilon": self.args.adam_epsilon
                    }
                })
            except TypeError:
                # This is for the case where the optimizer is not Adam-like such as SGD
                self.optimizer = tf.keras.optimizers.get({
                    "class_name":
                    self.args.optimizer_name,
                    "config": {
                        "learning_rate": self.args.learning_rate
                    }
                })

    def _create_checkpoint_manager(self,
                                   checkpoint_path: str,
                                   max_to_keep: int = 5,
                                   load_model: bool = True) -> None:
        """
        Create a checkpoint manager in order to be able to make the training
        fault-tolerant.
        Args:
          checkpoint_path: the directory path where the model checkpoints will be saved.
          max_to_keep: the maximum number of checkpoints to keep in the checkpoint path.
          load_model: if we want to start the training from the latest checkpoint.
        """
        ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
        self.model.ckpt_manager = tf.train.CheckpointManager(
            ckpt, checkpoint_path, max_to_keep=max_to_keep)

        if load_model:
            ckpt.restore(self.model.ckpt_manager.latest_checkpoint)

    def _evaluate_steps(self, per_replica_features, per_replica_labels):
        """
        One step evaluation across replica.
        Args:
          features: the batched features.
          labels: the batched labels.
        Returns:
          The loss corresponding to the given batch.
        """
        per_replica_loss = self.strategy.experimental_run_v2(
            self._run_model,
            args=(per_replica_features, per_replica_labels, False))

        return self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                    per_replica_loss, None)

    def _evaluate(self) -> None:
        """
        Evaluate the model during the training at the end of each epoch.
        """
        step = 1
        loss = 0.0

        for features, labels in self.eval_dataset:
            step = tf.convert_to_tensor(step, dtype=tf.int64)
            loss = self._evaluate_steps(features, labels)
            loss = tf.reduce_mean(loss)

            with self.test_writer.as_default():
                tf.summary.scalar("loss", loss, step=step)

            if step % self.validation_steps == 0:
                break

            step += 1

        return loss

    def train(self) -> None:
        """
        Train method to train the model.
        """
        tf.summary.trace_on(graph=True, profiler=True)
        self.gradient_accumulator.reset()

        iterations = self.optimizer.iterations
        tf.summary.experimental.set_step(iterations)

        for epoch in range(int(self.args.num_train_epochs)):
            for training_loss in self._training_steps():
                step = iterations.numpy()
                training_loss = tf.reduce_mean(training_loss)

                with self.train_writer.as_default():
                    tf.summary.scalar("loss", training_loss, step=step)

                if step == 1:
                    with self.train_writer.as_default():
                        tf.summary.trace_export(name="training",
                                                step=step,
                                                profiler_outdir=self.log_path)

                if step % 10 == 0:
                    logger.info(
                        "Epoch {} Step {} Loss {:.4f} Train Accuracy {:.4f}".
                        format(epoch, step, training_loss.numpy(),
                               self.train_acc_metric.result()))

                if step % 100 == 0:
                    ckpt_save_path = self.model.ckpt_manager.save()
                    logger.info("Saving checkpoint for step {} at {}".format(
                        step, ckpt_save_path))

                if step % self.train_steps == 0:
                    break

            test_loss = self._evaluate()

            logger.info(
                "Epoch {} Step {} Train Loss {:.4f} Train Accuracy {:.4f}".
                format(epoch, step, training_loss.numpy(),
                       self.train_acc_metric.result()))
            logger.info(
                "Epoch {} Validation Loss {:.4f} Validation Accuracy {:.4f}".
                format(epoch, test_loss.numpy(),
                       self.test_acc_metric.result()))

            self.train_acc_metric.reset_states()
            self.test_acc_metric.reset_states()

    def _training_steps(self):
        """
        Returns a generator over training steps (i.e. parameters update).
        Args:
          dataset: The training dataset.
        Returns:
          A generator that yields a loss value to report for this step.
        """
        for i, loss in enumerate(self._accumulate_next_gradients()):
            if i % self.accum_steps == 0:
                self._apply_gradients()
                yield loss

    @tf.function
    def _apply_gradients(self):
        """Applies the gradients (cross-replica)."""
        self.strategy.experimental_run_v2(self._step)

    def _step(self):
        """Applies gradients and resets accumulation."""
        gradient_scale = self.gradient_accumulator.step * self.strategy.num_replicas_in_sync
        gradients = [
            gradient / tf.cast(gradient_scale, gradient.dtype)
            for gradient in self.gradient_accumulator.gradients
        ]
        gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm,
                                       self.args.max_grad_norm))
                     for grad in gradients]
        vars = self.model.trainable_variables
        if self.args.mode == "labelling":
            vars = [
                var for var in self.model.trainable_variables
                if "pooler" not in var.name
            ]
        self.optimizer.apply_gradients(list(zip(gradients, vars)))
        self.gradient_accumulator.reset()

    def _accumulate_next_gradients(self):
        """Accumulates the gradients from the next element in dataset."""
        iterator = iter(self.train_dataset)

        @tf.function
        def _accumulate_next():
            per_replica_features, per_replica_labels = next(iterator)

            return self._accumulate_gradients(per_replica_features,
                                              per_replica_labels)

        while True:
            try:
                yield _accumulate_next()
            except tf.errors.OutOfRangeError:
                break

    def _accumulate_gradients(self, per_replica_features, per_replica_labels):
        """Accumulates the gradients across all the replica."""
        per_replica_loss = self.strategy.experimental_run_v2(
            self._forward, args=(per_replica_features, per_replica_labels))

        return self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                    per_replica_loss, None)

    def _forward(self, features, labels):
        """Forwards a training example and accumulates the gradients."""
        per_example_loss = self._run_model(features, labels, True)
        loss = tf.nn.compute_average_loss(
            per_example_loss,
            global_batch_size=self.args.per_gpu_train_batch_size)
        vars = self.model.trainable_variables
        if self.args.mode == "labelling":
            vars = [
                var for var in self.model.trainable_variables
                if "pooler" not in var.name
            ]
        gradients = self.optimizer.get_gradients(loss, vars)

        self.gradient_accumulator(gradients)

        return per_example_loss

    def _run_model(self, features, labels, training):
        """
        Computes the loss of the given features and labels pair.
        Args:
          features: the batched features.
          labels: the batched labels.
        """
        if self.args.mode == "classification" or self.args.mode == "labelling":
            logits = self.model(features, training=training)[0]
        else:
            logits = self.model(features, training=training)

        if self.args.mode == "labelling":
            active_loss = tf.reshape(labels, (-1, )) != -1
            logits = tf.boolean_mask(
                tf.reshape(logits, (-1, len(self.dataset_info.labels))),
                active_loss)
            labels = tf.boolean_mask(tf.reshape(labels, (-1, )), active_loss)

        loss = self.loss(labels, logits)

        if training:
            self.train_acc_metric(labels, logits)
        else:
            self.test_acc_metric(labels, logits)

        return loss

    def test(self) -> None:
        """
        Test the model over the test dataset and print a report.
        """
        y_true = []
        results = self.model.predict(self.test_dataset, steps=self.test_steps)

        if self.args.mode == "classification":
            for batch in self.test_dataset:
                y_true.extend(batch[1].numpy().tolist())

            y_pred = np.reshape(np.argmax(results, axis=-1), (-1, 1)).tolist()
            y_true = list(itertools.chain.from_iterable(y_true))
            y_pred = list(itertools.chain.from_iterable(y_pred))

            logger.info(
                classification_report(y_true,
                                      y_pred,
                                      target_names=self.dataset_info.labels))

    def save_model(self, save_path: str) -> None:
        """
        Save the pretrained model and create a Tensorflow saved model.
        Args:
          save_path: directory path where the pretrained model and
            Tensorflow saved model will be saved
        """
        logger.info("Saving model in {}".format(save_path))

        path = os.path.join(save_path, "saved_model")

        os.makedirs(path, exist_ok=True)
        self.model.save_pretrained(save_path)
        tf.saved_model.save(self.model, path)