Exemple #1
0
    def __init__(self,
                 gpt2_pretrained_model="gpt2-medium",
                 gpt2_gpu_id=-1,
                 **kargs):
        """Initialize GPT2 model."""
        super(GPT2GrammarQualityMetric, self).__init__()

        logger.info("load gpt2 model.")
        self._tokenizer = GPT2TokenizerFast.from_pretrained(
            utils.get_transformers(gpt2_pretrained_model))
        if gpt2_gpu_id == -1:
            logger.warning("GPT2 metric is running on CPU.")
            self._device = torch.device("cpu")
        else:
            logger.info("GPT2 metric is running on GPU %d.", gpt2_gpu_id)
            self._device = torch.device("cuda:%d" % gpt2_gpu_id)
        self._model = GPT2LMHeadModel.from_pretrained(
            utils.get_transformers(gpt2_pretrained_model)).to(self._device)
    def __init__(self, dataset, model_init):
        self._tokenizer = BertTokenizer.from_pretrained(
            utils.get_transformers(model_init), do_lower_case="uncased" in model_init)

        self._glove = get_glove_emb()
        stopwords = get_stopwords()

        for word in stopwords:
            word = word.lower().strip()
            if word in self._glove["tok2id"]:
                self._glove["emb_table"][self._glove["tok2id"][word], :] = 0

        data = []
        logger.info("processing data for wordpiece embedding training")
        for item in tqdm.tqdm(dataset["data"]):
            text = item["text0"]
            if "text1" in item:
                text += " " + item["text1"]

            text_toks = word_tokenize(text)
            data += [x for x in text_toks if x.lower() in self._glove["tok2id"]]
        self._data = data
Exemple #3
0
def load_or_train_bert_clf(model_init, dataset_name, trainset, testset,
                           bert_clf_steps, bert_clf_bs, bert_clf_lr,
                           bert_clf_optimizer, bert_clf_weight_decay,
                           bert_clf_period_summary, bert_clf_period_val,
                           bert_clf_period_save, bert_clf_val_steps, device):
    """Train BERT classification model on a dataset.

    The trained model will be stored at ``<fibber_root_dir>/bert_clf/<dataset_name>/``. If there's
    a saved model, load and return the model. Otherwise, train the model using the given data.

    Args:
        model_init (str): pretrained model name. Choose from ``["bert-base-cased",
            "bert-base-uncased", "bert-large-cased", "bert-large-uncased"]``.
        dataset_name (str): the name of the dataset. This is also the dir to save trained model.
        trainset (dict): a fibber dataset.
        testset (dict): a fibber dataset.
        bert_clf_steps (int): steps to train a classifier.
        bert_clf_bs (int): the batch size.
        bert_clf_lr (float): the learning rate.
        bert_clf_optimizer (str): the optimizer name.
        bert_clf_weight_decay (float): the weight decay.
        bert_clf_period_summary (int): the period in steps to write training summary.
        bert_clf_period_val (int): the period in steps to run validation and write validation
            summary.
        bert_clf_period_save (int): the period in steps to save current model.
        bert_clf_val_steps (int): number of batched in each validation.
        device (torch.Device): the device to run the model.

    Returns:
        (transformers.BertForSequenceClassification): a torch BERT model.
    """
    model_dir = os.path.join(get_root_dir(), "bert_clf", dataset_name)
    ckpt_path = os.path.join(model_dir,
                             model_init + "-%04dk" % (bert_clf_steps // 1000))

    if os.path.exists(ckpt_path):
        logger.info("Load BERT classifier from %s.", ckpt_path)
        model = BertForSequenceClassification.from_pretrained(ckpt_path)
        model.eval()
        model.to(device)
        return model

    num_labels = len(trainset["label_mapping"])
    model = BertForSequenceClassification.from_pretrained(
        utils.get_transformers(model_init), num_labels=num_labels).to(device)
    model.train()

    logger.info("Use %s tokenizer and classifier.", model_init)
    logger.info("Num labels: %s", num_labels)

    summary = SummaryWriter(os.path.join(model_dir, "summary"))

    dataloader = torch.utils.data.DataLoader(DatasetForBert(
        trainset, model_init, bert_clf_bs),
                                             batch_size=None,
                                             num_workers=2)

    dataloader_val = torch.utils.data.DataLoader(DatasetForBert(
        testset, model_init, bert_clf_bs),
                                                 batch_size=None,
                                                 num_workers=1)
    dataloader_val_iter = iter(dataloader_val)

    params = model.parameters()

    opt, sche = get_optimizer(bert_clf_optimizer, bert_clf_lr,
                              bert_clf_weight_decay, bert_clf_steps, params)

    global_step = 0
    correct_train, count_train = 0, 0
    for seq, mask, tok_type, label in tqdm.tqdm(dataloader,
                                                total=bert_clf_steps):
        global_step += 1
        seq = seq.to(device)
        mask = mask.to(device)
        tok_type = tok_type.to(device)
        label = label.to(device)

        outputs = model(seq, mask, tok_type, labels=label)
        loss, logits = outputs[:2]

        count_train += seq.size(0)
        correct_train += (logits.argmax(
            dim=1).eq(label).float().sum().detach().cpu().numpy())

        opt.zero_grad()
        loss.backward()
        opt.step()
        sche.step()

        if global_step % bert_clf_period_summary == 0:
            summary.add_scalar("clf_train/loss", loss, global_step)
            summary.add_scalar("clf_train/error_rate",
                               1 - correct_train / count_train, global_step)
            correct_train, count_train = 0, 0

        if global_step % bert_clf_period_val == 0:
            run_evaluate(model, dataloader_val_iter, bert_clf_val_steps,
                         summary, global_step, device)

        if global_step % bert_clf_period_save == 0 or global_step == bert_clf_steps:
            ckpt_path = os.path.join(
                model_dir, model_init + "-%04dk" % (global_step // 1000))
            if not os.path.exists(ckpt_path):
                os.makedirs(ckpt_path)
            model.save_pretrained(ckpt_path)
            logger.info("BERT classifier saved at %s.", ckpt_path)

        if global_step >= bert_clf_steps:
            break
    model.eval()
    return model
Exemple #4
0
    def __init__(self,
                 dataset_name,
                 trainset,
                 testset,
                 bert_gpu_id=-1,
                 bert_clf_steps=20000,
                 bert_clf_bs=32,
                 bert_clf_lr=0.00002,
                 bert_clf_optimizer="adamw",
                 bert_clf_weight_decay=0.001,
                 bert_clf_period_summary=100,
                 bert_clf_period_val=500,
                 bert_clf_period_save=5000,
                 bert_clf_val_steps=10,
                 initial_model_with_trainset=False,
                 **kargs):

        super(BertClassifier, self).__init__()

        if trainset["cased"]:
            model_init = "bert-base-cased"
            logger.info(
                "Use cased model in BERT classifier prediction metric.")
        else:
            model_init = "bert-base-uncased"
            logger.info(
                "Use uncased model in BERT classifier prediction metric.")

        self._tokenizer = BertTokenizerFast.from_pretrained(
            utils.get_transformers(model_init),
            do_lower_case="uncased" in model_init)

        if bert_gpu_id == -1:
            logger.warning("BERT metric is running on CPU.")
            self._device = torch.device("cpu")
        else:
            logger.info("BERT metric is running on GPU %d.", bert_gpu_id)
            self._device = torch.device("cuda:%d" % bert_gpu_id)

        self._model_init = model_init
        self._dataset_name = dataset_name

        if initial_model_with_trainset:
            print("Initial model with trainset")
            self._model = load_or_train_bert_clf(
                model_init=model_init,
                dataset_name=dataset_name,
                trainset=trainset,
                testset=testset,
                bert_clf_steps=bert_clf_steps,
                bert_clf_lr=bert_clf_lr,
                bert_clf_bs=bert_clf_bs,
                bert_clf_optimizer=bert_clf_optimizer,
                bert_clf_weight_decay=bert_clf_weight_decay,
                bert_clf_period_summary=bert_clf_period_summary,
                bert_clf_period_val=bert_clf_period_val,
                bert_clf_period_save=bert_clf_period_save,
                bert_clf_val_steps=bert_clf_val_steps,
                device=self._device)
        else:
            print("Directly use the pre-trained model")
            self._model = BertForSequenceClassification.from_pretrained(
                utils.get_transformers(model_init),
                num_labels=len(trainset["label_mapping"])).to(self._device)

        self._fine_tune_sche = None
        self._fine_tune_opt = None

        self._bert_clf_bs = bert_clf_bs
        self._testset = testset
Exemple #5
0
        SETTINGS['EPOCHS'] = args.epochs
    if args.mini == "True":
        SETTINGS['NUM_CLASSES'] = 11
        SETTINGS['DATA_PATHS']['TRAIN_CSV'] = "mini_train.csv"
        SETTINGS['DATA_PATHS']['TEST_CSV'] = "mini_test.csv"
    if args.lr:
        SETTINGS["LR"] = args.lr
    if args.decay:
        SETTINGS["DECAY"] = args.decay
    SETTINGS["WLOSS"] = args.wloss == "True"
    SETTINGS["TRANSFORMER"] = args.transformer
    if args.batch:
        SETTINGS["BATCH_SIZE"] = args.batch

    TIME = get_current_time()

    # Load and transform data
    transform = get_transformers()[SETTINGS["TRANSFORMER"]]
    dataset = Loader(SETTINGS['DATA_PATHS']['TRAIN_CSV'],
                     SETTINGS['DATA_PATHS']['DATASET_PATH'],
                     transform=transform)
    test_dataset = Loader(SETTINGS['DATA_PATHS']['TEST_CSV'],
                          SETTINGS['DATA_PATHS']['DATASET_PATH'],
                          transform=transform)

    # Train k models and keep the best
    logging.info("Settings: {}".format(str(SETTINGS)))
    best_model = train(dataset)
    plot_loss(best_model, SETTINGS)
    test(best_model, test_dataset)
    def __init__(self,
                 dataset,
                 model_init,
                 batch_size,
                 label_num=54,
                 exclude=-1,
                 masked_lm=False,
                 masked_lm_ratio=0.2,
                 dynamic_masked_lm=False,
                 include_raw_text=False,
                 seed=0,
                 clf_type="multi_label_classify"):
        """Initialize.

        Args:
            dataset (dict): a dataset dict.
            model_init (str): the pre-trained model name. select from ``['bert-base-cased',
                'bert-base-uncased', 'bert-large-cased', and 'bert-large-uncased']``.
            batch_size (int): the batch size in each step.
            exclude (int): exclude one category from the data.
                Use -1 (default) to include all categories.
            masked_lm (bool): whether to randomly replace words with mask tokens.
            masked_lm_ratio (float): the ratio of random masks. Ignored when masked_lm is False.
            dynamic_masked_lm (bool): whether to generate dynamic masked language model. lm ratio
                will be randomly sampled. ``dynamic_masked_lm`` and ``masked_lm`` should not be
                set True at the same time.
            include_raw_text (bool): whether to return the raw text.
            seed: random seed.
        """
        self._buckets = [30, 50, 100, 200]
        self._max_len = self._buckets[-1]
        self._data = [[] for i in range(len(self._buckets))]

        self._batch_size = batch_size
        self._label_num = label_num
        self._tokenizer = BertTokenizerFast.from_pretrained(
            utils.get_transformers(model_init),
            do_lower_case="uncased" in model_init)

        self._seed = seed
        self._pad_tok_id = self._tokenizer.pad_token_id

        self._masked_lm = masked_lm
        self._masked_lm_ratio = masked_lm_ratio
        self._mask_tok_id = self._tokenizer.mask_token_id

        if dynamic_masked_lm and masked_lm:
            raise RuntimeError(
                "Cannot have dynamic_masked_lm and masked_lm both True.")

        self._dynamic_masked_lm = dynamic_masked_lm
        self._include_raw_text = include_raw_text

        self._clf_type = clf_type

        counter = 0
        logger.info("DatasetForBert is processing data.")

        if isinstance(dataset, list):
            load_data = dataset
        elif isinstance(dataset, dict):
            load_data = dataset["data"]

        for item in tqdm.tqdm(load_data):
            y = item["label"]
            s0 = "[CLS] " + item["text0"] + " [SEP]"
            if "text1" in item:
                s1 = item["text1"] + " [SEP]"
            else:
                s1 = ""

            if y == exclude:
                continue

            counter += 1

            s0_ids = self._tokenizer.convert_tokens_to_ids(
                self._tokenizer.tokenize(s0))
            s1_ids = self._tokenizer.convert_tokens_to_ids(
                self._tokenizer.tokenize(s1))
            text_ids = (s0_ids + s1_ids)[:self._max_len]

            for bucket_id in range(len(self._buckets)):
                if self._buckets[bucket_id] >= len(text_ids):
                    self._data[bucket_id].append(
                        (text_ids, y, len(s0_ids), len(s1_ids), s0 + s1))
                    break

        logger.info("Load %d documents. with filter %d.", counter, exclude)
        self._bucket_prob = np.asarray([len(x) for x in self._data])
        self._bucket_prob = self._bucket_prob / np.sum(self._bucket_prob)
Exemple #7
0
def get_lm(lm_option,
           output_dir,
           trainset,
           device,
           lm_steps=5000,
           lm_bs=32,
           lm_opt="adamw",
           lm_lr=0.0001,
           lm_decay=0.01,
           lm_period_summary=100,
           lm_period_save=5000,
           **kwargs):
    """Returns a BERT language model or a list of language models on a given dataset.

    The language model will be stored at ``<output_dir>/lm_all`` if lm_option is finetune.
    The language model will be stored at ``<output_dir>/lm_filter_?`` if lm_option is adv.

    If filter is not -1. The pretrained language model will first be pretrained on the while
    dataset, then it will be finetuned on the data excluding the filter category.

    The re

    Args:
        lm_option (str): choose from `["pretrain", "finetune", "adv", "nartune"]`.
            pretrain means the pretrained BERT model without fine-tuning on current
            dataset.
            finetune means fine-tuning the BERT model on current dataset.
            adv means adversarial tuning on current dataset.
            nartune means tuning the
        output_dir (str): a directory to store pretrained language model.
        trainset (DatasetForBert): the training set for finetune the language model.
        device (torch.Device): a device to train the model.
        lm_steps (int): finetuning steps.
        lm_bs (int): finetuning batch size.
        lm_opt (str): optimzer name. choose from ["sgd", "adam", "adamW"].
        lm_lr (float): learning rate.
        lm_decay (float): weight decay for the optimizer.
        lm_period_summary (int): number of steps to write training summary.
        lm_period_save (int): number of steps to save the finetuned model.
    Returns:
        (BertTokenizerFast): the tokenizer for the language model.
        (BertForMaskedLM): a finetuned language model if lm_option is pretrain or finetune.
        ([BertForMaskedLM]): a list of finetuned language model if lm_option is adv. The i-th
            language model in the list is fine-tuned on data not having label i.
    """
    if trainset["cased"]:
        model_init = "bert-base-cased"
    else:
        model_init = "bert-base-uncased"

    tokenizer = BertTokenizerFast.from_pretrained(
        utils.get_transformers(model_init),
        do_lower_case="uncased" in model_init)
    tokenizer.do_lower_case = True if "uncased" in model_init else False

    # tokenizer = BertTokenizer.from_pretrained("../datasets/bert-base-uncased")

    if lm_option == "pretrain":
        bert_lm = BertForMaskedLM.from_pretrained(
            utils.get_transformers(model_init))
        bert_lm.eval()
        for item in bert_lm.parameters():
            item.requires_grad = False

    elif lm_option == "finetune":
        bert_lm = fine_tune_lm(output_dir,
                               trainset,
                               -1,
                               device,
                               lm_steps=lm_steps,
                               lm_bs=lm_bs,
                               lm_opt=lm_opt,
                               lm_lr=lm_lr,
                               lm_decay=lm_decay,
                               lm_period_summary=lm_period_summary,
                               lm_period_save=lm_period_save)
        bert_lm.eval()
        for item in bert_lm.parameters():
            item.requires_grad = False

    elif lm_option == "adv":
        bert_lm = []
        for i in range(len(trainset["label_mapping"])):
            lm = fine_tune_lm(output_dir,
                              trainset,
                              i,
                              device,
                              lm_steps=lm_steps,
                              lm_bs=lm_bs,
                              lm_opt=lm_opt,
                              lm_lr=lm_lr,
                              lm_decay=lm_decay,
                              lm_period_summary=lm_period_summary,
                              lm_period_save=lm_period_save)
            lm.eval()
            for item in lm.parameters():
                item.requires_grad = False
            bert_lm.append(lm)

    elif lm_option == "nartune":
        bert_lm = non_autoregressive_fine_tune_lm(
            output_dir,
            trainset,
            -1,
            device,
            lm_steps=lm_steps,
            lm_bs=lm_bs,
            lm_opt=lm_opt,
            lm_lr=lm_lr,
            lm_decay=lm_decay,
            lm_period_summary=lm_period_summary,
            lm_period_save=lm_period_save,
            **kwargs)
        bert_lm.eval()
        for item in bert_lm.parameters():
            item.requires_grad = False
    else:
        raise RuntimeError("unsupported lm_option")

    return tokenizer, bert_lm
Exemple #8
0
def non_autoregressive_fine_tune_lm(output_dir,
                                    trainset,
                                    filter,
                                    device,
                                    use_metric,
                                    lm_steps=5000,
                                    lm_bs=32,
                                    lm_opt="adamw",
                                    lm_lr=0.0001,
                                    lm_decay=0.01,
                                    lm_period_summary=100,
                                    lm_period_save=5000):
    """Returns a finetuned BERT language model on a given dataset.

    The language model will be stored at ``<output_dir>/lm_all`` if filter is -1, or
    ``<output_dir>/lm_filter_?`` if filter is not -1.

    If filter is not -1. The pretrained langauge model will first be pretrained on the while
    dataset, then it will be finetuned on the data excluding the filter category.

    Args:
        output_dir (str): a directory to store pretrained language model.
        trainset (DatasetForBert): the training set for finetune the language model.
        filter (int): a category to exclude from finetuning.
        device (torch.Device): a device to train the model.
        use_metric (USESemanticSimilarityMetric): a sentence encoder metric
        lm_steps (int): finetuning steps.
        lm_bs (int): finetuning batch size.
        lm_opt (str): optimzer name. choose from ["sgd", "adam", "adamW"].
        lm_lr (float): learning rate.
        lm_decay (float): weight decay for the optimizer.
        lm_period_summary (int): number of steps to write training summary.
        lm_period_save (int): number of steps to save the finetuned model.
    Returns:
        (BertForMaskedLM): a finetuned language model.
    """
    if filter == -1:
        output_dir_t = os.path.join(output_dir, "narlm_all")
    else:
        output_dir_t = os.path.join(output_dir, "narlm_filter_%d" % filter)

    summary = SummaryWriter(output_dir_t + "/summary")

    if trainset["cased"]:
        model_init = "bert-base-cased"
    else:
        model_init = "bert-base-uncased"

    ckpt_path_pattern = output_dir_t + "/checkpoint-%04dk"
    ckpt_path = ckpt_path_pattern % (lm_steps // 1000)

    if os.path.exists(ckpt_path):
        logger.info("Language model <%s> exists.", ckpt_path)
        return NonAutoregressiveBertLM.from_pretrained(
            ckpt_path, sentence_embed_size=512).eval()

    if filter == -1:
        lm_model = NonAutoregressiveBertLM.from_pretrained(
            utils.get_transformers(model_init), sentence_embed_size=512)
        lm_model.train()
    else:
        lm_model = get_lm(output_dir, trainset, -1, device, lm_steps)
        lm_model.train()
    lm_model.to(device)

    dataset = DatasetForBert(trainset,
                             model_init,
                             lm_bs,
                             exclude=filter,
                             masked_lm=False,
                             dynamic_masked_lm=True,
                             include_raw_text=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=None,
                                             num_workers=2)

    params = list(lm_model.parameters())
    opt, sche = get_optimizer(lm_opt, lm_lr, lm_decay, lm_steps, params)

    global_step = 0
    stats = new_stats()
    for seq, mask, tok_type, label, lm_label, raw_text in tqdm.tqdm(
            dataloader, total=lm_steps):
        opt.zero_grad()

        global_step += 1

        raw_text = [
            x.replace("[CLS]", "").replace("[SEP]", "") for x in raw_text
        ]
        sentence_embs = use_metric.model(raw_text).numpy()

        seq = seq.to(device)
        mask = mask.to(device)
        tok_type = tok_type.to(device)
        label = label.to(device)
        lm_label = lm_label.to(device)

        # print("seq", seq.size())
        # print("mask", mask.size())
        # print("toktype", tok_type.size())
        # print("label", label.size())
        # print("lmlabel", lm_label.size())

        lm_loss = compute_non_autoregressive_lm_loss(
            lm_model=lm_model,
            sentence_embeds=sentence_embs,
            seq=seq,
            mask=mask,
            tok_type=tok_type,
            lm_label=lm_label,
            stats=stats)
        lm_loss.backward()

        opt.step()
        sche.step()

        if global_step % lm_period_summary == 0:
            write_summary(stats, summary, global_step)
            stats = new_stats()

        if global_step % lm_period_save == 0 or global_step == lm_steps:
            lm_model.to(torch.device("cpu")).eval()
            lm_model.save_pretrained(ckpt_path_pattern % (global_step // 1000))
            lm_model.to(device)

        if global_step >= lm_steps:
            break

    lm_model.eval()
    lm_model.to(torch.device("cpu"))
    return lm_model
Exemple #9
0
def fine_tune_lm(output_dir,
                 trainset,
                 filter,
                 device,
                 lm_steps=5000,
                 lm_bs=32,
                 lm_opt="adamw",
                 lm_lr=0.0001,
                 lm_decay=0.01,
                 lm_period_summary=100,
                 lm_period_save=5000):
    """Returns a finetuned BERT language model on a given dataset.

    The language model will be stored at ``<output_dir>/lm_all`` if filter is -1, or
    ``<output_dir>/lm_filter_?`` if filter is not -1.

    If filter is not -1. The pretrained langauge model will first be pretrained on the while
    dataset, then it will be finetuned on the data excluding the filter category.

    Args:
        output_dir (str): a directory to store pretrained language model.
        trainset (DatasetForBert): the training set for finetune the language model.
        filter (int): a category to exclude from finetuning.
        device (torch.Device): a device to train the model.
        lm_steps (int): finetuning steps.
        lm_bs (int): finetuning batch size.
        lm_opt (str): optimzer name. choose from ["sgd", "adam", "adamW"].
        lm_lr (float): learning rate.
        lm_decay (float): weight decay for the optimizer.
        lm_period_summary (int): number of steps to write training summary.
        lm_period_save (int): number of steps to save the finetuned model.
    Returns:
        (BertForMaskedLM): a finetuned language model.
    """

    print("Active Learning Procedure:")

    if filter == -1:
        output_dir_t = os.path.join(output_dir, "lm_all")
    else:
        output_dir_t = os.path.join(output_dir, "lm_filter_%d" % filter)

    summary = SummaryWriter(output_dir_t + "/summary")

    if trainset["cased"]:
        model_init = "bert-base-cased"
    else:
        model_init = "bert-base-uncased"

    ckpt_path_pattern = output_dir_t + "/checkpoint-%04dk"
    ckpt_path = ckpt_path_pattern % (lm_steps // 1000)

    if os.path.exists(ckpt_path):
        logger.info("Language model <%s> exists.", ckpt_path)
        return BertForMaskedLM.from_pretrained(ckpt_path).eval()
    if filter == -1:
        lm_model = BertForMaskedLM.from_pretrained(
            utils.get_transformers(model_init))
        lm_model.train()
    else:
        lm_model = get_lm(output_dir, trainset, -1, device, lm_steps)
        lm_model.train()
    lm_model.to(device)

    dataset = DatasetForBert(trainset,
                             model_init,
                             lm_bs,
                             exclude=filter,
                             masked_lm=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=None,
                                             num_workers=2)

    params = list(lm_model.parameters())
    opt, sche = get_optimizer(lm_opt, lm_lr, lm_decay, lm_steps, params)

    global_step = 0
    stats = new_stats()
    for seq, mask, tok_type, lm_label in tqdm.tqdm(  #label,
            dataloader, total=lm_steps):
        opt.zero_grad()

        global_step += 1

        seq = seq.to(device)
        mask = mask.to(device)
        tok_type = tok_type.to(device)
        # label = label.to(device)
        lm_label = lm_label.to(device)

        lm_loss = compute_lm_loss(lm_model, seq, mask, tok_type, lm_label,
                                  stats)
        lm_loss.backward()

        opt.step()
        sche.step()

        if global_step % lm_period_summary == 0:
            write_summary(stats, summary, global_step)
            stats = new_stats()

        if global_step % lm_period_save == 0 or global_step == lm_steps:
            lm_model.to(torch.device("cpu")).eval()
            print(ckpt_path_pattern % (global_step // 1000))
            lm_model.save_pretrained(ckpt_path_pattern % (global_step // 1000))
            lm_model.to(device)

        if global_step >= lm_steps:
            break

    lm_model.eval()
    lm_model.to(torch.device("cpu"))
    return lm_model