Beispiel #1
0
    def save_to_disk(self, path: str):
        template = datasets.DatasetDict()
        template['train'] = self.train_dataset._dataset
        template['validation'] = self.val_dataset._dataset
        template['test'] = self.test_dataset._dataset

        template.save_to_disk(path)
 def __init__(self, conf=None):
     super().__init__()
     # save conf, accessible in self.hparams.conf
     self.save_hyperparameters()
     # pretrained models
     self.sentence_classifier = BertClf.load_from_checkpoint(checkpoint_path=os.path.join(os.path.split(hydra.utils.get_original_cwd())[0], 'outputs', conf.model.classifier_path))
     # tokenizers
     self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
     self.bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
     # classification threshold
     self.threshold = conf.model.threshold
     # dataset
     self.dataset = datasets.DatasetDict({'train': [], 'val': [], 'test': []})
     self._dataset = {'train': self._init_dataset(), 'val': self._init_dataset(), 'test': self._init_dataset()}
     self.dataset_path = os.path.join(os.path.split(hydra.utils.get_original_cwd())[0], 'datasets', 'supporting_facts') + (f'_augmented' if self.hparams.conf.model.data_augmentation else '') + f'_{self.threshold}'
Beispiel #3
0
    def map(self, batch_size=1000, cache_file_name=None, **kwargs):
        """
    Args:
      batch_size(int): See :class:`datasets.Dataset.map`, shouldn't be None here
      cache_file_name: The same with the one of :func:`my_map`
      kwargs: passed to :class:`datasets.Dataset.map`
    """

        # check
        assert 'remove_columns' not in kwargs, "Aggregation type transform will only leave output columns for output dataset."

        # infer cache_file_name s
        if not isinstance(cache_file_name, dict):
            cache_names = {k: cache_file_name for k in self.dsets.keys()}
        for k, dset in self.dsets.items():
            if cache_names[k] is None: continue
            if not cache_names[k].endswith('.arrow'):
                cache_names[k] += '.arrow'
            if '{split}' in cache_names[k]:
                cache_names[k] = cache_names[k].format(split=k)
            if '/' not in cache_names[k]:
                cache_names[k] = os.path.join(dset.cache_directory(),
                                              cache_names[k])

        # map every dataset
        mapped_dsets = {}
        for k, dset in self.dsets.items():
            self.last_idx = len(dset) - 1
            mapped_dset = dset.map(
                function=self,
                batched=True,
                batch_size=batch_size,
                with_indices=True,
                num_proc=1,
                cache_file_name=cache_names[k],
                remove_columns=self.
                in_cols,  # Cuz output column has less rows (combined) than orginal column
                **kwargs)
            mapped_dset.set_format(None, columns=self.out_cols)
            mapped_dsets[k] = mapped_dset

        if self.single: return mapped_dsets['Single']
        else: return datasets.DatasetDict(mapped_dsets)
Beispiel #4
0
import torch
import numpy as np


## 1) Data preprocessing and loading
# Download a dataset from the HF datasets hub
ag_dataset = load_dataset('ag_news')

# Create a train/dev/test splits
ag_dev_dataset = load_dataset('ag_news', split='train[10%:11%]')
ag_train_dataset = load_dataset('ag_news', split='train[:10%]')
ag_test_dataset = load_dataset('ag_news', split='test[11%:12%]')

# merge the splits in a single `datasets.DatasetDict` object
ag_split = {split: data for split, data in zip(['train', 'test', 'dev'], [ag_train_dataset, ag_test_dataset, ag_dev_dataset])}
ag_dataset_split =  datasets.DatasetDict(ag_split)

# Count the number of labels.
# Important: use all the splits to compute the labels. 
num_labels = len(set(ag_dataset_split['train'].unique('label') + 
                     ag_dataset_split['test'].unique('label') +
                     ag_dataset_split['dev'].unique('label')))

## 2) Prepare the features: tokenizing, padding and truncate
# Define a tokenizer

model_pretrained = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_pretrained)

"""🤔 **Understanding BERT tokenizer**"""
        chars = [chr(i) for i in range(97, 123, 1)]
        for label_name, label_size in zip(label_names, label_sizes):
            d[label_name] = []
            c[label_name] = list(range(label_size))
        for i in range(size):
            length = np.random.randint(1, max_length)
            d['tokens'].append([''.join(np.random.choice(chars, size=max_token_length))
                                for _ in range(length)])
            for label_name in label_names:
                dummy_labels = np.random.choice(c[label_name],
                                                size=length)
                d[label_name].append(dummy_labels)
        return Dataset.from_dict(d)
    dataset = datasets.DatasetDict(
        {'train': generat_dummy_dataset(50, 50, 8, ['ner_tags', 'pos_tags'], [10, 20]),
         'validation': generat_dummy_dataset(10, 50, 8, ['ner_tags', 'pos_tags'], [10, 20]),
         'test': generat_dummy_dataset(10, 50, 8, ['ner_tags', 'pos_tags'], [10, 20])
         })
    label_maps = {i: str(name) for i, name in
                  enumerate(range(20))}
    if 'ner' in label_col.lower():
        label_names = ['O'] + ['B-' + str(name) for i, name in enumerate(range(1, 20))]
    else:
        label_names = [str(name) for i, name in enumerate(range(20))]
    num_labels = 20
else:
    raise NotImplementedError


def pre_tokenize(token, space_token=custom_args.space_token):
    token = token.replace(' ', space_token)
    def prepare_data(self):
        if self.conf.dataset.setup.from_disk:  # do not load preprocessed dataset
            dataset = None
        elif self.conf.dataset.preprocessing.from_disk and os.path.exists(
                self.dataset_path):  # load preprocessed dataset
            log.info(f'Loading preprocessed dataset from {self.dataset_path}')
            dataset = datasets.load_from_disk(self.dataset_path)
            # import IPython
            # IPython.embed()
            # exit(1)
        else:  # preprocess dataset
            log.info(
                f'Loading dataset {self.conf.dataset.name} with splits {list(self.conf.dataset.preprocessing.splits)} (combine: {self.conf.dataset.preprocessing.combine})'
            )
            if self.conf.dataset.preprocessing.combine:
                # get raw dataset combined
                dataset_raw = datasets.load_dataset(
                    self.conf.dataset.name,
                    name='distractor',
                    split='+'.join(self.conf.dataset.preprocessing.splits))
                # split into train-val-test
                dataset_raw = dataset_raw.train_test_split(
                    test_size=self.conf.dataset.preprocessing.test_split,
                    shuffle=self.conf.dataset.preprocessing.shuffle)
                dataset_raw['train'] = dataset_raw['train'].train_test_split(
                    test_size=self.conf.dataset.preprocessing.val_split /
                    (1 - self.conf.dataset.preprocessing.test_split),
                    shuffle=self.conf.dataset.preprocessing.shuffle)
                dataset = datasets.DatasetDict({
                    'train':
                    dataset_raw['train']['train'],
                    'val':
                    dataset_raw['train']['test'],
                    'test':
                    dataset_raw['test']
                })
            else:
                # get raw dataset
                dataset_train, dataset_test = datasets.load_dataset(
                    self.conf.dataset.name,
                    name='distractor',
                    split=list(self.conf.dataset.preprocessing.splits))
                # split into train-val-test
                dataset_train = dataset_train.train_test_split(
                    test_size=self.conf.dataset.preprocessing.val_split,
                    shuffle=self.conf.dataset.preprocessing.shuffle)
                dataset = datasets.DatasetDict({
                    'train': dataset_train['train'],
                    'val': dataset_train['test'],
                    'test': dataset_test
                })

            # preprocess dataset
            log.info('Preprocessing dataset')
            dataset = dataset.map(self._preprocess_dataset)

            # save dataset
            log.info(f'Saving dataset to: {self.dataset_path}')
            dataset.save_to_disk(self.dataset_path)

        self.dataset = dataset
Beispiel #7
0
def load_training_data(
    data_dir: Path,
    dataset: str,
    tokenizer: PreTrainedTokenizer,
    pseudo_label_file: Optional[Path] = None,
    max_seq_length=128,
    with_dev=True,
    cache_dir: Optional[Path] = None,
    overwrite_cache=False,
) -> Tuple[hf_datasets.DatasetDict, Dict[str, int]]:
  """Load text data with minimal supervision. Returns train/dev labeled data and unlabeled
  data for inference."""

  # always load label map
  label_map = utils.load_label_map(data_dir, dataset)

  # load dataset features from cache or file
  # deal with cache
  use_cache = cache_dir is not None
  if use_cache and not Path(cache_dir).is_dir():
    Path(cache_dir).mkdir(parents=True)

  cache_file = Path(cache_dir, "train_dev.datasets.cache") if use_cache else None
  lock_file = cache_file.with_suffix(".lock") if use_cache else Path("/tmp/train_dev.dataset.lock")

  # Make sure only the first process in distributed training processes the dataset,
  # and the others will use the cache.
  with FileLock(lock_file):
    if use_cache and cache_file.is_file() and overwrite_cache:
      logger.info("Overwrite cache file!")
      cache_file.unlink()
    if use_cache and cache_file.is_dir() and overwrite_cache:
      logger.info("Overwrite cache dir!")
      shutil.rmtree(cache_file)
  try:
    torch.distributed.barrier()
  except:
    pass

  with FileLock(lock_file):
    if use_cache and (cache_file.is_file() or cache_file.is_dir()):
      logger.warning(f"######## Loading from cache {cache_file} ##########")
      s = time.time()
      datasets = hf_datasets.DatasetDict.load_from_disk(cache_file)
      e = time.time()
      logger.info(f"Time {e - s:.4f}")
    else:
      # load raw input data
      df_train_unlabeled = pd.read_csv(Path(data_dir, dataset, "train_docs.csv"), dtype="str").set_index("ID")
      df_train_labels = pd.read_csv(Path(data_dir, dataset, "train_labels.csv"), dtype="str").set_index("ID")
      if pseudo_label_file:
        df_pseudo_labels = pd.read_csv(pseudo_label_file, dtype="str").set_index("ID")
      df_dev = None
      if with_dev:
        try:
          df_dev = pd.read_csv(Path(data_dir, dataset, "dev_docs.csv"), dtype=str).set_index("ID")
        except FileNotFoundError:
          logger.warning("Try loading dev labels but not found!")

      # split and process data
      df = df_train_unlabeled.join(df_train_labels, rsuffix="_train", how="left")
      if pseudo_label_file:
        # if pseudo label file provided, use them as training set
        logger.info("Using pseudo label file")
        df = df_train_unlabeled.join(df_pseudo_labels, rsuffix="_pseudo", how="left")
        df_train = df[~df.label.isnull()]
      else:
        df_train = df[~df.label.isnull()]
      df_train_full = df_train_unlabeled
      # unlabeled set is always the original unlabeled set
      # which means it may overlap with pseudo labeled set
      df_unlabeled = df[df.label.isnull()]
      train_dataset = hf_datasets.Dataset.from_pandas(df_train, split="train")
      full_train_dataset = hf_datasets.Dataset.from_pandas(df_train_full, split="full_train")
      unlabeled_dataset = hf_datasets.Dataset.from_pandas(df_unlabeled[["text"]], split="unlabeled")

      process_func = get_process_dataset_func(label_map, tokenizer, max_seq_length)

      train_dataset = train_dataset.map(process_func, batched=True)
      full_train_dataset = full_train_dataset.map(process_func, batched=True)
      unlabeled_dataset = unlabeled_dataset.map(process_func, batched=True)

      dev_dataset = None
      if df_dev is not None:
        dev_dataset = hf_datasets.Dataset\
          .from_pandas(df_dev, split="dev")\
          .map(process_func,batched=True)

      datasets = hf_datasets.DatasetDict({
          "train": train_dataset,
          "full_train": full_train_dataset,
          "unlabeled": unlabeled_dataset,
      })
      if dev_dataset is not None:
        datasets["dev"] = dev_dataset

      if use_cache:
        logger.info(f"Saving dataset to cache {cache_file}")
        s = time.time()
        datasets.save_to_disk(cache_file)
        e = time.time()
        logger.info(f"Time {e - s:.4f}")

  # print info
  logger.info("Unlabeled dataset")
  logger.info(datasets["unlabeled"])
  logger.info("(Pseudo) Labeled dataset for training")
  logger.info(datasets["train"])
  logger.info("Full training document set")
  logger.info(datasets["full_train"])
  dev_dataset = datasets.get("dev", None)
  if with_dev and dev_dataset is not None:
    logger.info("Dev dataset")
    logger.info(dev_dataset)
  elif with_dev and dev_dataset is None:
    logger.warning("with_dev = True but dev dataset not found!")
  index = random.choice(range(len(datasets["unlabeled"])))
  logger.info(f"Sample {index} of the unlabeled set: {datasets['unlabeled'][index]}.")
  index = random.choice(range(len(datasets["train"])))
  logger.info(f"Sample {index} of the training set: {datasets['train'][index]}.")
  index = random.choice(range(len(datasets["full_train"])))
  logger.info(f"Sample {index} of the training set: {datasets['full_train'][index]}.")

  return datasets, label_map
Beispiel #8
0
def main():
    args = parse_args()

    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    # If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
    accelerator = Accelerator(
        log_with="all",
        logging_dir=args.output_dir) if args.with_tracking else Accelerator()
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name,
                                               token=args.hub_token)
            else:
                repo_name = args.hub_model_id
            repo = Repository(args.output_dir, clone_from=repo_name)

            with open(os.path.join(args.output_dir, ".gitignore"),
                      "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = datasets.DatasetDict({
            "train":
            datasets.Dataset.from_dict(
                load_dataset(args.dataset_name,
                             args.dataset_config_name)["train"][:args.n_train +
                                                                args.n_val])
        })
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                args.dataset_name,
                args.dataset_config_name,
                split=f"train[:{args.validation_split_percentage}%]",
            )
            raw_datasets["train"] = load_dataset(
                args.dataset_name,
                args.dataset_config_name,
                split=f"train[{args.validation_split_percentage}%:]",
            )
    else:
        data_files = {}
        dataset_args = {}
        if args.train_file is not None:
            data_files["train"] = args.train_file
        if args.validation_file is not None:
            data_files["validation"] = args.validation_file
        extension = args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
        raw_datasets = load_dataset(extension,
                                    data_files=data_files,
                                    **dataset_args)
        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{args.validation_split_percentage}%]",
                **dataset_args,
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{args.validation_split_percentage}%:]",
                **dataset_args,
            )

    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if args.config_name:
        config = AutoConfig.from_pretrained(args.config_name)
    elif args.model_name_or_path:
        config = AutoConfig.from_pretrained(args.model_name_or_path)
    else:
        config = CONFIG_MAPPING[args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
    elif args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if args.model_name_or_path:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    with accelerator.main_process_first():
        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )

    if args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > 1024:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can change that default value by passing --block_size xxx."
            )
        block_size = 1024
    else:
        if args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k:
            [t[i:i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    with accelerator.main_process_first():
        lm_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            load_from_cache_file=not args.overwrite_cache,
            desc=f"Grouping texts in chunks of {block_size}",
        )

    train_dataset = lm_datasets["train"]
    eval_dataset = lm_datasets["validation"]

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")

    # DataLoaders creation:
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  collate_fn=default_data_collator,
                                  batch_size=args.per_device_train_batch_size)
    eval_dataloader = DataLoader(eval_dataset,
                                 collate_fn=default_data_collator,
                                 batch_size=args.per_device_eval_batch_size)

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
    if accelerator.distributed_type == DistributedType.TPU:
        model.tie_weights()

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps /
                                          num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler)

    # Figure out how many steps we should save the Accelerator states
    if hasattr(args.checkpointing_steps, "isdigit"):
        checkpointing_steps = args.checkpointing_steps
        if args.checkpointing_steps.isdigit():
            checkpointing_steps = int(args.checkpointing_steps)
    else:
        checkpointing_steps = None

    # We need to initialize the trackers we use, and also store our configuration
    if args.with_tracking:
        experiment_config = vars(args)
        # TensorBoard cannot log Enums, need the raw value
        experiment_config["lr_scheduler_type"] = experiment_config[
            "lr_scheduler_type"].value
        accelerator.init_trackers("clm_no_trainer", experiment_config)

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(
        f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(
        f"  Total optimization steps = {int(args.max_train_steps/accelerator.num_processes)}"
    )
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(
        int(args.max_train_steps / accelerator.num_processes)),
                        disable=not accelerator.is_local_main_process)
    completed_steps = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
            accelerator.print(
                f"Resumed from checkpoint: {args.resume_from_checkpoint}")
            accelerator.load_state(args.resume_from_checkpoint)
            resume_step = None
            path = args.resume_from_checkpoint
        else:
            # Get the most recent checkpoint
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[
                -1]  # Sorts folders by date modified, most recent checkpoint is the last
        if "epoch" in path:
            args.num_train_epochs -= int(path.replace("epoch_", ""))
        else:
            resume_step = int(path.replace("step_", ""))
            args.num_train_epochs -= resume_step // len(train_dataloader)
            resume_step = (args.num_train_epochs *
                           len(train_dataloader)) - resume_step

    for epoch in range(args.num_train_epochs):
        model.train()
        if args.with_tracking:
            total_loss = 0
        for step, batch in enumerate(train_dataloader):
            # We need to skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == 0 and step < resume_step:
                continue
            outputs = model(**batch)
            loss = outputs.loss
            # We keep track of the loss at each epoch
            if args.with_tracking:
                total_loss += loss.detach().float()
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
            if step % args.gradient_accumulation_steps == 0 or step == len(
                    train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps}"
                    if args.output_dir is not None:
                        output_dir = os.path.join(args.output_dir, output_dir)
                    accelerator.save_state(output_dir)
            if completed_steps >= args.max_train_steps:
                break

        model.eval()
        losses = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(**batch)

            loss = outputs.loss
            losses.append(
                accelerator.gather(loss.repeat(
                    args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        losses = losses[:len(eval_dataset)]
        try:
            perplexity = math.exp(torch.mean(losses))
        except OverflowError:
            perplexity = float("inf")

        logger.info(f"epoch {epoch}: perplexity: {perplexity}")

        if args.with_tracking:
            accelerator.log(
                {
                    "perplexity": perplexity,
                    "train_loss": total_loss,
                    "epoch": epoch,
                    "step": completed_steps
                }, )

        if args.push_to_hub and epoch < args.num_train_epochs - 1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(args.output_dir,
                                            save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(args.output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress epoch {epoch}",
                    blocking=False,
                    auto_lfs_prune=True)

        if args.checkpointing_steps == "epoch":
            output_dir = f"epoch_{epoch}"
            if args.output_dir is not None:
                output_dir = os.path.join(args.output_dir, output_dir)
            accelerator.save_state(output_dir)

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(args.output_dir,
                                        save_function=accelerator.save)
        if accelerator.is_main_process:
            tokenizer.save_pretrained(args.output_dir)
            if args.push_to_hub:
                repo.push_to_hub(commit_message="End of training",
                                 auto_lfs_prune=True)

        with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
            json.dump({"perplexity": perplexity}, f)