예제 #1
0
def model_init(app):
    ArgsSet = type('ArgsSet',(object,),{})
    client = ArgsSet()
    parser = ArgumentParser()
    parser.add_argument("--model-config", type=str, default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available()
                        else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--outlens", type=int, default=30)
    parser.add_argument("--beam", type=int, default=1)
    parser.add_argument("--gpt-checkpoints", type=str)
    parser.add_argument("--port", type=int, default=8866)

    args = parser.parse_args()
    args.load_model = True
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12

    gpt = BertModel(None, args)
    state_dict = convert_model(torch.load(args.gpt_checkpoints)['sd'])
    gpt.load_state_dict(state_dict)
    gpt.to(args.device)
    gpt.eval()
    tokenizer = BertWordPieceTokenizer("bert-base-chinese", cache_dir="temp_cache_dir")
    print(" Load model from {}".format(args.gpt_checkpoints))

    client.tokenizer = tokenizer
    client.gpt =gpt
    client.gpt_beam = SequenceGenerator(gpt, tokenizer, beam_size=args.beam, max_lens=args.outlens)
    client.device = args.device
    client.port = args.port
    client.generator = sample_sequence

    return client
예제 #2
0
def model_init(app):
    ArgsSet = type('ArgsSet', (object, ), {})
    client = ArgsSet()
    parser = ArgumentParser()
    parser.add_argument("--model-config",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--outlens", type=int, default=30)
    parser.add_argument("--beam", type=int, default=1)
    parser.add_argument("--fuse-checkpoints", type=str)
    parser.add_argument("--gpt-checkpoints", type=str)
    parser.add_argument("--qa-style-checkpoints", type=str)
    parser.add_argument("--multi-task", type=str)
    parser.add_argument("--split-sentence-with-task-embedding-checkpoints",
                        type=str)
    parser.add_argument("--special-cls-checkpoints", type=str)

    parser.add_argument("--port", type=int, default=8866)

    args = parser.parse_args()
    args.load_model = True
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12

    fuse_model = BertModel(None, args)
    state_dict = convert_model(torch.load(args.fuse_checkpoints)['sd'])
    fuse_model.load_state_dict(state_dict)
    fuse_model.to(args.device)
    fuse_model.eval()
    print("| Load model from {}".format(args.fuse_checkpoints))

    gpt = BertModel(None, args)
    state_dict = convert_model(torch.load(args.gpt_checkpoints)['sd'])
    gpt.load_state_dict(state_dict)
    gpt.to(args.device)
    gpt.eval()
    tokenizer = BertWordPieceTokenizer("bert-base-chinese",
                                       cache_dir="temp_cache_dir")
    print(" Load model from {}".format(args.gpt_checkpoints))

    # Load bert checkpoints
    args.load_model = False
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12
    bert = BertModel(None, args)
    bert.to(args.device)
    bert.eval()

    client.tokenizer = tokenizer
    client.fuse_model = fuse_model
    client.fuse_beam = SequenceGenerator(fuse_model,
                                         tokenizer,
                                         beam_size=args.beam,
                                         max_lens=args.outlens)
    client.gpt = gpt
    client.gpt_beam = SequenceGenerator(gpt,
                                        tokenizer,
                                        beam_size=args.beam,
                                        max_lens=args.outlens)
    client.bert = bert
    client.device = args.device
    client.port = args.port
    client.generator = sample_sequence

    # multi task model

    multi_task = BertModel(None, args)
    state_dict = convert_model(torch.load(args.multi_task)['sd'])
    print("| Load model from {}".format(args.multi_task))
    multi_task.load_state_dict(state_dict)
    multi_task.to(args.device)
    multi_task.eval()
    client.multi_task_model = multi_task
    client.multi_task_beam = SequenceGenerator(multi_task,
                                               tokenizer,
                                               beam_size=args.beam,
                                               max_lens=args.outlens)

    # qa style model
    qa_style = BertModel(None, args)
    state_dict = convert_model(torch.load(args.qa_style_checkpoints)['sd'])
    qa_style.load_state_dict(state_dict)
    qa_style.to(args.device)
    qa_style.eval()
    print(" Load model from {}".format(args.qa_style_checkpoints))
    client.qa_task_model = qa_style

    # special cls tokens
    special_cls_model = BertModel(None, args)
    special_cls_model.eval()
    state_dict = convert_model(torch.load(args.special_cls_checkpoints)['sd'])
    special_cls_model.load_state_dict(state_dict)
    special_cls_model.to(args.device)
    special_cls_model.eval()
    print(" Load model from {}".format(args.special_cls_checkpoints))
    client.special_cls_model = special_cls_model
    client.special_beam = SequenceGenerator(special_cls_model,
                                            tokenizer,
                                            beam_size=args.beam,
                                            max_lens=args.outlens)

    # split sentence model with task embedding
    split_sentence_model = BertModel(None, args)
    split_sentence_model.eval()
    state_dict = convert_model(
        torch.load(args.split_sentence_with_task_embedding_checkpoints)['sd'])
    split_sentence_model.load_state_dict(state_dict)
    split_sentence_model.to(args.device)
    split_sentence_model.eval()
    print(" Load model from {}".format(
        args.split_sentence_with_task_embedding_checkpoints))
    client.split_sentence_model = split_sentence_model
    client.split_sentence_beam = SequenceGenerator(split_sentence_model,
                                                   tokenizer,
                                                   beam_size=args.beam,
                                                   max_lens=args.outlens)

    return client
예제 #3
0
파일: main.py 프로젝트: sxjpage/ChineseBert
"""
import re
import torch
import sentencepiece as spm

from model import BertModel

sp = spm.SentencePieceProcessor()
sp.load('resource/sentencepiece.unigram.35000.model')
vocab_size = sp.get_piece_size()

n_embedding = 512
n_layer = 8

model = BertModel(vocab_size, n_embedding, n_layer)
model.eval()
model.load_state_dict(torch.load('resource/model.{}.{}.th'.format(n_embedding, n_layer),
                                 map_location='cpu'))

# you should enable cuda if it is available
# model.cuda()

# if you are using a GPU that has tensor cores (nvidia volta, Turing architecture), you can enable half precision
# inference and training, we recommend to use the nvidia official apex to make everything as clean as possible from
# apex import amp [model] = amp.initialize([model], opt_level="O2")
device = model.embedding.weight.data.device


def clean_text(txt):
    txt = txt.lower()
    txt = re.sub('\s*', '', txt)