def get_transformer(opt) -> Transformer: model = Transformer(embed_dim=opt.embed_dim, src_vocab_size=opt.src_vocab_size, trg_vocab_size=opt.trg_vocab_size, src_pad_idx=opt.src_pad_idx, trg_pad_idx=opt.trg_pad_idx, n_head=opt.n_head) model = model.to(opt.device) checkpoint_file_path = get_best_checkpoint(opt) if checkpoint_file_path is not None: print(f'Checkpoint loaded - {checkpoint_file_path}') checkpoint = torch.load(checkpoint_file_path, map_location=opt.device) model.load_state_dict(checkpoint['model']) return model
def load_transformer(opt) -> Transformer: checkpoint_file_path = get_best_checkpoint(opt) checkpoint = torch.load(checkpoint_file_path, map_location=opt.device) assert checkpoint is not None assert checkpoint['opt'] is not None assert checkpoint['weights'] is not None model_opt = checkpoint['opt'] model = Transformer(embed_dim=model_opt.embed_dim, src_vocab_size=model_opt.src_vocab_size, trg_vocab_size=model_opt.trg_vocab_size, src_pad_idx=model_opt.src_pad_idx, trg_pad_idx=model_opt.trg_pad_idx, n_head=model_opt.n_head) model.load_state_dict(checkpoint['weights']) print('model loaded:', checkpoint_file_path) return model.to(opt.device)
def create_model(opt): data = torch.load(opt.data_path) opt.src_vocab_size = len(data['src_dict']) opt.tgt_vocab_size = len(data['tgt_dict']) print('Creating new model parameters..') model = Transformer(opt) # Initialize a model state. model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0} # If opt.model_path exists, load model parameters. if os.path.exists(opt.model_path): print('Reloading model parameters..') model_state = torch.load(opt.model_path) model.load_state_dict(model_state['model_params']) if use_cuda: print('Using GPU..') model = model.to(device) return model, model_state
def main(): import argparse parse = argparse.ArgumentParser(description="设置基本参数") # model parameter parse.add_argument("--vocab_size", type=int, default=1000, help="字典大小") parse.add_argument("--n_position", type=int, default=256, help="位置数量序列最大长度") parse.add_argument("--word_vec_size", type=int, default=512, help="embedding输出大小") parse.add_argument("--d_model", type=int, default=512, help="隐层大小") parse.add_argument("--d_inner", type=int, default=1024, help="隐层中间层大小") parse.add_argument("--n_head", type=int, default=8, help="自注意力头的数量") parse.add_argument("--d_k", type=int, default=64, help="d_model/n_head每个头隐层的大小") parse.add_argument("--d_v", type=int, default=64, help="d_model/n_head每个头隐层的大小") parse.add_argument("--encoder_n_layers", type=int, default=6, help="编码的层数") parse.add_argument("--decoder_n_layers", type=int, default=6, help="解码的层数") parse.add_argument("--dropout", type=float, default=0.1, help="dropout概率") parse.add_argument("--pad_idx", type=int, default=-1, help="padding index") parse.add_argument("--trg_emb_prj_weight_sharing", action="store_true", default=True) parse.add_argument("--emb_src_trg_weight_sharing", action="store_true", default=True) # data parameter parse.add_argument("--vocab_path", type=str, default=os.path.join(root, "vocabulary/vocab.txt"), help="词汇表路径") parse.add_argument("--train_data_path", type=str, default=os.path.join(root, "data/train_small.txt"), help="训练数据路径") parse.add_argument("--evaluate_data_path", type=str, default=None, help="评估数据路径") parse.add_argument("--max_encode_len", type=int, default=192, help="最大编码序列长度") parse.add_argument("--max_decode_len", type=int, default=64, help="最大解码序列长度") parse.add_argument("--history_turns", type=int, default=3, help="历史对话轮数") parse.add_argument("--max_lines", type=int, default=525106, help="最多处理数据量") parse.add_argument("--batch_size", type=int, default=32, help="batch size 大小") # train parameter parse.add_argument("--epochs", type=int, default=20, help="训练epoch数量") parse.add_argument("--save_epoch", type=int, default=5, help="每训练多少epoch保存一次模型") parse.add_argument("--save_dir", type=str, default=os.path.join(root, "model/transformer_0127"), help="模型保存路径") parse.add_argument("--init_lr", type=float, default=1.0, help="初始学习率") parse.add_argument("--n_warmup_steps", type=int, default=100, help="热身步长") parse.add_argument("--label_smoothing", action="store_true", default=False) args = parse.parse_args() tokenizer = BertTokenizer(vocab_file=args.vocab_path) args.vocab_size = tokenizer.vocab_size args.pad_idx = tokenizer._convert_token_to_id("[PAD]") args_dict = vars(args) config = TransformerConfig(**args_dict) if not os.path.exists(config.save_dir): os.makedirs(config.save_dir) # 创建模型保存路径 logger.info("Load dataset.") train_dataset = ChatDataset(config.train_data_path, tokenizer=tokenizer, max_encode_len=config.max_encode_len, max_decode_len=config.max_decode_len, history_turns=config.history_turns, max_lines=config.max_lines) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) if config.evaluate_data_path is not None: eval_dataset = ChatDataset(config.evaluate_data_path, tokenizer=tokenizer, max_encode_len=config.max_encode_len, max_decode_len=config.max_decode_len, history_turns=config.history_turns, max_lines=config.max_lines) eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size, shuffle=False) else: eval_loader = False logger.info("Load model.") device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # 标准写法 model = Transformer(config=config) model.to(device) logger.info("Load optimizer.") optimizer = ScheduledOptim( optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09), config.init_lr, config.d_model, config.n_warmup_steps) logger.info("Save all config parameter.") config.save_para_to_json_file(os.path.join(root, "data/para.json")) logger.info("Training model.") train(config, model, optimizer, train_loader=train_loader, eval_loader=eval_loader, device=device)
def main(): import argparse parse = argparse.ArgumentParser(description="设置基本参数") parse.add_argument("--para_path", type=str, default=os.path.join(root, "data/para.json"), help="所有配置参数") parse.add_argument("--model_path", type=str, default=os.path.join( root, "model/transformer_0127/checkpoint_5.pt"), help="所有配置参数") parse.add_argument("--no_sample", action='store_true', default=False, help="Set to use greedy decoding instead of sampling") parse.add_argument("--repetition_penalty", type=float, default=0.01, help="重复惩罚项") parse.add_argument("--temperature", type=float, default=0.7, help="Sampling softmax temperature") parse.add_argument( "--top_k", type=int, default=0, help="Filter top-k tokens before sampling (<=0: no filtering)") parse.add_argument( "--top_p", type=float, default=0.9, help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)") args = parse.parse_args() with open(args.para_path, mode='r', encoding='utf-8') as fp: para_dict = json.load(fp) config = TransformerConfig(**para_dict) tokenizer = BertTokenizer(vocab_file=config.vocab_path) bos_token_id = tokenizer._convert_token_to_id("[CLS]") eos_token_id = tokenizer._convert_token_to_id("[SEP]") pad_token_id = tokenizer._convert_token_to_id("[PAD]") logger.info("Load model.") device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # 标准写法 model = Transformer(config=config) model.load_state_dict(torch.load(args.model_path, map_location="cpu"), strict=False) for name, weights in zip(model.named_parameters(), model.parameters()): logger.info("{} --- {}".format(name, weights)) model.to(device) history_tokens = [] while True: user_text = input("User-->>") while not user_text: logger.info('Prompt should not be empty!') user_text = input("User-->>") tokens = tokenizer.tokenize(user_text) history_tokens.append(tokens) # 获取输入tokens context_tokens = ["[SEP]"] for turn in history_tokens[::-1]: # 逆序访问 if len(context_tokens) + len(turn) < config.max_encode_len: context_tokens = turn + context_tokens context_tokens = ["[SEP]"] + context_tokens else: break context_tokens[0] = "[CLS]" # 将头部[SEP] token替换为[CLS] token # 编码部分 encode_input_ids = tokenizer.convert_tokens_to_ids(context_tokens) encode_input_ids = torch.tensor(encode_input_ids).long().unsqueeze( dim=0).to(device) encode_outputs, encode_attention_mask = encoder(model.encoder, encode_input_ids, pad_idx=pad_token_id) # 解码部分, 生成文本 index = 1 generate_sequence_ids = [bos_token_id] while index <= config.max_decode_len: # decode_input_ids = torch.LongTensor([generate_sequence_ids]) # 扩充为二维向量 decode_input_ids = torch.tensor( generate_sequence_ids).long().unsqueeze(dim=0).to(device) logits = decoder(model.decoder, model.trg_word_prj, decode_input_ids, encode_outputs=encode_outputs, encode_attention_mask=encode_attention_mask) next_token_logit = logits[0][-1, :] # 获取最后一个token的Logit for id in set(generate_sequence_ids): next_token_logit[id] /= args.repetition_penalty next_token_logit = top_filtering(next_token_logit, top_k=args.top_k, top_p=args.top_p) probs = F.softmax(next_token_logit, dim=-1) temp_token_id = torch.topk(probs, 1) next_token_id = torch.topk( probs, 1)[1] if args.no_sample else torch.multinomial( probs, 1) next_token_id = next_token_id.item() if next_token_id == eos_token_id: generate_sequence_ids.append(next_token_id) break generate_sequence_ids.append(next_token_id) index += 1 system_tokens = tokenizer.convert_ids_to_tokens(generate_sequence_ids) print("System-->>{}".format("".join(system_tokens[1:-1]))) history_tokens.append(system_tokens[1:-1]) # 删除首尾[CLS] 与 [SEP] token