Пример #1
0
    def build_gpt2(cls, args, tgt_dict):
        model = GPT2LMHeadModel.from_pretrained(args.gpt2_name)

        return model
Пример #2
0
def collate_fn(batch_samples):
    max_len = max([len(s) for s in batch_samples])
    batch = [s + [-1] * (max_len - len(s)) for s in batch_samples]
    tensor = torch.tensor(batch, dtype=torch.long)
    inputs, labels = tensor[:, :-1], tensor[:, 1:]
    inputs[inputs == -1] = pad_id
    return inputs.to(dev), labels.to(dev)


def cal_ce_loss(logits, labels):
    crit = nn.CrossEntropyLoss(ignore_index=-1)
    return crit(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))


lm_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(dev)
lm_model.load_state_dict(torch.load(f"../dump/eval_lm_{ds}.pth"))
lm_model.eval()

sentences = [l.strip().split("\t")[1].strip() for l in open(hyp_file, 'r')]
test_loader = DataLoader(dataset=LMDataset(sentences),
                         batch_size=batch_size,
                         shuffle=False,
                         collate_fn=collate_fn)

losses, counts = [], []
for inp, labels in test_loader:
    with torch.no_grad():
        logits = lm_model(input_ids=inp)[0]
    mask = (labels >= 0).float()  # B, L
    loss_bat_avg = cal_ce_loss(logits, labels)
Пример #3
0
 def __init__(self, pretrained_model):
     super().__init__()
     self.model = GPT2LMHeadModel.from_pretrained(pretrained_model)
Пример #4
0
from transformers import GPT2LMHeadModel, GPT2Config
from new_tokenizer import MyTokenizer
import torch
import kss

vocab_file_path = '../tokenizer/vocab.json'
merge_file_path = '../tokenizer/merges.txt'

answer_tokenizer = MyTokenizer(vocab_file_path, merge_file_path)
question_tokenizer = MyTokenizer(vocab_file_path, merge_file_path)

answer_config = GPT2Config(vocab_size=52004)
question_config = GPT2Config(vocab_size=52005)

answer_model = GPT2LMHeadModel(answer_config)
question_model = GPT2LMHeadModel(question_config)

answer_model_dir = '../KorGPT-2SampleModel/answer_model.bin'
question_model_dir = '../KorGPT-2SampleModel/question_model.bin'

answer_model.load_state_dict(torch.load(answer_model_dir), strict=False)
question_model.load_state_dict(torch.load(question_model_dir), strict=False)

answer_model.to('cpu')
question_model.to('cpu')


def add_special_tokens_(model, tokenizer, added_tokens):
    orig_num_tokens = tokenizer.get_vocab_size()
    tokenizer.add_special_tokens(added_tokens)
Пример #5
0
with T.no_grad():
    dialog_act_classifier = Classifier(D=bot_queries_embd.shape[-1],
                                       classes_num=len(labels2idx)).cuda()
    checkpoint = T.load("Classifier/Model_Backup/model.pt")
    dialog_act_classifier.load_state_dict(checkpoint['model_state_dict'])
    dialog_act_classifier = dialog_act_classifier.eval()

# LOAD DialoGPT Generator

with T.no_grad():
    tokenizer = GPT2Tokenizer.from_pretrained('Generator/DialoGPT/Configs/')
    weights = T.load('Generator/DialoGPT/Parameters/medium_ft.pkl')
    weights_reverse = T.load('Generator/DialoGPT/Parameters/small_reverse.pkl')
    cfg = GPT2Config.from_json_file('Generator/DialoGPT/Configs/config.json')
    model = GPT2LMHeadModel(cfg)
    model_reverse = GPT2LMHeadModel(cfg)

    # fix misused key value
    weights["lm_head.weight"] = weights["lm_head.decoder.weight"]
    weights.pop("lm_head.decoder.weight", None)
    weights_reverse["lm_head.weight"] = weights_reverse[
        "lm_head.decoder.weight"]
    weights_reverse.pop("lm_head.decoder.weight", None)

    model.load_state_dict(weights)
    model.to('cuda')
    model.eval()

    model_reverse.load_state_dict(weights_reverse)
    model_reverse.to('cuda')
Пример #6
0
from .models import ElmoSCLSTM
from .util import get_module_or_attr

"""
NEW: reranking snippets
"""
# (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
import torch
from torch.nn import CrossEntropyLoss

HFACE_batch_size = 8

from transformers import GPT2Tokenizer, GPT2LMHeadModel

gpt2Tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
gpt2LMHeadModel = GPT2LMHeadModel.from_pretrained('gpt2-medium')
gpt2Tokenizer.pad_token = gpt2Tokenizer.eos_token


def get_losses_from_gpt_lm(this_sents: "list[str]", gpt2LMHeadModel, gpt2Tokenizer, device):
    this_input_ids = gpt2Tokenizer.batch_encode_plus(this_sents, add_special_tokens=True, pad_to_max_length=True,
                                                     add_space_before_punct_symbol=True)["input_ids"]
    this_labels = torch.tensor(
        [[i if i != gpt2Tokenizer.pad_token_id else -100 for i in row] for row in this_input_ids]).to(device)
    this_input_ids = torch.tensor(this_input_ids).to(device)
    this_outputs = gpt2LMHeadModel(input_ids=this_input_ids)
    this_lm_logits = this_outputs[0]
    # Shift so that tokens < n predict n
    shift_logits2 = this_lm_logits[:, :-1, :]
    shift_labels2 = this_labels[:, 1:]
    # Flatten the tokens
import sys
from flask import Flask, request, jsonify, send_file, current_app, make_response, redirect, url_for, render_template
from flask_cors import CORS
from . import utilities
from .config import API_TITLE
import logging
import ast
import time
from datetime import timedelta
from functools import update_wrapper
import random
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpt_model.to(device)
gpt_model.eval()

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
# creates a Flask application, named app
app = Flask(__name__, static_folder="static")
app.config['SECRET_KEY'] = 'secret'
app.config['CORS_HEADERS'] = 'Content-Type'

cors = CORS(app,
            resources={
                r"/action": {
                    "origins": "http://localhost:5000"
                },
Пример #8
0
import torch
from torch.nn import functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.set_grad_enabled(False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
model = GPT2LMHeadModel.from_pretrained('./gpt2').eval()
model = model.to(device)


def extend(text, size=20):
    tokens = tokenizer.encode(text)
    tokens = torch.tensor([tokens]).to(device)
    tokens = model.generate(tokens, max_length=size+tokens.shape[1], do_sample=True)
    tokens = tokens[0].tolist()
    return tokenizer.decode(tokens)


if __name__ == "__main__":
    test_text = 'Microsoft and Google'
    extended = extend(test_text, 25)
    print(extended)
tokenizer = GPT2Tokenizer.from_pretrained("tokenizer")

tokenizer.add_special_tokens({
    "eos_token": "</s>",
    "bos_token": "<s>",
    "unk_token": "<unk>",
    "pad_token": "<pad>",
    "mask_token": "<mask>"
})

t = tokenizer.encode(inp)
print(t)

print(tokenizer.decode(t))

model = GPT2LMHeadModel.from_pretrained("GPyT").to("cuda")

while True:
    inp = input(">>> ")
    input_ids = tokenizer.encode(inp, return_tensors="pt").to("cuda")
    beam_output = model.generate(input_ids,
                                 max_length=512,
                                 num_beams=10,
                                 temperature=0.7,
                                 no_repeat_ngram_size=5,
                                 num_return_sequences=1)
    for beam in beam_output:
        out = tokenizer.decode(beam)
        fout = out.replace("<N>", "\n")

        print(green(str(fout)))
Пример #10
0
    with torch.no_grad():

        for i in range(text_len):
            outputs = model(cur_ids, labels=cur_ids)
            loss, logits = outputs[:2]
            softmax_logits = torch.softmax(
                logits[0, -1], dim=0
            )  #Take the first(only one) batch and the last predicted embedding
            next_token_id = choose_from_top(
                softmax_logits.to('cpu').numpy(), n=10
            )  #Randomly(from the given probability distribution) choose the next word from the top n words
            cur_ids = torch.cat([
                cur_ids,
                torch.ones((1, 1)).long().to(device) * next_token_id
            ],
                                dim=1)  # Add the last word

        output_list = list(cur_ids.squeeze().to('cpu').numpy())
        output_text = tokenizer.decode([output_list])
        print(output_text)


if __name__ == '__main__':

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = YTEncoder.from_pretrained(str(PATH_TO_MODELS / "yt.model"))
    model = GPT2LMHeadModel.from_pretrained(str(PATH_TO_MODELS /
                                                "m_gpt_2/")).to(device)

    generate_some_text(" хм...фестиваль в сочи...киберсборная сельца ")
def main(
    gpu: Param("GPU to run on", int) = None,
    lr: Param("base Learning rate", float) = 1e-4,
    bs: Param("Batch size", int) = 8,
    sl: Param("Sequence length", int) = 1024,
    epochs: Param("Number of epochs", int) = 1,
    fp16: Param("Use mixed precision training", int) = 0,
    dump: Param("Print model; don't train", int) = 0,
    runs: Param("Number of times to repeat training", int) = 1,
):
    "Training of IMDB classifier."

    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
        if gpu is None: gpu = list(range(n_gpu))[0]
        torch.cuda.set_device(gpu)
    else:
        n_gpu = None

    # GPT2
    from transformers import GPT2LMHeadModel, GPT2TokenizerFast
    pretrained_weights = 'gpt2'
    tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
    model = GPT2LMHeadModel.from_pretrained(pretrained_weights)

    # datasets
    path = rank0_first(lambda: untar_data(URLs.WIKITEXT_TINY))
    df_train = pd.read_csv(path / 'train.csv', header=None)
    df_valid = pd.read_csv(path / 'test.csv', header=None)
    all_texts = np.concatenate([df_train[0].values, df_valid[0].values])

    # fastai v2 tokenizer
    class TransformersTokenizer(Transform):
        def __init__(self, tokenizer):
            self.tokenizer = tokenizer

        def encodes(self, x):
            toks = self.tokenizer.tokenize(x)
            return tensor(self.tokenizer.convert_tokens_to_ids(toks))

        def decodes(self, x):
            return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

    splits = [
        list(range_of(df_train)),
        list(range(len(df_train), len(all_texts)))
    ]
    tls = TfmdLists(all_texts,
                    TransformersTokenizer(tokenizer),
                    splits=splits,
                    dl_type=LMDataLoader)

    # get dataloaders
    workers = min(8, num_cpus())
    dls = tls.dataloaders(bs=bs, seq_len=sl, num_workers=workers)

    class DropOutput(Callback):
        def after_pred(self):
            self.learn.pred = self.pred[0]

    for run in range(runs):
        print(
            f'Rank[{rank_distrib()}] Run: {run}; epochs: {epochs}; lr: {lr}; bs: {bs}; sl: {sl}'
        )

        learn = rank0_first(lambda: Learner(dls,
                                            model,
                                            loss_func=CrossEntropyLossFlat(),
                                            cbs=[DropOutput],
                                            metrics=Perplexity()))

        if dump:
            print(learn.model)
            exit()
        if fp16: learn = learn.to_fp16()

        # TODO: DataParallel would hit floating point error, disabled for now.
        # if gpu is None and n_gpu: ctx = partial(learn.parallel_ctx, device_ids=list(range(n_gpu)))

        # Workaround: In PyTorch 1.4, need to set DistributedDataParallel() with find_unused_parameters=True,
        # to avoid a crash that only happens in distributed mode of text_classifier_learner.fine_tune()

        # if num_distrib() > 1 and torch.__version__.startswith("1.4"): DistributedTrainer.fup = True
        DistributedTrainer.fup = True

        with learn.distrib_ctx(
                cuda_id=gpu
        ):  # distributed traing requires "-m fastai2.launch"
            print(
                f"Training in distributed data parallel context on GPU {gpu}",
                flush=True)
            learn.fit_one_cycle(epochs, lr)
Пример #12
0
# SPDX-License-Identifier: Apache-2.0
# based on: https://huggingface.co/blog/how-to-generate

from transformers import GPT2LMHeadModel, GPT2Tokenizer


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='pt')

# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=50)
print(f"Output ({greedy_output.shape}): {greedy_output}")
print(f"Detokenized: `{tokenizer.decode(greedy_output[0], skip_special_tokens=False)}`")


Пример #13
0
 def __init__(self, run_params):
     super().__init__()
     self.run_params = run_params
     config = GPT2Config()
     self.model = GPT2LMHeadModel(config)
     self.loss = torch.nn.CrossEntropyLoss(reduction='none')
    # Create tokenizers
    model_name = pretrained_models[0]
    gpt2_tokenizer = None
    if pretrained_models == [model_name]:
        gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    else:
        raise NotImplementedError("Only the following tokenizers are supported: {}".format(model_name))

    num_keywords = len(TARG)
    num_possible_labels = int(1 + num_keywords)

    model = None
    tokenizer = None
    if 'gpt2' in model_name:
        model = GPT2LMHeadModel.from_pretrained(model_name)
        tokenizer = gpt2_tokenizer
    else:
        raise NotImplementedError("model_name == {} not supported".format(model_name))

    model.transformer.output_hidden_states = True  # necessary to pull activation tensors
    device = torch.device("cpu")
    if torch.cuda.is_available():
        model = model.cuda()
        device = torch.device("cuda")

    try:
        # BEGINNING ##############################################################################################################################

        print("and so it begins", flush=True)
        dataset = []
Пример #15
0
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large",
                                          cache_dir="/u/scr/mhahn/cache/")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(
    "gpt2-large",
    pad_token_id=tokenizer.eos_token_id,
    cache_dir="/u/scr/mhahn/cache/").cuda()
print("Finished loading GPT2")

#text = "Replace me by any text you'd like."
#encoded_input = tokenizer.encode(text, return_tensors='pt')
#print(encoded_input)
#predictions, _ = model(encoded_input.cuda())
#print(predictions.size())
#
#sentences = {length : set() for length in range(20)}
#counter = 0
#with open("/jagupard27/scr0/mhahn/memory/char-lm-ud-stationary_12_SuperLong_WithAutoencoder_WithEx_Samples_Short_Combination_Subseq_VeryLong_WithSurp12_NormJudg_Short_CondGPT2.py_749792590_Model.txt", "r") as inFile:
#   next(inFile)
#   for line in inFile:
#     counter += 1
#     line = line.strip().split("\t")
#     _, _, _, _, _, _, _, sentence, _, _, nextWord = line
#     sentence = (sentence.strip().split(" ")[1:-1] + [nextWord.strip()])
#     sentences[len(sentence)].add(" ".join(sentence).strip())
#     if counter % 1000 == 0:
#        print(counter/9537985, sum([len(x) for _, x in sentences.items()])/counter)
 def get_encoder_decoder_model(self, config, decoder_config):
     encoder_model = BertModel(config)
     decoder_model = GPT2LMHeadModel(decoder_config)
     return encoder_model, decoder_model
Пример #17
0
else:
    gdown.download(model_location, datamodel)

# Code
tokenizer = ilm.tokenize_util.Tokenizer.GPT2
with open(os.path.join(MODEL_DIR, 'additional_ids_to_tokens.pkl'), 'rb') as f:
    additional_ids_to_tokens = pickle.load(f)
additional_tokens_to_ids = {v:k for k, v in additional_ids_to_tokens.items()}
try:
    ilm.tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer)
except ValueError:
    print('Already updated')

# Load model
device = 'cpu'
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
model.eval()
_ = model.to(device)

# Create context
context = """
Interview
Chris had a job interview today. _ He ended up landing the job.
""".strip()

# context = """
# Chris had a _. He ended up enjoying it a lot.
# """.strip()


context_ids = ilm.tokenize_util.encode(context, tokenizer)
Пример #18
0
    def __init__(
        self,
        model: str = None,
        config: Union[str, GPT2Config] = None,
        vocab_file: str = None,
        merges_file: str = None,
        cache_dir: str = "aitextgen",
        tf_gpt2: str = None,
        to_gpu: bool = False,
        to_fp16: bool = False,
        verbose: bool = False,
        torchscript: bool = False,
        ts_to_trace: bool = False,
        bos_token: str = None,
        eos_token: str = None,
        unk_token: str = None,
        **kwargs,
    ) -> None:

        if not verbose:
            for module in [
                    "transformers.file_utils",
                    "transformers.configuration_utils",
                    "transformers.tokenization_utils",
                    "filelock",
                    "transformers.modeling_gpt2",
            ]:
                logging.getLogger(module).setLevel(logging.WARN)
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.ERROR)

        if torchscript:
            assert model
            logger.info(f"Loading traced GPT-2 model from provided {model}.")
            if config is None:
                config = GPT2Config()
            self.torchscript = True
            self.model = GPT2LMHeadModel(config)

            # Transpose the traced model attributes to a GPT2LMHeadModel class
            # so it can inherit its functions
            pt_model = torch.jit.load(model)
            self.model.transformer = pt_model.transformer
            self.model.lm_head = pt_model.lm_head

        elif tf_gpt2:
            # Download + convert the TF weights if a PyTorch model has not been created
            if not os.path.isfile(
                    os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin")):
                assert tf_gpt2 in [
                    "124M",
                    "355M",
                    "774M",
                    "1558M",
                ], "Invalid TensorFlow GPT-2 model size."

                logger.info(
                    f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config "
                    + "from Google's servers")

                download_gpt2(cache_dir, tf_gpt2)

                logger.info(
                    f"Converting the {tf_gpt2} GPT-2 TensorFlow weights to PyTorch."
                )

                config_path = os.path.join(cache_dir, tf_gpt2, "hparams.json")

                convert_gpt2_checkpoint_to_pytorch(
                    os.path.join(cache_dir, tf_gpt2),
                    config_path,
                    cache_dir,
                )

                os.rename(
                    os.path.join(cache_dir, "pytorch_model.bin"),
                    os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin"),
                )

                os.rename(
                    os.path.join(cache_dir, "config.json"),
                    os.path.join(cache_dir, f"config_{tf_gpt2}.json"),
                )

            logger.info(f"Loading {tf_gpt2} GPT-2 model from /{cache_dir}.")
            model = os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin")
            config = os.path.join(cache_dir, f"config_{tf_gpt2}.json")

            self.model = GPT2LMHeadModel.from_pretrained(model, config=config)

        elif model and os.path.exists(model):
            # A pytorch_model.bin (+ optional config/config.json) is provided
            logger.info(f"Loading GPT-2 model from provided {model}.")
            if config is None:
                config = GPT2Config()
            if ts_to_trace:
                config.torchscript = True
            self.model = GPT2LMHeadModel.from_pretrained(model, config=config)
        elif config:
            if ts_to_trace:
                config.torchscript = True
            # Manually construct a GPT-2 model from scratch
            logger.info("Constructing GPT-2 model from provided config.")
            self.model = AutoModelWithLMHead.from_config(config=config)
        else:
            # Download and cache model from Huggingface
            if os.path.isdir(cache_dir) and len(os.listdir(cache_dir)) > 0:
                logger.info(
                    f"Loading {model or 'gpt2'} model from /{cache_dir}.")
            else:
                logger.info(
                    f"Downloading {model or 'gpt2'} model to /{cache_dir}.")
            self.model = GPT2LMHeadModel.from_pretrained(
                model or "gpt2", cache_dir=cache_dir, torchscript=ts_to_trace)
            if model and "gpt2" not in model:
                logger.info(f"Using the tokenizer for {model}.")
                self.tokenizer = GPT2Tokenizer.from_pretrained(
                    model,
                    cache_dir=cache_dir,
                )

        if self.tokenizer is None:
            # Update tokenizer settings (if not set already)
            args = locals()
            custom_tokenizer = False
            for attr in [
                    "vocab_file",
                    "merges_file",
                    "bos_token",
                    "eos_token",
                    "unk_token",
            ]:
                if args[attr] is not None:
                    custom_tokenizer = True
                    setattr(self, attr, args[attr])

            if custom_tokenizer:
                logger.info("Using a custom tokenizer.")
            else:
                logger.info("Using the default GPT-2 Tokenizer.")

            self.tokenizer = GPT2Tokenizer(
                vocab_file=self.vocab_file,
                merges_file=self.merges_file,
                bos_token=self.bos_token,
                eos_token=self.eos_token,
                unk_token=self.unk_token,
                pad_token=self.pad_token,
            )

        if to_gpu:
            if to_fp16:
                self.to_fp16()
            self.to_gpu()
Пример #19
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.model = GPT2LMHeadModel.from_pretrained('gpt2')
     self.get_model()
     self.get_datasets()
Пример #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
    parser.add_argument('--length', default=100, type=int, required=False, help='生成长度')
    parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度,越高越随机')
    parser.add_argument('--topk', default=8, type=int, required=False, help='生成的时候最高几选一')
    parser.add_argument('--topp', default=0, type=float, required=False, help='生成的时候积累概率最高多少')
    parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
                        help='模型参数路径')
    parser.add_argument('--tokenizer_path', default='cache/vocab.txt', type=str, required=False, help='词表路径')
    parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径')
    parser.add_argument('--save_path', default='generated/', type=str, required=False, help='存放生成的文件的路径')
    parser.add_argument('--articles_per_title', default=5, type=int, required=False, help='每个标题生成多少篇文章')
    parser.add_argument('--titles', default='我喜欢', type=str, required=False, help='标题列表,是一个字符串,用空格分开')
    parser.add_argument('--titles_file', default='', type=str, required=False,
                        help='标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效')
    parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)

    args = parser.parse_args()
    print('args:\n' + args.__repr__())

    if args.segment:
        from tokenizations import tokenization_bert_word_level as tokenization_bert
    else:
        from tokenizations import tokenization_bert

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device  # 此处设置程序使用哪些显卡
    length = args.length
    temperature = args.temperature
    topk = args.topk
    topp = args.topp
    repetition_penalty = args.repetition_penalty

    titles = args.titles.split()  # 列表,里面每个元素是一个生成的标题
    if args.titles_file:
        with open(args.titles_file, 'r') as f:
            titles = [line.strip('\n') for line in f.readlines()]
    articles_per_title = args.articles_per_title  # 这里定义一个标题生成多少篇文章
    save_path = args.save_path  # 设置存到哪

    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model.to(device)
    model.eval()

    n_ctx = model.config.n_ctx

    if not os.path.exists(save_path):
        os.mkdir(save_path)
    if length == -1:
        length = model.config.n_ctx

    for i, title in enumerate(titles):
        for j in range(articles_per_title):
            with open(save_path + str(i) + '-' + str(j) + '.txt', 'w',encoding='UTF8') as f:
                context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(title))
                generated = 0
                out = sample_sequence(
                    n_ctx=n_ctx,
                    model=model, length=length,
                    context=context_tokens, tokenizer=tokenizer,
                    temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty,
                    device=device
                )
                out = out.tolist()[0]

                generated += 1
                text = tokenizer.convert_ids_to_tokens(out)

                for i, item in enumerate(text[:-1]):  # 确保英文前后有空格
                    if is_word(item) and is_word(text[i + 1]):
                        text[i] = item + ' '

                for i, item in enumerate(text):
                    if item == '[MASK]':
                        text[i] = ''
                    if item == '[CLS]' or item == '[SEP]':
                        text[i] = '\n'

                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                text = ''.join(text).replace('##', '').strip()
                # text = ''.join(text.split('\n')[:-1])
                print(text)
                f.write(text + '\n')
                print("=" * 80)
Пример #21
0
CORS(app)
top = 5
""" BERT """
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Fill in the Blanks
bert_model_mlm = BertForMaskedLM.from_pretrained('bert-base-uncased',
                                                 output_attentions=True)
bert_model_mlm.eval()
# Question Answering
bert_model_qna = BertForQuestionAnswering.from_pretrained(
    'bert-base-uncased', output_attentions=True)
bert_model_qna.eval()
""" GPT2 """
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Language Modelling
gpt_model_lm = GPT2LMHeadModel.from_pretrained('gpt2', output_attentions=True)
gpt_model_lm.eval()


@app.route('/get_next_word', methods=['POST'])
def LM_predict():
    sentence_orig = request.form.get('text')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoded_prompt = gpt_tokenizer.encode(sentence_orig,
                                          add_special_tokens=False,
                                          return_tensors="pt")
    encoded_prompt = encoded_prompt.to(device)

    if encoded_prompt.size()[-1] == 0:
Пример #22
0
from transformers import BertTokenizer
tokenizer = BertTokenizer('vocab_small.txt')

from transformers import GPT2LMHeadModel
device = "cuda"
model = GPT2LMHeadModel.from_pretrained('./doupoGPT2')
model.to(device)
model.train()
print(model.num_parameters())

from transformers import LineByLineTextDataset
dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="./doupo.txt",
    block_size=128,
)

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                mlm=True,
                                                mlm_probability=0.15)

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./doupoGPT2",
    overwrite_output_dir=True,
    logging_steps=50,
    num_train_epochs=1,
    per_device_train_batch_size=64,
Пример #23
0
	results_path = args.output + 'metrics_' + save_name + '.csv'
	use_context = False if args.no_context else True

	print("Initialize evaluation script ...")

	script = GPT2EvaluationScript(path_to_data_folder=args.data,
								  batch_size=args.batch_size,
								  path_to_bert_ner=args.ner,
								  use_context=use_context,
								  summarizer=args.sum)

	if not os.path.exists(generation_path):
		print("Load GPT2 model in memory ...")

		if args.no_context or args.model == 'gpt2':
			model = GPT2LMHeadModel.from_pretrained(args.model)
		else:
			model = GPT2LMSegmentModel.from_pretrained(args.model)

		tokenizer = GPT2Tokenizer.from_pretrained(args.model)
		# add_special_tokens(model, tokenizer)

		gpt_2 = FlexibleGPT2(model=model,
							 tokenizer=tokenizer,
							 decoding_strategy=DEFAULT_DECODING_STRATEGY)

		print("Begin text generation ...")
		script.generate_texts(generation_path, gpt_2, verbose=1)

	print("Compute metrics ...")
	script.compute_metrics(generation_path, results_path, args.metrics, verbose=1)
Пример #24
0
from dotenv import load_dotenv

from nltk.tokenize import sent_tokenize
import nltk
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# FIX: Shouldn't need to be in this module but gunicorn needs it
load_dotenv('.env')

generate_blueprint = Blueprint('generate', __name__)

model_name = os.environ.get('MODEL_NAME') or 'gpt2'
cwd = os.getcwd()
model_path = cwd + '/' + model_name
model = GPT2LMHeadModel.from_pretrained(model_path)
tokenizer = GPT2Tokenizer.from_pretrained(model_path)


def get_sentences(input, split_sentences, tokenizer):
    text = tokenizer.decode(input.tolist(), skip_special_tokens=True)
    if split_sentences:
        text = sent_tokenize(text)
        return text
    else:
        return [text]


@generate_blueprint.route('/generate')
def generate():
    meta = {}
Пример #25
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        default='gpt2',
        help='pretrained model name or path to local checkpoint')
    parser.add_argument("--train_input_file",
                        type=str,
                        default='data/train.128len.db')
    parser.add_argument("--eval_input_file",
                        type=str,
                        default='./data/dummy_data.tsv')
    parser.add_argument("--output_dir", type=str, default='output')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_seq_length", type=int, default=128)

    parser.add_argument("--skip_eval",
                        action='store_true',
                        help='If true, skip evaluation.')

    parser.add_argument("--continue_from", type=int, default=0)

    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="batch size now means per GPU per step")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=2,
        help="to increase effective batch size and reduce synchronization")
    parser.add_argument("--eval_batch_size", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_optim_steps",
                        type=int,
                        default=1000000,
                        help="new API specifies num update steps")
    parser.add_argument("--valid_step",
                        type=int,
                        default=10000,
                        help="how many optim steps between validations")
    parser.add_argument("--warmup_proportion", type=float, default=0.1)
    parser.add_argument("--warmup_steps", type=int, default=16000)

    parser.add_argument("--normalize_data", type=boolean_string, default=True)
    parser.add_argument("--fp16", type=boolean_string, default=True)
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--lr_schedule",
                        type=str,
                        choices=['noam', 'noamwd', 'BERT', 'None'],
                        default='noam')
    parser.add_argument("--loss_scale", type=float, default=0)
    parser.add_argument("--no_token_id", type=boolean_string, default=True)

    parser.add_argument("--log_dir", type=str)
    parser.add_argument('--pbar',
                        type=boolean_string,
                        default=True,
                        help='turn on progress bar')

    # distributed
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help='for torch.distributed')

    args = parser.parse_args()

    assert args.train_batch_size % args.gradient_accumulation_steps == 0, 'batch size % gradient accumulation steps != 0!'
    args.train_batch_size = (args.train_batch_size //
                             args.gradient_accumulation_steps)
    logger.info(
        f'train batch size = {args.train_batch_size*args.gradient_accumulation_steps}, '
        'new train batch size (after gradient accumulation) = {args.train_batch_size}'
    )

    if args.local_rank == -1:
        logger.info(f'CUDA available? {str(torch.cuda.is_available())}')
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        args.device, args.n_gpu = device, n_gpu
    else:
        # distributed training
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = torch.distributed.get_world_size()
        args.device, args.n_gpu = device, 1
        logger.info(
            f"device: {device} n_gpu: {n_gpu}, distributed training: {bool(args.local_rank != -1)},16-bits training: {args.fp16}"
        )

    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    if n_gpu > 0: torch.cuda.manual_seed_all(args.seed)

    timestamp = datetime.datetime.now().strftime('%Y-%m-%d%H%M%S')
    output_dir = join(
        args.output_dir,
        'GPT2.{}.{}.{}gpu.{}'.format(args.learning_rate, args.train_batch_size,
                                     n_gpu, timestamp))
    log_dir = args.log_dir if args.log_dir is not None and len(
        args.log_dir) > 0 else output_dir
    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        os.makedirs(output_dir, exist_ok=True)
        train_logger = open(join(log_dir, 'train_log.txt'), 'a+', buffering=1)
        eval_logger = open(join(log_dir, 'eval_log.txt'), 'a+', buffering=1)
        print(
            'epoch,global_step,step,mean_loss,n_token_real,n_token_total,epoch_time',
            file=train_logger)
        print('epoch,global_step,step,eval_loss', file=eval_logger)

    tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
    config = GPT2Config.from_pretrained(args.model_name_or_path)

    if args.local_rank == -1:
        train_dataloader = BucketingDataLoader(args.train_input_file,
                                               args.train_batch_size,
                                               args.max_seq_length)
    else:
        train_dataloader = DistributedBucketingDataLoader(
            torch.distributed.get_rank(), torch.distributed.get_world_size(),
            args.train_input_file, args.train_batch_size, args.max_seq_length)

    model = GPT2LMHeadModel.from_pretrained(
        args.model_name_or_path,
        from_tf=bool('.ckpt' in args.model_name_or_path),
        config=config)
    model = model.to(args.device)

    global_step, tr_loss = train(args, train_dataloader, model, tokenizer,
                                 train_logger, eval_logger)
Пример #26
0
def main(
    german_tokenizer,
    english_tokenizer,
    gpt2_model,
    german_freqs,
    english_freqs,
    german_vecs,
    english_vecs,
    out_path,
):
    german_freqs = json.load(open(german_freqs))
    english_freqs = json.load(open(english_freqs))

    de_vectors, de_ids = load_vectors(german_vecs)
    en_vectors, en_ids = load_vectors(english_vecs)

    german_tokenizer = GPT2TokenizerFast.from_pretrained(german_tokenizer)
    english_tokenizer = GPT2TokenizerFast.from_pretrained(english_tokenizer)

    model = GPT2LMHeadModel.from_pretrained(gpt2_model)

    en_tok_embs = get_tokenizer_embeddings(english_tokenizer, en_vectors,
                                           en_ids, english_freqs)
    de_tok_embs = get_tokenizer_embeddings(german_tokenizer, de_vectors,
                                           de_ids, german_freqs)

    def get_closest(token_id, similarities=None):
        if (de_tok_embs[token_id] == 0).all():
            return None, None

        if similarities is None:
            similarities = cosine_similarity(
                de_tok_embs[token_id][np.newaxis, :], en_tok_embs)[0]

        best = np.argsort(similarities)[::-1]

        best = english_tokenizer.convert_ids_to_tokens(best)
        de_token = german_tokenizer.convert_ids_to_tokens([token_id])[0]
        space_before = de_token.startswith("Ġ")

        best = [x for x in best if x.startswith("Ġ") == space_before]
        en_token = best[0]

        return en_token, de_token

    print("Some sample mappings:")

    for token_id in np.random.choice(
            list(german_tokenizer.get_vocab().values()), 30):
        en_token, de_token = get_closest(token_id)

        print(f"{de_token} -> {en_token}")

    german_wte_weight = torch.zeros_like(model.transformer.wte.weight)
    mean, std = (
        model.transformer.wte.weight.mean().item(),
        model.transformer.wte.weight.std().item(),
    )

    en_vocab = english_tokenizer.get_vocab()
    n_matched = 0

    batch_size = 1024
    for i in tqdm(range(int(math.ceil(len(german_wte_weight) / batch_size)))):
        start, end = i * batch_size, min((i + 1) * batch_size,
                                         len(german_wte_weight))

        similarities = cosine_similarity(de_tok_embs[start:end], en_tok_embs)
        for token_id in range(start, end):
            en_token, _ = get_closest(token_id,
                                      similarities=similarities[token_id -
                                                                start])

            if en_token is None:
                german_wte_weight[token_id] = torch.normal(
                    mean, std, size=(german_wte_weight.shape[1], ))
            else:
                en_token_id = en_vocab[en_token]
                german_wte_weight[token_id] = model.transformer.wte.weight[
                    en_token_id]

                n_matched += 1

    print(f"Matching token found for {n_matched} of {len(en_vocab)} tokens.")
    torch.save(german_wte_weight, out_path)
def main():
    args = parse_args()

    if args.should_continue and args.mode == 'train':
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            logger.warning(
                "Used --should_continue but no checkpoint was found in --output_dir."
            )
        else:
            logger.info(f"Found existing checkpoint: {sorted_checkpoints[-1]}")
            args.model_name_or_path = sorted_checkpoints[-1]

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and args.mode == 'train' and not args.overwrite_output_dir
            and not args.should_continue):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup CUDA, GPU & distributed training
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device('cpu')
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    # logger.info(
    #     "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    #     args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16,
    # )

    if args.seed:
        set_seed(args)

    config = GPT2Config.from_pretrained(args.config_name)
    tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer)
    if args.model_name_or_path:
        logger.info(f'Loading pretrained from {args.model_name_or_path}...')
        model = GPT2LMHeadModel.from_pretrained(
            args.model_name_or_path,
            from_tf=False,
            config=config,
        )
    else:
        model = GPT2LMHeadModel(config)
        logger.info(f'Loading model state dict from {args.state_dict}...')
        state = torch.load(args.state_dict, map_location=torch.device('cpu'))
        model.load_state_dict(state, strict=False)

    model.to(args.device)

    if args.mode == 'train':
        logging.info(f'Tokenizer size: {len(tokenizer.encoder)}')
        logging.info(f'Model config: {config.to_dict()}')

    if args.mode == 'train':
        logger.info("Training/evaluation parameters %s", args)

    # Training
    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using
    # from_pretrained()
    if args.mode == 'train':
        df_trn, df_val = preprocess_dataset(args.dataset)
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                df_trn,
                                                df_val,
                                                evaluate=False)
        eval_dataset = load_and_cache_examples(args,
                                               tokenizer,
                                               df_trn,
                                               df_val,
                                               evaluate=True)
        global_step, tr_loss = train(args, train_dataset, eval_dataset, model,
                                     tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step,
                    tr_loss)
        # Create output directory if needed
        os.makedirs(args.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (model.module if hasattr(model, "module") else model
                         )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = GPT2LMHeadModel.from_pretrained(args.output_dir)
        tokenizer = GPT2Tokenizer.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if args.mode == 'eval' and args.local_rank in [-1, 0]:
        df_trn, df_val = preprocess_dataset(args.dataset)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                "-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split(
                "/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model = GPT2LMHeadModel.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args,
                              model,
                              tokenizer,
                              df_trn,
                              df_val,
                              prefix=prefix)
            result = dict(
                (k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

        print(results)

    if args.mode == 'interact':
        interact(args, model, tokenizer)
Пример #28
0
def main_worker(gpu, ngpus_per_node, args):
    if args.model_type == 'cvae':
        args.learn_prior = True
    else:
        args.learn_prior = False

    # GPU
    args.gpu = gpu
    print("There are ", torch.cuda.device_count(), " available GPUs!")
    # print('Setting GPUs {}'.format(args.device))
    print('Using GPU devices {}'.format(devices))
    device = torch.device('cuda', args.gpu)
    torch.cuda.set_device(device)
    print('Current single GPU: {}'.format(torch.cuda.current_device()))

    # randomness
    np.random.seed(args.seed)
    prng = np.random.RandomState()
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # For multiprocessing distributed training, rank needs to be the global rank among all the processes
    args.rank = args.rank * ngpus_per_node + gpu
    print('Setting rank', args.rank)
    recon_attempt = 1
    connected = False
    if args.rank != 0:
        # Stall to have rank 0 node go first
        time.sleep(3)
    while not connected:
        try:
            dist.init_process_group(backend=args.dist_backend,
                                    init_method=args.dist_url,
                                    world_size=args.world_size,
                                    rank=args.rank)
            connected = True
            print('Established connection. Rank:', args.rank)
        except Exception as e:
            # Sometimes the head node launches after the worker, which would cause an issue
            print('Failed to init process group. Retrying...', recon_attempt,
                  e)
            recon_attempt += 1
            time.sleep(10)

    # logging
    if args.rank == 0:
        save_folder = os.path.join(args.out_dir, args.experiment)
        os.makedirs(save_folder, exist_ok=True)
        t_writer = SummaryWriter(os.path.join(save_folder, 'train'),
                                 flush_secs=5)
        v_writer = SummaryWriter(os.path.join(save_folder, 'val'),
                                 flush_secs=5)
        importlib.reload(logging)
        logging.basicConfig(filename=os.path.join(save_folder, 'train.log'),
                            level=logging.INFO,
                            format='%(asctime)s--- %(message)s')
        logging.info(
            '\n*******************************************************************************\n'
        )
        logging.info("the configuration:")
        logging.info(str(args).replace(',', '\n'))

    print('Loading models...')
    cache_dir = os.path.join(args.out_dir, 'model_cache')
    os.makedirs(cache_dir, exist_ok=True)
    # Load pre-trained teacher tokenizer (vocabulary)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
    # Hack to allow tokenizing longer sequences.
    tokenizer.max_len = int(1e12)
    gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
    print('gpt2_params:', num_params(gpt2_model))  # gpt2: 124439808
    config = GPT2Config()

    # # add special tokens
    # special_tokens_dict = {
    #     'pad_token': '<|startoftext|>',
    #     'cls_token': '<|startofcond|>',
    #     'sep_token': '<|sepofcond|>',
    #     'mask_token': '<|endofcond|>'
    # }
    # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    # print('We have added', num_added_toks, 'special tokens')
    # # Notice: resize_token_embeddings expect to receive the full size of the new vocab
    # gpt2_model.resize_token_embeddings(len(tokenizer))
    # assert tokenizer.pad_token == '<|startoftext|>'

    VAE = VAEModel(config,
                   add_input=args.add_input,
                   add_attn=args.add_attn,
                   add_softmax=args.add_softmax,
                   attn_proj_vary=args.attn_proj_vary,
                   learn_prior=args.learn_prior)
    init_para_frompretrained(VAE.transformer,
                             gpt2_model.transformer,
                             share_para=True)
    init_para_frompretrained(VAE.encoder,
                             gpt2_model.transformer,
                             share_para=False)
    if args.learn_prior:
        init_para_frompretrained(VAE.encoder_prior,
                                 VAE.encoder,
                                 share_para=True)
        VAE.encoder_prior.averageSelfAttention.attention_weights = VAE.encoder.averageSelfAttention.attention_weights
    VAE.lm_head.weight = gpt2_model.lm_head.weight
    if VAE.add_softmax:
        VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size())
        # VAE.lm_head_rep = LM_head_rep(*gpt2_model.lm_head.weight.size()[::-1])
    print('VAE_params:', num_params(VAE))  # 286694400
    if args.load:
        print('Loading model weights...')
        state = torch.load(os.path.join(
            args.load, 'model_latest.pt'))  # , map_location='cpu'
        if 'module' in list(state.keys(
        ))[0]:  # model_path is data parallel model with attr 'module'
            state_copy = copy.copy(state)
            keys = state_copy.keys()
            for k in keys:
                state[k.replace('module.', '')] = state.pop(k)
        VAE.load_state_dict(state)
        gc.collect()
    print('Done.')

    # fix pre-trained parameters before certain iterations
    tuning_all_after_iters = 10000
    tuning_all = False
    for name, parameter in VAE.named_parameters():
        # print((name, parameter.requires_grad))
        new_pars = [
            'c_z', 'attention_weights', 'mean', 'logvar', 'input_proj',
            'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep'
        ]

        if not any([True if n in name else False for n in new_pars]):
            parameter.requires_grad = False

    print('Setup data...')
    # Batch and sequence length schedule
    assert len(args.batch_sizes) == len(args.seq_lens)
    batch_schedule = list(
        zip(map(int, args.batch_sizes), map(int, args.seq_lens)))
    assert len(
        batch_schedule) <= 2, 'Currently not supporting multiple schedule'
    cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0
    print('Batch schedule', batch_schedule)
    train_loader, val_loader, test_loader = prepare_dataset(
        args.data_dir,
        args.dataset,
        tokenizer,
        batch_schedule[cur_b_schedule][0],
        batch_schedule[cur_b_schedule][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        batch_schedule[-1][0],
        batch_schedule[-1][1],
        make_test=True,
        num_workers=args.workers,
        data_type=args.data_type)
    print('Done.')

    ###
    val_loader = test_loader
    ###

    print('Wrapping models and optimizers...')
    # Apply linear scaling rule to increase batch size for short sequence training.
    lr_schedule = switch_schedule(
        linear_schedule(args),
        batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0],
        int(args.iterations * args.switch_time))
    if args.fp16:
        VAE = VAE.half()
    VAE = VAE.to(device)
    VAE = VAE.train()

    params = [p for p in VAE.parameters() if p.requires_grad]
    optimizer = FusedAdam(params, lr=args.lr)
    optimizer = FP16_Optimizer(optimizer,
                               dynamic_loss_scale=True,
                               verbose=False)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer.optimizer,
                                                  lr_schedule)
    loss_model = SimpleDistributedDataParallel(VAE, args.world_size)

    loss_fn = nn.CrossEntropyLoss(reduction='none')
    print('Done.')

    print('Begin training iterations')
    logging.info("Begin training iterations")
    max_val_batches = 20000  # max num. of val batches
    logging.info("Total iteration: %d" % args.iterations)
    e = 0  # number of epoch
    num_iters = 0
    optimizer.zero_grad()
    beta = args.beta_0
    endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")

    def val_step(val_loader):
        VAE.eval()

        n_words_bpe = 0
        n_words = 0
        logp_sum = 0.0
        kl_loss_sum = 0.0

        logging.info("Validation loop.         Batches: %d" % len(val_loader))
        logging.info("Validation loop. max_val_batches: %d" % max_val_batches)

        # val_iter = iter(val_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(val_iter)
        with tqdm(total=min(len(val_loader), max_val_batches)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(val_loader):
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        loss, ce_loss, kl_loss = compute_loss(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)
                    else:
                        loss, ce_loss, kl_loss = compute_loss_ae(
                            device, VAE, x_mask, x_tokens, y_mask, y_tokens,
                            input_tokens, target_tokens, mask, loss_fn, 1.0)

                if len(target_tokens.size()) == 1:
                    target_tokens = target_tokens.unsqueeze(0)
                n, l = target_tokens.size()

                text = target_tokens[0, :].tolist()
                logprob = ce_loss.tolist()
                assert len(text) == len(logprob)

                # only for story
                idx = text.index(endoftext)
                text = text[idx + 1:]
                logprob = logprob[idx + 1:]

                if endoftext in text:
                    idx = text.index(endoftext)
                    text = text[:idx]
                    logprob = logprob[:idx]

                logp_sum += sum(logprob)

                n_words_bpe += len(text)

                story = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                story = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in story
                ]
                story = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in story
                ]
                words = sum([
                    len([
                        t for t in re.split(
                            '("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s)
                        if t != ' ' and t != ''
                    ]) for s in story
                ])
                n_words += words

                kl_loss_sum += kl_loss.item()

                if i > max_val_batches:
                    break
                pbar.update(1)

        loss_bpe = logp_sum / n_words_bpe
        ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3)
        ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3)
        kl = kl_loss_sum / len(val_loader)

        v_writer.add_scalar('loss', loss_bpe, num_iters)
        v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters)
        v_writer.add_scalar('ppl_word', ppl_word, num_iters)
        v_writer.add_scalar('kl', kl, num_iters)
        logging.info('val loss    : %.4f' % loss_bpe)
        logging.info('val ppl_bpe : %.4f' % ppl_bpe)
        logging.info('val ppl_word: %.4f' % ppl_word)
        logging.info('val   kl    : %.4f' % kl)

        VAE.train()

    def test_plot(test_loader, num_iters):
        VAE.eval()

        # get embedding
        X_emb = None
        y = None

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(test_loader):
                y_mask = y_mask.to(device)
                y_tokens = y_tokens.to(device)
                x_mask = x_mask.to(device)
                x_tokens = x_tokens.to(device)
                with torch.no_grad():
                    if args.model_type == 'cvae':
                        latent_mean, _ = VAE.encoder_prior(
                            input_ids=x_tokens, attention_mask=x_mask)[:2]
                    else:
                        latent_mean, _ = VAE.encoder(input_ids=x_tokens,
                                                     attention_mask=x_mask)[:2]

                if args.dataset == 'ax' or args.dataset == 'yp':
                    label = [
                        tokenizer.decode(l)[:2] for l in x_tokens.tolist()
                    ]
                elif args.dataset == 'wp':
                    label = []
                    prompts = [
                        tokenizer.decode(l)[:6].lower()
                        for l in x_tokens.tolist()
                    ]
                    for prom in prompts:
                        if prom[0] in ['[', '('] and prom[5] in [']', ')']:
                            label.append(prom[2:4])
                        else:
                            label.append(None)
                elif args.dataset == 'wi':
                    # 0. TV, play, miniseries, telenovela; 1.film; 2. music; 3. manga, comic, 4. book, novel, story 5. game
                    label = []
                    prompts = [tokenizer.decode(l) for l in x_tokens.tolist()]
                    for prom in prompts:
                        if 'TV' in prom or 'play' in prom or 'miniseries' in prom or 'telenovela' in prom:
                            label.append(0)
                        elif 'film' in prom:
                            label.append(1)
                        elif 'music' in prom:
                            label.append(2)
                        elif 'manga' in prom or 'comic' in prom:
                            label.append(3)
                        elif 'book' in prom or 'novel' in prom or 'story' in prom:
                            label.append(4)
                        elif 'game' in prom:
                            label.append(5)
                        else:
                            label.append(None)
                else:
                    raise Exception

                if i == 0:
                    X_emb = latent_mean.data
                    y = label
                else:
                    X_emb = torch.cat((X_emb, latent_mean.data), dim=0)
                    y.extend(label)
                pbar.update(1)
        X_emb = X_emb.cpu().numpy()

        try:
            if args.dataset == 'yp':
                y = ['0' if l in ['0', '1'] else l for l in y]
                y = ['4' if l in ['3', '4'] else l for l in y]
                X_emb = X_emb[[l != '2' for l in y], :]
                y = [l for l in y if l != '2']

            if args.dataset == 'wp':
                topics = [['wp', 'sp', 'tt'], ['eu'], ['cw'], ['pm'],
                          ['mp', 'ip'], ['pi', 'cc'], ['ot'], ['rf']]
                match = [[True if l in t else False for t in topics]
                         for l in y]
                y = [m.index(True) if True in m else None for m in match]
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            if args.dataset == 'wi':
                X_emb = X_emb[[l is not None for l in y], :]
                y = [l for l in y if l is not None]

            # to 2D
            # X_emb_2d = TSNE(n_components=2, init='pca', verbose=1).fit_transform(X_emb)
            X_emb_2d = TSNE(n_components=2, verbose=1,
                            perplexity=40).fit_transform(X_emb)

            def remove_outliers(data, r=2.0):
                outliers_data = abs(
                    data - np.mean(data, axis=0)) >= r * np.std(data, axis=0)
                outliers = np.any(outliers_data, axis=1)
                keep = np.logical_not(outliers)
                return outliers, keep

            outliers, keep = remove_outliers(X_emb_2d)
            X_emb_2d = X_emb_2d[keep, :]
            y = [l for l, k in zip(y, keep.tolist()) if k]

            # plot
            fig = plt.figure(figsize=(4, 4))
            ax = fig.add_axes([0, 0, 1, 1])
            cc = ['r', 'b', 'g', 'y', 'k', 'c', 'm', 'tab:blue']
            for i, l in enumerate(sorted(set(y))):
                idx = [yl == l for yl in y]
                plt.scatter(X_emb_2d[idx, 0],
                            X_emb_2d[idx, 1],
                            c=cc[i],
                            s=10,
                            edgecolor='none',
                            alpha=0.5)
            ax.axis('off')  # adding it will get no axis
            plt.savefig(
                os.path.join(save_folder,
                             'tSNE_' + '{:07d}'.format(num_iters) + '.png'))
            plt.close(fig)
        except:
            pass

        VAE.train()

    def generate(test_loader, num_iters):
        VAE.eval()

        n_samples = 0
        bleu4_sum = 0.0
        rouge_scores_values_sum = [0.0] * 9

        args.nsamples = 1
        args.batch_size = 1
        args.temperature = 0.95
        args.top_k = 100
        args.top_p = 0.95
        model_type = args.model_type

        # write samples to file
        samples_file = open(os.path.join(
            save_folder, 'generate-' + '%07d' % num_iters + '.txt'),
                            'w',
                            encoding='utf8')

        # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter)
        with tqdm(total=len(test_loader)) as pbar:
            for i_test, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                         target_tokens, mask) in enumerate(test_loader):

                if i_test >= 10: break

                length = -1
                if length == -1:
                    length = VAE.config.n_ctx - x_tokens.size(1) - 1
                elif length > VAE.config.n_ctx - x_tokens.size(1) - 1:
                    raise ValueError(
                        "Can't get samples longer than window size: %s" %
                        VAE.config.n_ctx)

                eff_samples = []
                n, l = target_tokens.size()
                storys = [
                    tokenizer.decode(target_tokens[i, :]) for i in range(n)
                ]
                storys = [
                    s[s.find("<|endoftext|>") + len("<|endoftext|>"):]
                    for s in storys
                ]
                storys_str = [
                    s[:s.find("<|endoftext|>") +
                      len("<|endoftext|>")] if "<|endoftext|>" in s else s
                    for s in storys
                ]

                for _ in range(args.nsamples // args.batch_size):
                    # model, batch_size, temperature, top_k, top_p, eos_token, sample = VAE, args.batch_size, args.temperature, args.top_k, args.top_p, tokenizer.encoder['<|endoftext|>'], True
                    out, _ = sample_sequence(
                        model=VAE,
                        tokenizer=tokenizer,
                        length=length,
                        batch_size=args.batch_size,
                        x_mask=x_mask,
                        x_tokens=x_tokens,
                        y_mask=y_mask,
                        y_tokens=y_tokens,
                        temperature=args.temperature,
                        top_k=args.top_k,
                        top_p=args.top_p,
                        device=device,
                        eos_token=tokenizer.encoder['<|endoftext|>'],
                        model_type=model_type)
                    out = out.tolist()

                    # extract story, check metrics
                    for i in range(len(out)):
                        text = out[i]
                        text = text[text.index(endoftext) + 1:]

                        if endoftext in text:
                            idx = text.index(endoftext)
                            text = text[:idx]

                        text = tokenizer.decode(text).strip()

                        # score for one long text, higher than 0.075 usually means repetition
                        # rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8])
                        # if rep_score > 0.075:
                        #     # print(rep_score)
                        #     continue

                        try:
                            # check bleu
                            bleu4 = sentence_bleu(
                                [storys_str[i].split()],
                                text,
                                smoothing_function=SmoothingFunction().method7)

                            # check rouge
                            rouge = Rouge()
                            rouge_scores = rouge.get_scores(
                                text, storys_str[i])
                            rouge_scores_values = [
                                v for k in rouge_scores[0].keys()
                                for v in rouge_scores[0][k].values()
                            ]

                            bleu4_sum += bleu4
                            rouge_scores_values_sum = [
                                v1 + v2
                                for v1, v2 in zip(rouge_scores_values_sum,
                                                  rouge_scores_values)
                            ]
                            n_samples += 1
                        except:
                            bleu4 = 0.0
                            rouge_scores = [{
                                'rouge-1': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-2': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                },
                                'rouge-l': {
                                    'f': 0.0,
                                    'p': 0.0,
                                    'r': 0.0
                                }
                            }]

                        eff_samples.append((text, bleu4, rouge_scores))

                    pbar.update(1)

                for i in range(len(eff_samples)):
                    samples_file.write("=" * 50 + " SAMPLE " + str(i_test) +
                                       " " + "=" * 50)
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Outlines  " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(
                        tokenizer.decode(
                            x_tokens[i, :][x_mask[i, :] == 1].tolist()))
                    samples_file.write('\n' * 2)
                    samples_file.write("=" * 40 + " Story " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(storys_str[i])
                    samples_file.write('\n' * 2)

                    samples_file.write("=" * 40 + " Generated " + "=" * 40)
                    samples_file.write('\n' * 2)
                    samples_file.write(eff_samples[i][0])
                    samples_file.write('\n' * 4)
                    samples_file.flush()

        print('Test complete with %05d samples.' % n_samples)
        logging.info("Test complete with %05d samples.", n_samples)
        logging.info("Iteration completed: %d" % num_iters)

        bleu4 = round(bleu4_sum / n_samples, 3)
        rouge_scores_values = [
            round(r / n_samples, 3) for r in rouge_scores_values_sum
        ]
        print(' bleu-4:', bleu4)
        print(' rouge :', rouge_scores_values)
        logging.info(' bleu-4: %f', bleu4)
        logging.info(' rouge : %s', str(rouge_scores_values))

        VAE.train()

    if args.rank == 0:
        test_plot(test_loader, num_iters)
        val_step(val_loader)
        generate(test_loader, num_iters)
        torch.save(
            loss_model.state_dict(),
            os.path.join(save_folder,
                         'model_' + '{:07d}'.format(num_iters) + '.pt'))

    while num_iters < args.iterations:
        # Run epoch
        st = time.time()

        # Training
        print('Training loop. Batches:', len(train_loader))
        logging.info(
            '\n----------------------------------------------------------------------'
        )
        logging.info("Training loop.       Batches: %d" % len(train_loader))

        # train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter)
        with tqdm(total=len(train_loader)) as pbar:
            for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens,
                    target_tokens, mask) in enumerate(train_loader):

                # if num_iters % args.cycle >= args.cycle - args.beta_warmup:
                #     beta = min(1.0, beta + (1. - args.beta_0) / args.beta_warmup)

                if not tuning_all and num_iters >= tuning_all_after_iters:
                    for name, parameter in VAE.named_parameters():
                        # print((name, parameter.requires_grad))
                        parameter.requires_grad = True
                    tuning_all = True

                output = train_step(device, loss_model, optimizer, x_mask,
                                    x_tokens, y_mask, y_tokens, input_tokens,
                                    target_tokens, mask, loss_fn, beta,
                                    args.model_type)

                if args.rank == 0:
                    loss, ce_loss, kl_loss = output[-1]
                    lr = scheduler.get_last_lr()[0]
                    # Log to Tensorboard
                    t_writer.add_scalar('loss', loss, num_iters)
                    t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)),
                                        num_iters)
                    t_writer.add_scalar('lr', lr, num_iters)
                    t_writer.add_scalar('iter_time',
                                        time.time() - st, num_iters)
                    t_writer.add_scalar('kl', kl_loss, num_iters)
                    t_writer.add_scalar('beta', beta, num_iters)

                    if args.model_type == 'ae_vae_fusion':
                        loss, ce_loss, kl_loss = output[0]
                        # Log to Tensorboard
                        t_writer.add_scalar('ae_loss', loss, num_iters)
                        t_writer.add_scalar('ae_kl', kl_loss, num_iters)

                st = time.time()
                end = num_iters >= args.iterations

                if args.warmup != -1:
                    scheduler.step()

                if end: break
                num_iters += 1
                pbar.update(1)

                if num_iters % args.cycle == 0:
                    beta = args.beta_0
                    logging.info('KL annealing restart')

                if args.rank == 0 and num_iters % 10000 == 0:
                    test_plot(test_loader, num_iters)
                    val_step(val_loader)
                    generate(test_loader, num_iters)

                if args.rank == 0 and num_iters % 10000 == 0:
                    print('Saving model...')
                    logging.info("Iteration completed: %d, remained %d" %
                                 (num_iters, args.iterations - num_iters))
                    logging.info("Saving model...")
                    logging.info(
                        '\n------------------------------------------------------'
                    )
                    torch.save(
                        loss_model.state_dict(),
                        os.path.join(
                            save_folder,
                            'model_' + '{:07d}'.format(num_iters) + '.pt'))

                if args.switch_time > 0 and num_iters == int(
                        args.iterations * args.switch_time):
                    print('Switch to long sequence training')
                    logging.info("Switch to long sequence training")
                    cur_b_schedule += 1
                    train_loader, val_loader, test_loader = prepare_dataset(
                        args.data_dir,
                        args.dataset,
                        tokenizer,
                        batch_schedule[cur_b_schedule][0],
                        batch_schedule[cur_b_schedule][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        batch_schedule[-1][0],
                        batch_schedule[-1][1],
                        make_test=True,
                        num_workers=args.workers,
                        data_type=args.data_type)
        if not end:
            e += 1
            logging.info("Training loop. The ith epoch completed: %d" % e)

    if args.rank == 0:
        torch.save(loss_model.state_dict(),
                   os.path.join(save_folder, 'model_latest.pt'))
    print('Training complete.')
    logging.info("Training complete.")
Пример #29
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device',
                        default='0,1,2,3',
                        type=str,
                        required=False,
                        help='生成设备')
    parser.add_argument('--length',
                        default=-1,
                        type=int,
                        required=False,
                        help='生成长度')
    parser.add_argument('--batch_size',
                        default=1,
                        type=int,
                        required=False,
                        help='生成的batch size')
    parser.add_argument('--nsamples',
                        default=10,
                        type=int,
                        required=False,
                        help='生成几个样本')
    parser.add_argument('--temperature',
                        default=1,
                        type=float,
                        required=False,
                        help='生成温度')
    parser.add_argument('--topk',
                        default=8,
                        type=int,
                        required=False,
                        help='最高几选一')
    parser.add_argument('--topp',
                        default=0,
                        type=float,
                        required=False,
                        help='最高积累概率')
    parser.add_argument('--model_config',
                        default='config/model_config_small.json',
                        type=str,
                        required=False,
                        help='模型参数')
    parser.add_argument('--tokenizer_path',
                        default='cache/vocab_small.txt',
                        type=str,
                        required=False,
                        help='词表路径')
    parser.add_argument('--model_path',
                        default='model/final_model',
                        type=str,
                        required=False,
                        help='模型路径')
    parser.add_argument('--prefix',
                        default='萧炎',
                        type=str,
                        required=False,
                        help='生成文章的开头')
    parser.add_argument('--no_wordpiece',
                        action='store_true',
                        help='不做word piece切词')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--fast_pattern',
                        action='store_true',
                        help='采用更加快的方式生成文本')
    parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
    parser.add_argument('--save_samples_path',
                        default='.',
                        type=str,
                        required=False,
                        help="保存样本的路径")
    parser.add_argument('--repetition_penalty',
                        default=1.0,
                        type=float,
                        required=False)

    args = parser.parse_args()
    print('args:\n' + args.__repr__())

    if args.segment:
        from tokenizations import tokenization_bert_word_level as tokenization_util
    else:
        from tokenizations import tokenization_chars as tokenization_util

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device  # 此处设置程序使用哪些显卡
    length = args.length
    batch_size = args.batch_size
    nsamples = args.nsamples
    temperature = args.temperature
    topk = args.topk
    topp = args.topp
    repetition_penalty = args.repetition_penalty

    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = tokenization_util.BertTokenizer(vocab_file=args.tokenizer_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model.to(device)
    model.eval()

    n_ctx = model.config.n_ctx

    if length == -1:
        length = model.config.n_ctx
    if args.save_samples:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt',
                            'w',
                            encoding='utf8')
    while True:
        raw_text = args.prefix
        context_tokens = tokenizer.convert_tokens_to_ids(
            tokenizer.tokenize(raw_text))
        #pdb.set_trace()
        generated = 0
        for _ in range(nsamples // batch_size):
            out = generate(n_ctx=n_ctx,
                           model=model,
                           context=context_tokens,
                           length=length,
                           is_fast_pattern=args.fast_pattern,
                           tokenizer=tokenizer,
                           temperature=temperature,
                           top_k=topk,
                           top_p=topp,
                           repitition_penalty=repetition_penalty,
                           device=device)
            for i in range(batch_size):
                generated += 1
                text = tokenizer.convert_ids_to_tokens(out)
                for i, item in enumerate(text[:-1]):  # 确保英文前后有空格
                    if is_word(item) and is_word(text[i + 1]):
                        text[i] = item + ' '
                for i, item in enumerate(text):
                    if item == '[MASK]':
                        text[i] = ''
                    elif item == '[CLS]':
                        text[i] = '\n\n'
                    elif item == '[SEP]':
                        text[i] = '\n'
                info = "=" * 40 + " SAMPLE " + str(
                    generated) + " " + "=" * 40 + "\n"
                print(info)
                text = ''.join(text).replace('##', '').strip()
                print(text)
                if args.save_samples:
                    samples_file.write(info)
                    samples_file.write(text)
                    samples_file.write('\n')
                    samples_file.write('=' * 90)
                    samples_file.write('\n' * 2)
        print("=" * 80)
        if generated == nsamples:
            # close file when finish writing.
            if args.save_samples:
                samples_file.close()
            break
def init_model(name='gpt2'):
    model = GPT2LMHeadModel.from_pretrained(name, pad_token_id=50256)
    tokenizer = GPT2Tokenizer.from_pretrained(name, pad_token_id=50256)
    model.eval()
    return model, tokenizer