Esempio n. 1
0
    def test_model_for_masked_lm(self):
        for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = TFAutoModelForMaskedLM.from_pretrained(model_name)
            model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, TFBertForMaskedLM)
Esempio n. 2
0
    def vectorize(self, list_of_texts):
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.model = TFAutoModelForMaskedLM.from_pretrained(
            "bert-base-uncased", output_hidden_states=True)

        BERT_vectors = get_BERT_vectors(list_of_texts, self.model,
                                        self.tokenizer)

        return BERT_vectors
Esempio n. 3
0
def build_model():
    print(f"Using pretrained {pretrain_model}")
    model = TFAutoModelForMaskedLM.from_pretrained(pretrain_model)
    optimizer = tf.keras.optimizers.Adam(lr=5e-5)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    tf.keras.metrics.SparseCategoricalAccuracy("accuracy")

    model.compile(optimizer, loss=loss, metrics=["accuracy"])
    return model
Esempio n. 4
0
def train_language_model(model_name, data):

    if model_name == 'bt':
        pt_path = "vinai/bertweet-base"
    elif model_name == 'rob':
        pt_path = "roberta-base"

    tokenizer = AutoTokenizer.from_pretrained(pt_path)
    model = TFAutoModelForMaskedLM.from_pretrained(pt_path, return_dict=True)

    text = pd.read_csv(data).to_numpy()[:, 1]
    text, labels = preprocessing(text, 70, tokenizer)

    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    optimizer = tf.keras.optimizers.Adam(1e-5)

    batch_size = 32
    history = []
    for i in range(4):

        batch_splits = splits_batch(batch_size, len(text))
        loss = 0
        perc = 0

        for j in range(len(batch_splits)):
            with tf.GradientTape() as g:

                if j * 100.0 / len(batch_splits) - perc >= 1:
                    perc = j * 100.0 / len(batch_splits)
                    print('\r Epoch:{} setp {} of {}. {}%'.format(
                        i + 1, j, len(batch_splits), np.round(perc,
                                                              decimals=2)),
                          end="")

                out = model(text[batch_splits[j], :])
                loss_value = loss_object(y_true=labels[batch_splits[j], :],
                                         y_pred=out.logits)
            gradients = g.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model.trainable_variables))
            loss += loss_value.numpy() / len(batch_splits)
        history.append(loss)
        print('\repoch: {} Loss: {}'.format(i + 1, loss))

    plt.plot(history)
    plt.legend('train', loc='upper left')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.show()
    os.system(f'mkdir -p ../data/lm_fine_tunning_{model_name}')
    model.save_pretrained(f'../data/lm_fine_tunning_{model_name}')
Esempio n. 5
0
def main():
    # region Argument Parsing
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TFTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sanity checks
    if data_args.dataset_name is None and data_args.train_file is None and data_args.validation_file is None:
        raise ValueError(
            "Need either a dataset name or a training/validation file.")
    else:
        if data_args.train_file is not None:
            extension = data_args.train_file.split(".")[-1]
            assert extension in [
                "csv", "json", "txt"
            ], "`train_file` should be a csv, json or txt file."
        if data_args.validation_file is not None:
            extension = data_args.validation_file.split(".")[-1]
            assert extension in [
                "csv", "json", "txt"
            ], "`validation_file` should be a csv, json or txt file."

    if training_args.output_dir is not None:
        training_args.output_dir = Path(training_args.output_dir)
        os.makedirs(training_args.output_dir, exist_ok=True)

    if isinstance(
            training_args.strategy,
            tf.distribute.TPUStrategy) and not data_args.pad_to_max_length:
        logger.warning("We are training on TPU - forcing pad_to_max_length")
        data_args.pad_to_max_length = True
    # endregion

    # region Checkpoints
    # Detecting last checkpoint.
    checkpoint = None
    if len(os.listdir(training_args.output_dir)
           ) > 0 and not training_args.overwrite_output_dir:
        config_path = training_args.output_dir / CONFIG_NAME
        weights_path = training_args.output_dir / TF2_WEIGHTS_NAME
        if config_path.is_file() and weights_path.is_file():
            checkpoint = training_args.output_dir
            logger.warning(
                f"Checkpoint detected, resuming training from checkpoint in {training_args.output_dir}. To avoid this"
                " behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )
        else:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to continue regardless.")

    # endregion

    # region Setup logging
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(logging.INFO)
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()
    # endregion

    # If passed along, set the training seed now.
    if training_args.seed is not None:
        set_seed(training_args.seed)

    # region Load datasets
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(data_args.dataset_name,
                                    data_args.dataset_config_name)
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
            )
            raw_datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        raw_datasets = load_dataset(extension, data_files=data_files)

    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.
    # endregion

    # region Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if checkpoint is not None:
        config = AutoConfig.from_pretrained(checkpoint)
    elif model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )
    # endregion

    # region Dataset preprocessing
    # First we tokenize all the texts.
    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    if data_args.max_seq_length is None:
        max_seq_length = tokenizer.model_max_length
        if max_seq_length > 1024:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                "Picking 1024 instead. You can reduce that default value by passing --max_seq_length xxx."
            )
            max_seq_length = 1024
    else:
        if data_args.max_seq_length > tokenizer.model_max_length:
            logger.warning(
                f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
                f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
            )
        max_seq_length = min(data_args.max_seq_length,
                             tokenizer.model_max_length)

    if data_args.line_by_line:
        # When using line_by_line, we just tokenize each nonempty line.
        padding = "max_length" if data_args.pad_to_max_length else False

        def tokenize_function(examples):
            # Remove empty lines
            examples[text_column_name] = [
                line for line in examples[text_column_name]
                if len(line) > 0 and not line.isspace()
            ]
            return tokenizer(
                examples[text_column_name],
                padding=padding,
                truncation=True,
                max_length=max_seq_length,
                # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
                # receives the `special_tokens_mask`.
                return_special_tokens_mask=True,
            )

        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=[text_column_name],
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset line_by_line",
        )
    else:
        # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
        # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
        # efficient when it receives the `special_tokens_mask`.
        def tokenize_function(examples):
            return tokenizer(examples[text_column_name],
                             return_special_tokens_mask=True)

        tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on every text in dataset",
        )

        # Main data processing function that will concatenate all texts from our dataset and generate chunks of
        # max_seq_length.
        def group_texts(examples):
            # Concatenate all texts.
            concatenated_examples = {
                k: list(chain(*examples[k]))
                for k in examples.keys()
            }
            total_length = len(concatenated_examples[list(examples.keys())[0]])
            # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
            if total_length >= max_seq_length:
                total_length = (total_length //
                                max_seq_length) * max_seq_length
            # Split by chunks of max_len.
            result = {
                k: [
                    t[i:i + max_seq_length]
                    for i in range(0, total_length, max_seq_length)
                ]
                for k, t in concatenated_examples.items()
            }
            return result

        # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
        # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
        # might be slower to preprocess.
        #
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
        # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

        tokenized_datasets = tokenized_datasets.map(
            group_texts,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            load_from_cache_file=not data_args.overwrite_cache,
            desc=f"Grouping texts in chunks of {max_seq_length}",
        )

    train_dataset = tokenized_datasets["train"]

    if data_args.validation_file is not None:
        eval_dataset = tokenized_datasets["validation"]
    else:
        logger.info(
            f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
        )
        train_indices, val_indices = train_test_split(
            list(range(len(train_dataset))),
            test_size=data_args.validation_split_percentage / 100)

        eval_dataset = train_dataset.select(val_indices)
        train_dataset = train_dataset.select(train_indices)

    if data_args.max_train_samples is not None:
        max_train_samples = min(len(train_dataset),
                                data_args.max_train_samples)
        train_dataset = train_dataset.select(range(max_train_samples))
    if data_args.max_eval_samples is not None:
        max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
        eval_dataset = eval_dataset.select(range(max_eval_samples))

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(
            f"Sample {index} of the training set: {train_dataset[index]}.")
    # endregion

    with training_args.strategy.scope():
        # region Prepare model
        if checkpoint is not None:
            model = TFAutoModelForMaskedLM.from_pretrained(checkpoint,
                                                           config=config)
        elif model_args.model_name_or_path:
            model = TFAutoModelForMaskedLM.from_pretrained(
                model_args.model_name_or_path, config=config)
        else:
            logger.info("Training new model from scratch")
            model = TFAutoModelForMaskedLM.from_config(config)

        model.resize_token_embeddings(len(tokenizer))
        # endregion

        # region TF Dataset preparation
        num_replicas = training_args.strategy.num_replicas_in_sync
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm_probability=data_args.mlm_probability,
            return_tensors="tf")
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

        tf_train_dataset = train_dataset.to_tf_dataset(
            # labels are passed as input, as we will use the model's internal loss
            columns=[
                col for col in train_dataset.features
                if col != "special_tokens_mask"
            ] + ["labels"],
            shuffle=True,
            batch_size=num_replicas *
            training_args.per_device_train_batch_size,
            collate_fn=data_collator,
            drop_remainder=True,
        ).with_options(options)

        tf_eval_dataset = eval_dataset.to_tf_dataset(
            # labels are passed as input, as we will use the model's internal loss
            columns=[
                col for col in eval_dataset.features
                if col != "special_tokens_mask"
            ] + ["labels"],
            shuffle=False,
            batch_size=num_replicas *
            training_args.per_device_train_batch_size,
            collate_fn=data_collator,
            drop_remainder=True,
        ).with_options(options)
        # endregion

        # region Optimizer and loss
        batches_per_epoch = len(train_dataset) // (
            num_replicas * training_args.per_device_train_batch_size)
        # Bias and layernorm weights are automatically excluded from the decay
        optimizer, lr_schedule = create_optimizer(
            init_lr=training_args.learning_rate,
            num_train_steps=int(training_args.num_train_epochs *
                                batches_per_epoch),
            num_warmup_steps=training_args.warmup_steps,
            adam_beta1=training_args.adam_beta1,
            adam_beta2=training_args.adam_beta2,
            adam_epsilon=training_args.adam_epsilon,
            weight_decay_rate=training_args.weight_decay,
        )

        # no user-specified loss = will use the model internal loss
        model.compile(optimizer=optimizer)
        # endregion

        # region Training and validation
        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {len(train_dataset)}")
        logger.info(f"  Num Epochs = {training_args.num_train_epochs}")
        logger.info(
            f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
        )
        logger.info(
            f"  Total train batch size = {training_args.per_device_train_batch_size * num_replicas}"
        )

        history = model.fit(
            tf_train_dataset,
            validation_data=tf_eval_dataset,
            epochs=int(training_args.num_train_epochs),
            steps_per_epoch=len(train_dataset) //
            (training_args.per_device_train_batch_size * num_replicas),
            callbacks=[
                SavePretrainedCallback(output_dir=training_args.output_dir)
            ],
        )
        try:
            train_perplexity = math.exp(history.history["loss"][-1])
        except OverflowError:
            train_perplexity = math.inf
        try:
            validation_perplexity = math.exp(history.history["val_loss"][-1])
        except OverflowError:
            validation_perplexity = math.inf
        logger.warning(
            f"  Final train loss: {history.history['loss'][-1]:.3f}")
        logger.warning(f"  Final train perplexity: {train_perplexity:.3f}")
        logger.warning(
            f"  Final validation loss: {history.history['val_loss'][-1]:.3f}")
        logger.warning(
            f"  Final validation perplexity: {validation_perplexity:.3f}")
        # endregion

        if training_args.output_dir is not None:
            model.save_pretrained(training_args.output_dir)

    if training_args.push_to_hub:
        # You'll probably want to append some of your own metadata here!
        model.push_to_hub()
Esempio n. 6
0
def main() -> NoReturn:
    parser = ArgumentParser(description="执行器")
    parser.add_argument("--act", default="preprocess", type=str, required=False, help="执行模式")
    parser.add_argument("--vocab_size", default=6932, type=int, required=False, help="词汇量大小")
    parser.add_argument("--epochs", default=50, type=int, required=False, help="训练周期")
    parser.add_argument("--num_layers", default=12, type=int, required=False, help="block层数")
    parser.add_argument("--units", default=1024, type=int, required=False, help="单元数")
    parser.add_argument("--first_kernel_size", default=3, type=int, required=False, help="第一个卷积核大小")
    parser.add_argument("--second_kernel_size", default=3, type=int, required=False, help="第二个卷积核大小")
    parser.add_argument("--first_strides_size", default=3, type=int, required=False, help="第一个卷积步幅大小")
    parser.add_argument("--second_strides_size", default=3, type=int, required=False, help="第二个卷积步幅大小")
    parser.add_argument("--first_output_dim", default=32, type=int, required=False, help="第一个卷积输出通道数")
    parser.add_argument("--second_output_dim", default=16, type=int, required=False, help="第二个卷积输出通道数")
    parser.add_argument("--embedding_dim", default=768, type=int, required=False, help="词嵌入大小")
    parser.add_argument("--num_heads", default=12, type=int, required=False, help="注意力头数")
    parser.add_argument("--dropout", default=0.1, type=float, required=False, help="采样率")
    parser.add_argument("--batch_size", default=32, type=int, required=False, help="batch大小")
    parser.add_argument("--buffer_size", default=100000, type=int, required=False, help="缓冲区大小")
    parser.add_argument("--max_sentence_length", default=32, type=int, required=False, help="最大句子序列长度")
    parser.add_argument("--checkpoint_save_size", default=10, type=int, required=False, help="最大保存检查点数量")
    parser.add_argument("--train_data_size", default=0, type=int, required=False, help="训练数据大小")
    parser.add_argument("--valid_data_size", default=0, type=int, required=False, help="验证数据大小")
    parser.add_argument("--max_train_steps", default=-1, type=int, required=False, help="最大训练数据量,-1为全部")
    parser.add_argument("--checkpoint_save_freq", default=1, type=int, required=False, help="检查点保存频率")
    parser.add_argument("--data_dir", default="./tcdata/", type=str, required=False, help="原始数据保存目录")
    parser.add_argument("--raw_train_data_path", default="./tcdata/gaiic_track3_round1_train_20210228.tsv", type=str,
                        required=False, help="原始训练数据相对路径")
    parser.add_argument("--raw_test_data_path", default="./tcdata/gaiic_track3_round1_testA_20210228.tsv", type=str,
                        required=False, help="原始测试数据相对路径")
    parser.add_argument("--train_data_path", default="./user_data/train.tsv", type=str, required=False, help="训练数据相对路径")
    parser.add_argument("--valid_data_path", default="./user_data/valid.tsv", type=str, required=False, help="验证数据相对路径")
    parser.add_argument("--train_record_data_path", default="./user_data/train.tfrecord", type=str, required=False,
                        help="训练数据的TFRecord格式保存相对路径")
    parser.add_argument("--valid_record_data_path", default="./user_data/valid.tfrecord", type=str, required=False,
                        help="验证数据的TFRecord格式保存相对路径")
    parser.add_argument("--test_record_data_path", default="./user_data/test.tfrecord", type=str, required=False,
                        help="测试数据的TFRecord格式保存相对路径")
    parser.add_argument("--checkpoint_dir", default="./user_data/checkpointv1/", type=str, required=False,
                        help="验证数据的TFRecord格式保存相对路径")
    parser.add_argument("--result_save_path", default="./user_data/result.tsv", type=str, required=False,
                        help="测试数据的结果文件")
    parser.add_argument("--config_path", default="./tcdata/bert/config.json", type=str, required=False,
                        help="配置文件路径")
    parser.add_argument("--bert_path", default="./tcdata/bert/tf_model.h5", type=str, required=False,
                        help="Bert路径")
    parser.add_argument("--dict_path", default="./tcdata/bert/vocab.txt", type=str, required=False, help="字典保存路径")

    options = parser.parse_args()
    # bert_model = model(vocab_size=)
    # model_path = "../tcdata/bert/"
    # tokenizer = BertTokenizer.from_pretrained("../tcdata/bert/vocab.txt")
    model_config = BertConfig.from_pretrained("./tcdata/bert/config.json")
    model_config.output_attentions = False
    model_config.output_hidden_states = False
    model_config.use_cache = False
    # model = TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path=model_path, from_pt=False,
    #                                           config=model_config, cache_dir="../user_data/temp")
    # model.resize_token_embeddings(len(tokenizer))

    # tokenizer = AutoTokenizer.from_pretrained("./tcdata/bert")
    bert = TFAutoModelForMaskedLM.from_pretrained("./tcdata/bert", config=model_config, cache_dir="../user_data/temp")
    # token = tokenizer.encode("生活的真谛是[MASK]。")
    # print(tokenizer.decode(token))
    # input = tf.convert_to_tensor([token])
    # print(input)
    # outputs = bert(input)[0]
    # print(tokenizer.decode(tf.argmax(outputs[0],axis=-1)))
    # exit(0)

    # model_config = BertConfig.from_pretrained(options.config_path)
    # bert = TFBertModel.from_pretrained(pretrained_model_name_or_path=options.bert_path, from_pt=False,
    #                                    config=model_config, cache_dir="../user_data/temp")
    # bert.resize_token_embeddings(new_num_tokens=options.vocab_size)

    model = bert_model(vocab_size=options.vocab_size, bert=bert)

    checkpoint_manager = load_checkpoint(checkpoint_dir=options.checkpoint_dir, execute_type=options.act,
                                         checkpoint_save_size=options.checkpoint_save_size, model=model)

    if options.act == "train":
        history = train(
            model=model, checkpoint=checkpoint_manager, batch_size=options.batch_size, buffer_size=options.buffer_size,
            epochs=options.epochs, train_data_path=options.raw_train_data_path,
            test_data_path=options.raw_test_data_path, dict_path=options.dict_path,
            max_sentence_length=options.max_sentence_length, checkpoint_save_freq=options.checkpoint_save_freq
        )
    elif options.act == "evaluate":
        pass
    elif options.act == "inference":
        pass
    else:
        parser.error(message="")
Esempio n. 7
0
 def load_from_disk(self, target_path):
     self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
     self.model = TFAutoModelForMaskedLM.from_pretrained(
         "bert-base-uncased", output_hidden_states=True)
Esempio n. 8
0
# In[ ]:

import tensorflow as tf
import transformers
from transformers import AutoTokenizer, TFAutoModelForMaskedLM

# In[ ]:

MAX_LEN = 128
BATCH_SIZE = 16

# In[ ]:

pretrain_model = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(pretrain_model)
model = TFAutoModelForMaskedLM.from_pretrained(pretrain_model)

# In[ ]:


def build_model():
    print(f"Using pretrained {pretrain_model}")
    model = TFAutoModelForMaskedLM.from_pretrained(pretrain_model)
    optimizer = tf.keras.optimizers.Adam(lr=5e-5)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    tf.keras.metrics.SparseCategoricalAccuracy("accuracy")

    model.compile(optimizer, loss=loss, metrics=["accuracy"])
    return model