Exemple #1
0
def build_dataloader(location, shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, continuous_iter=True, world_size=1, num_workers=1):
    size_dicts = {128: 64*8, 256: 32*8, 512: 16*8, 768: 8*8, 1024: 8*8}
    # TODO: num workers based on dataset size, only top 16 datasets get 2 workers, next 16 get 1 worker and rest are done in main process
    single_node = world_size == 1
    try:
        train_dataset = Dataset.load_from_disk(location)
        train_dataset = TokenizerDataset(config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), train_dataset)
        if num_workers > 0:
            train_loader = DataLoader(train_dataset, sampler=None if single_node else DistributedSampler(train_dataset, shuffle=shuffle_dataset), batch_size=8*8, collate_fn=None, prefetch_factor=4 if num_workers > 0 else None, num_workers=(2*num_workers) if single_node else num_workers)
        else:
            train_loader = DataLoader(train_dataset, sampler=None if single_node else DistributedSampler(train_dataset, shuffle=shuffle_dataset), batch_size=8*8,
                                      collate_fn=None,
                                      num_workers=(2 * num_workers) if single_node else num_workers)
        train_loader = custom_batching_fn(train_loader, size_dicts, continuous_iter)
    except:
        train_dataset = DatasetDict.load_from_disk(location)
        train_dataset = {k: v for k, v in train_dataset.items() if len(v) >= world_size}
        train_dataset_sampling_proba = {k: len(v) ** sampling_fraction for k, v in train_dataset.items()}
        lsum = sum(train_dataset_sampling_proba.values())
        train_dataset_sampling_proba = {k: v / lsum for k, v in train_dataset_sampling_proba.items()}
        train_dataset = {k: TokenizerDataset(config, tokenizer, char_to_id, dict(padding="max_length", truncation=True, return_tensors="pt", max_length=config.tokenizer_length), v) for k, v in train_dataset.items()}
        # for v in train_dataset.values():
        #     v.training = False
        if num_workers > 0:
            train_loader = {k: DataLoader(v, sampler=None if single_node else DistributedSampler(v, shuffle=shuffle_dataset, ), batch_size=8*8, collate_fn=collate_fn, prefetch_factor=2, num_workers=(2*num_workers) if single_node else num_workers) for k, v in train_dataset.items()}
        else:
            train_loader = {
                k: DataLoader(v, sampler=None if single_node else DistributedSampler(v, shuffle=shuffle_dataset, ), batch_size=8*8, collate_fn=collate_fn,
                              num_workers=(2 * num_workers) if single_node else num_workers) for k, v in train_dataset.items()}
        train_loader = {k: custom_batching_fn(dataloader, size_dicts, continuous_iter) for k, dataloader in train_loader.items()}
        train_loader = datadict_iterator(train_loader, train_dataset_sampling_proba)
    return train_loader
Exemple #2
0
 def load_eval_data(self, force_reload=False, save_datasets=True) -> None:
     eval_save_dir = self.save_dir / "eval"
     try:
         if force_reload:
             raise Exception()
         self.datasets["eval"] = DatasetDict.load_from_disk(eval_save_dir)
         print("Evaluation data loaded from disk.")
     except:
         print("Regenerating evaluation data.")
         eval_df_dict = self._parse_eval_data(self.eval_dir)
         self.datasets["eval"] = DatasetDict({
             "far":
             Dataset.from_pandas(eval_df_dict["far"]),
             "obj":
             Dataset.from_pandas(eval_df_dict["obj"]),
         })
         if save_datasets:
             print(f"Saving evaluation dataset to {eval_save_dir}")
             self.datasets["eval"].save_to_disk(eval_save_dir)
Exemple #3
0
 def load_training_data(self,
                        force_reload=False,
                        save_datasets=True) -> None:
     training_save_dir = self.save_dir / "training"
     try:
         if force_reload:
             raise Exception()
         # If the training data has already been save, load it from the save_directory
         self.datasets["train"] = DatasetDict.load_from_disk(
             training_save_dir)
         print("Training data loaded from disk.")
     except:
         # If it hasn't regenerate the training data.
         print("Regenerating training data.")
         train_df_dict = self._parse_training_data(self.train_dir)
         self.datasets["train"] = DatasetDict({
             "far":
             Dataset.from_pandas(train_df_dict["far"]),
             "obj":
             Dataset.from_pandas(train_df_dict["obj"]),
         })
         if save_datasets:
             print(f"Saving training dataset to {training_save_dir}")
             self.datasets["train"].save_to_disk(training_save_dir)
    batched=True,
    batch_size=256)
dataset_filtered.save_to_disk("/home/ahemf/processed/c4_256")

fmap = get_filter_mapper(448)
dataset_448 = dataset_filtered.map(fmap,
                                   batched=True,
                                   batch_size=256,
                                   remove_columns=['timestamp'])
dataset_448 = dataset_448.map(
    lambda x: dict(text=list(map(lambda y: clean_text(y), x["text"]))),
    batched=True,
    batch_size=256)
dataset_448.save_to_disk("/home/ahemf/processed/c4_448")

c4 = DatasetDict.load_from_disk("/home/ahemf/processed/c4_448")
dsets = Dataset.load_from_disk("/home/ahemf/processed/dsets_448")

c4['train'] = c4['train'].add_column('dataset', ['c4'] * len(c4['train']))
c4['train'] = c4['train'].remove_columns(['url', 'timestamp'])
c4['validation'] = c4['validation'].remove_columns(['url', 'timestamp'])
c4['validation'] = c4['validation'].add_column('dataset',
                                               ['c4'] * len(c4['validation']))

dataset_col = dsets['dataset']
dsets = dsets.remove_columns(["dataset"])
dsets = dsets.add_column("dataset", dataset_col)

c4["train"] = concatenate_datasets([c4["train"], dsets])
c4["train"].save_to_disk("/home/ahemf/processed/c4_extended")
Exemple #5
0
    def __call__(self):
        datadict = DatasetDict.load_from_disk(self.location)
        print("Loaded Validation Data")
        tokenizer = self.tokenizer
        model = self.model.to(self.device)
        model = model.eval()
        collate_fn = get_collate_fn(self.config.num_highway_cls_tokens, tokenizer.pad_token_id)
        results = dict()
        for k, dataset in datadict.items():
            cns = dataset.column_names
            predictions = []
            if 'answer' in cns:
                labels = [dataset[i] for i in range(len(dataset))]
            dataset = TokenizerDataset(self.config, tokenizer, char_to_id,
                                       dict(padding="max_length", truncation=True, return_tensors="pt", max_length=self.config.tokenizer_length),
                                       dataset)
            dataset.training = False
            record_accuracy = False
            if 'answer' not in cns:
                dataset.training = True
                record_accuracy = True
            loader = DataLoader(dataset, sampler=None, batch_size=8, collate_fn=collate_fn, prefetch_factor=2, num_workers=4)
            # loader = custom_batching_fn(loader, size_dicts, collate_fn, False)
            for pt_batch in loader:
                pt_batch["record_accuracy"] = record_accuracy
                pt_batch = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in pt_batch.items()}
                if 'answer' in cns:
                    with autocast():
                        with torch.no_grad():
                            funnel_inputs = dict(input_ids=pt_batch["input_ids"],
                                                 attention_mask=pt_batch["attention_mask"],
                                                 token_type_ids=pt_batch["token_type_ids"],
                                                 inputs_embeds=None,
                                                 char_ids=pt_batch["char_ids"], char_offsets=pt_batch["char_offsets"],
                                                 run_decoder=False,
                                                 run_answering=True)
                            output = model.module.module.backbone(**funnel_inputs)
                            answering_predictions = output["answering_logits"].argmax(dim=-1)
                            answering_predictions = answer_decoder(answering_predictions, tokenizer)
                            predictions.extend(answering_predictions)

                else:
                    labels = pt_batch["label_mlm_input_ids"] if "label_mlm_input_ids" in pt_batch else pt_batch["input_ids"]
                    labels = labels.to(self.device)
                    with autocast():
                        with torch.no_grad():
                            output = model(**pt_batch, labels=labels)["accuracy_hist"]
                            predictions.append(output)
            if 'answer' in cns:
                final_labels, final_predictions = [], []
                for lbl, prd in zip(labels, predictions):
                    if len(prd) > len(lbl):
                        prd = prd[:len(lbl)]
                    if len(prd) < len(lbl):
                        prd = prd + ([''] * (len(lbl) - len(prd)))
                    final_labels.extend(lbl)
                    final_predictions.extend(prd)
                score = accuracy_score(final_labels, final_predictions)
                results[k] = dict(accuracy=score)
            else:
                results[k] = pd.DataFrame.from_records(predictions).mean().to_dict()
        model = model.train()
        return results
     for name, param in model.named_parameters():
         if 'classifier' not in name: # only word embeddings
             param.requires_grad = False
     
 train_pipeline = HuggingFaceRoBERTaBase(tokenizer, 
                                         model, args.task_name, 
                                         TASK_CONFIG[args.task_name])
 logger.info(f"***** TASK NAME: {args.task_name} *****")
 # we use panda loader now, to make sure it is backward compatible
 # with our file writer.
 pd_format = True
 if args.inoculation_data_path.split(".")[-1] != "tsv":
     if len(args.inoculation_data_path.split(".")) > 1:
         logger.info(f"***** Loading pre-loaded datasets from the disk directly! *****")
         pd_format = False
         datasets = DatasetDict.load_from_disk(args.inoculation_data_path)
         inoculation_step_sample_size = int(len(datasets["train"]) * args.inoculation_step_sample_size)
         logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
         # this may not always start for zero inoculation
         training_args = generate_training_args(args, inoculation_step=inoculation_step_sample_size)
         datasets["train"] = datasets["train"].shuffle(seed=args.seed)
         inoculation_train_df = datasets["train"].select(range(inoculation_step_sample_size))
         eval_df = datasets["validation"]
         datasets["validation"] = datasets["validation"].shuffle(seed=args.seed)
         if args.eval_sample_limit != -1:
             datasets["validation"] = datasets["validation"].select(range(args.eval_sample_limit))
     else:
         logger.info(f"***** Loading downloaded huggingface datasets: {args.inoculation_data_path}! *****")
         pd_format = False
         if args.inoculation_data_path in ["sst3", "cola", "mnli", "snli", "mrps", "qnli"]:
             pass
Exemple #7
0
def build_dataloader(location,
                     shuffle_dataset,
                     sampling_fraction,
                     config,
                     collate_fn,
                     tokenizer,
                     size_dicts,
                     continuous_iter=True,
                     world_size=1,
                     num_workers=1):
    assert max(size_dicts.values()) % min(size_dicts.values()) == 0
    single_node = world_size == 1
    from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
    min_size = gcd_array(size_dicts.values())
    prefetch_factor = 2 * (max(size_dicts.values()) // min_size)
    try:
        train_dataset = Dataset.load_from_disk(location)
        train_dataset = TokenizerDataset(
            config, tokenizer, char_to_id,
            dict(padding="max_length",
                 truncation=True,
                 return_tensors="pt",
                 max_length=config.tokenizer_length), train_dataset)
        if num_workers > 0:
            train_loader = DataLoader(
                train_dataset,
                sampler=None if single_node else DistributedSampler(
                    train_dataset, shuffle=shuffle_dataset),
                batch_size=min_size,
                collate_fn=collate_fn,
                shuffle=shuffle_dataset and single_node,
                prefetch_factor=prefetch_factor,
                num_workers=num_workers,
                pin_memory=True)
        else:
            train_loader = DataLoader(
                train_dataset,
                sampler=None if single_node else DistributedSampler(
                    train_dataset, shuffle=shuffle_dataset),
                batch_size=min_size,
                collate_fn=collate_fn,
                shuffle=shuffle_dataset and single_node,
                num_workers=0,
                pin_memory=True)
        train_loader = custom_batching_fn(train_loader, size_dicts,
                                          continuous_iter)
    except:
        train_dataset = DatasetDict.load_from_disk(location)
        train_dataset = {
            k: v
            for k, v in train_dataset.items() if len(v) >= world_size
        }
        train_dataset_sampling_proba = {
            k: len(v)**sampling_fraction
            for k, v in train_dataset.items()
        }
        lsum = sum(train_dataset_sampling_proba.values())
        train_dataset_sampling_proba = {
            k: v / lsum
            for k, v in train_dataset_sampling_proba.items()
        }
        train_dataset = {
            k: TokenizerDataset(
                config, tokenizer, char_to_id,
                dict(padding="max_length",
                     truncation=True,
                     return_tensors="pt",
                     max_length=config.tokenizer_length), v)
            for k, v in train_dataset.items()
        }
        # for v in train_dataset.values():
        #     v.training = False
        if num_workers > 0:
            train_loader = {
                k:
                DataLoader(v,
                           sampler=None if single_node else DistributedSampler(
                               v,
                               shuffle=shuffle_dataset,
                           ),
                           shuffle=shuffle_dataset and single_node,
                           batch_size=min_size,
                           collate_fn=collate_fn,
                           prefetch_factor=prefetch_factor,
                           num_workers=num_workers)
                for k, v in train_dataset.items()
            }
        else:
            train_loader = {
                k:
                DataLoader(v,
                           sampler=None if single_node else DistributedSampler(
                               v,
                               shuffle=shuffle_dataset,
                           ),
                           shuffle=shuffle_dataset and single_node,
                           batch_size=min_size,
                           collate_fn=collate_fn,
                           num_workers=0)
                for k, v in train_dataset.items()
            }
        train_loader = {
            k: custom_batching_fn(dataloader, size_dicts, continuous_iter)
            for k, dataloader in train_loader.items()
        }
        train_loader = datadict_iterator(train_loader,
                                         train_dataset_sampling_proba)
    return train_loader
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Detecting last checkpoint.
    last_checkpoint = None
    if (os.path.isdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(
                training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome.")
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
        if is_main_process(training_args.local_rank) else logging.WARN,
    )

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
    # behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.

    if data_args.tokenized_dataset_dict_path is not None:
        datasets = DatasetDict.load_from_disk(
            dataset_dict_path=data_args.tokenized_dataset_dict_path)
        logger.info(f"Loaded tokenized datasets dict: {datasets}")
    elif data_args.dataset_dict_path is not None:
        datasets = DatasetDict.load_from_disk(
            dataset_dict_path=data_args.dataset_dict_path)
        logger.info(f"Loaded datasets dict: {datasets}")
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(data_args.dataset_name,
                                data_args.dataset_config_name)
        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name,
                                            **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path,
                                            **config_kwargs)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name,
                                                  **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = AutoModelForMaskedLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForMaskedLM.from_config(config)

    model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if data_args.tokenized_dataset_dict_path is not None:
        tokenized_datasets = datasets
        logger.info(f"Datasets are already tokenized")
    else:
        if training_args.do_train:
            column_names = datasets["train"].column_names
        else:
            column_names = datasets["validation"].column_names
        text_column_name = "text" if "text" in column_names else column_names[0]

        if data_args.line_by_line:
            # When using line_by_line, we just tokenize each nonempty line.
            padding = "max_length" if data_args.pad_to_max_length else False

            def tokenize_function(examples):
                # Remove empty lines
                examples["text"] = [
                    line for line in examples["text"]
                    if len(line) > 0 and not line.isspace()
                ]
                return tokenizer(
                    examples["text"],
                    padding=padding,
                    truncation=True,
                    max_length=data_args.max_seq_length,
                    # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
                    # receives the `special_tokens_mask`.
                    return_special_tokens_mask=True,
                )

            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=[text_column_name],
                load_from_cache_file=not data_args.overwrite_cache,
            )
        else:
            # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
            # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
            # efficient when it receives the `special_tokens_mask`.
            def tokenize_function(examples):
                return tokenizer(examples[text_column_name],
                                 return_special_tokens_mask=True)

            tokenized_datasets = datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
            )

            if data_args.max_seq_length is None:
                max_seq_length = tokenizer.model_max_length
                if max_seq_length > 1024:
                    logger.warn(
                        f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                        "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
                    )
                    max_seq_length = 1024
            else:
                if data_args.max_seq_length > tokenizer.model_max_length:
                    logger.warn(
                        f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
                        f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
                    )
                max_seq_length = min(data_args.max_seq_length,
                                     tokenizer.model_max_length)

            # Main data processing function that will concatenate all texts from our dataset and generate chunks of
            # max_seq_length.
            def group_texts(examples):
                # Concatenate all texts.
                concatenated_examples = {
                    k: sum(examples[k], [])
                    for k in examples.keys()
                }
                total_length = len(concatenated_examples[list(
                    examples.keys())[0]])
                # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
                # customize this part to your needs.
                total_length = (total_length //
                                max_seq_length) * max_seq_length
                # Split by chunks of max_len.
                result = {
                    k: [
                        t[i:i + max_seq_length]
                        for i in range(0, total_length, max_seq_length)
                    ]
                    for k, t in concatenated_examples.items()
                }
                return result

            # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
            # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
            # might be slower to preprocess.
            #
            # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
            tokenized_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
            )

    # Data collator
    # This one will take care of randomly masking the tokens.
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"]
        if training_args.do_train else None,
        eval_dataset=tokenized_datasets["validation"]
        if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Training
    if training_args.do_train:
        if last_checkpoint is not None:
            model_path = last_checkpoint
        elif model_args.model_name_or_path is not None and os.path.isdir(
                model_args.model_name_or_path):
            model_path = model_args.model_name_or_path
        else:
            model_path = None
        train_result = trainer.train(model_path=model_path)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        output_train_file = os.path.join(training_args.output_dir,
                                         "train_results.txt")
        if trainer.is_world_process_zero():
            with open(output_train_file, "w") as writer:
                logger.info("***** Train results *****")
                for key, value in sorted(train_result.metrics.items()):
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

            # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
            trainer.state.save_to_json(
                os.path.join(training_args.output_dir, "trainer_state.json"))

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        eval_output = trainer.evaluate()

        perplexity = math.exp(eval_output["eval_loss"])
        results["perplexity"] = perplexity

        output_eval_file = os.path.join(training_args.output_dir,
                                        "eval_results_mlm.txt")
        if trainer.is_world_process_zero():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in sorted(results.items()):
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

    return results