def test_training(self):
        if not self.model_tester.is_training:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        for model_class in self.all_model_classes:
            if model_class in MODEL_MAPPING.values():
                continue
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()
    def test_training(self):
        if not self.model_tester.is_training:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        config.return_dict = True

        for model_class in self.all_model_classes:
            # DeiTForImageClassificationWithTeacher supports inference-only
            if (model_class in MODEL_MAPPING.values() or model_class.__name__
                    == "DeiTForImageClassificationWithTeacher"):
                continue
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict,
                                             model_class,
                                             return_labels=True)
            loss = model(**inputs).loss
            loss.backward()
    def test_training_gradient_checkpointing(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        if not self.model_tester.is_training or not hasattr(
                config, "gradient_checkpointing"):
            return

        config.gradient_checkpointing = True
        config.return_dict = True

        for model_class in self.all_model_classes:
            if model_class in MODEL_MAPPING.values(
            ) or model_class in MODEL_WITH_HEADS_MAPPING.values():
                continue
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict,
                                             model_class,
                                             return_labels=True)
            loss = model(**inputs).loss
            loss.backward()
from accelerate import Accelerator
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AdamW,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    SchedulerType,
    get_scheduler,
    set_seed,
)

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def parse_args():
    parser = argparse.ArgumentParser(
        description=
        "Finetune a transformers model on a Masked Language Modeling task")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--dataset_config_name",