Пример #1
0
def main():
    import argparse
    parse = argparse.ArgumentParser(description="设置基本参数")
    # model parameter
    parse.add_argument("--vocab_size", type=int, default=1000, help="字典大小")
    parse.add_argument("--n_position",
                       type=int,
                       default=256,
                       help="位置数量序列最大长度")
    parse.add_argument("--word_vec_size",
                       type=int,
                       default=512,
                       help="embedding输出大小")
    parse.add_argument("--d_model", type=int, default=512, help="隐层大小")
    parse.add_argument("--d_inner", type=int, default=1024, help="隐层中间层大小")
    parse.add_argument("--n_head", type=int, default=8, help="自注意力头的数量")
    parse.add_argument("--d_k",
                       type=int,
                       default=64,
                       help="d_model/n_head每个头隐层的大小")
    parse.add_argument("--d_v",
                       type=int,
                       default=64,
                       help="d_model/n_head每个头隐层的大小")
    parse.add_argument("--encoder_n_layers", type=int, default=6, help="编码的层数")
    parse.add_argument("--decoder_n_layers", type=int, default=6, help="解码的层数")
    parse.add_argument("--dropout", type=float, default=0.1, help="dropout概率")
    parse.add_argument("--pad_idx", type=int, default=-1, help="padding index")
    parse.add_argument("--trg_emb_prj_weight_sharing",
                       action="store_true",
                       default=True)
    parse.add_argument("--emb_src_trg_weight_sharing",
                       action="store_true",
                       default=True)

    # data parameter
    parse.add_argument("--vocab_path",
                       type=str,
                       default=os.path.join(root, "vocabulary/vocab.txt"),
                       help="词汇表路径")
    parse.add_argument("--train_data_path",
                       type=str,
                       default=os.path.join(root, "data/train_small.txt"),
                       help="训练数据路径")
    parse.add_argument("--evaluate_data_path",
                       type=str,
                       default=None,
                       help="评估数据路径")
    parse.add_argument("--max_encode_len",
                       type=int,
                       default=192,
                       help="最大编码序列长度")
    parse.add_argument("--max_decode_len",
                       type=int,
                       default=64,
                       help="最大解码序列长度")
    parse.add_argument("--history_turns", type=int, default=3, help="历史对话轮数")
    parse.add_argument("--max_lines", type=int, default=525106, help="最多处理数据量")
    parse.add_argument("--batch_size",
                       type=int,
                       default=32,
                       help="batch size 大小")

    # train parameter
    parse.add_argument("--epochs", type=int, default=20, help="训练epoch数量")
    parse.add_argument("--save_epoch",
                       type=int,
                       default=5,
                       help="每训练多少epoch保存一次模型")
    parse.add_argument("--save_dir",
                       type=str,
                       default=os.path.join(root, "model/transformer_0127"),
                       help="模型保存路径")
    parse.add_argument("--init_lr", type=float, default=1.0, help="初始学习率")
    parse.add_argument("--n_warmup_steps", type=int, default=100, help="热身步长")
    parse.add_argument("--label_smoothing", action="store_true", default=False)

    args = parse.parse_args()

    tokenizer = BertTokenizer(vocab_file=args.vocab_path)
    args.vocab_size = tokenizer.vocab_size
    args.pad_idx = tokenizer._convert_token_to_id("[PAD]")

    args_dict = vars(args)
    config = TransformerConfig(**args_dict)

    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)  # 创建模型保存路径

    logger.info("Load dataset.")
    train_dataset = ChatDataset(config.train_data_path,
                                tokenizer=tokenizer,
                                max_encode_len=config.max_encode_len,
                                max_decode_len=config.max_decode_len,
                                history_turns=config.history_turns,
                                max_lines=config.max_lines)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    if config.evaluate_data_path is not None:
        eval_dataset = ChatDataset(config.evaluate_data_path,
                                   tokenizer=tokenizer,
                                   max_encode_len=config.max_encode_len,
                                   max_decode_len=config.max_decode_len,
                                   history_turns=config.history_turns,
                                   max_lines=config.max_lines)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=config.batch_size,
                                 shuffle=False)
    else:
        eval_loader = False

    logger.info("Load model.")
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  # 标准写法
    model = Transformer(config=config)
    model.to(device)

    logger.info("Load optimizer.")
    optimizer = ScheduledOptim(
        optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        config.init_lr, config.d_model, config.n_warmup_steps)

    logger.info("Save all config parameter.")
    config.save_para_to_json_file(os.path.join(root, "data/para.json"))

    logger.info("Training model.")
    train(config,
          model,
          optimizer,
          train_loader=train_loader,
          eval_loader=eval_loader,
          device=device)
Пример #2
0
def main(args):
    # configs path to load data & save model
    from pathlib import Path
    if not Path(args.root_dir).exists():
        Path(args.root_dir).mkdir()

    p = Path(args.save_path).parent
    if not p.exists():
        p.mkdir()

    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
    import sys
    print(sys.version)
    print(f"Using {device}")
    print("Loading Data...")
    (src, trg), (train, valid, _), (train_loader, valid_loader,
                                    _) = get_data(args)
    src_vocab_len = len(src.vocab.stoi)
    trg_vocab_len = len(trg.vocab.stoi)
    # check vocab size
    print(f"SRC vocab {src_vocab_len}, TRG vocab {trg_vocab_len}")
    enc_max_seq_len = args.max_length
    dec_max_seq_len = args.max_length
    pad_idx = src.vocab.stoi.get(
        "<pad>") if args.pad_idx is None else args.pad_idx
    enc_sos_idx = src.vocab.stoi.get(
        "<s>") if args.enc_sos_idx is None else args.enc_sos_idx
    enc_eos_idx = src.vocab.stoi.get(
        "</s>") if args.enc_eos_idx is None else args.enc_eos_idx
    dec_sos_idx = trg.vocab.stoi.get(
        "<s>") if args.dec_sos_idx is None else args.dec_sos_idx
    dec_eos_idx = trg.vocab.stoi.get(
        "</s>") if args.dec_eos_idx is None else args.dec_eos_idx
    pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx

    print("Building Model...")
    model = Transformer(enc_vocab_len=src_vocab_len,
                        enc_max_seq_len=enc_max_seq_len,
                        dec_vocab_len=trg_vocab_len,
                        dec_max_seq_len=dec_max_seq_len,
                        n_layer=args.n_layer,
                        n_head=args.n_head,
                        d_model=args.d_model,
                        d_k=args.d_k,
                        d_v=args.d_v,
                        d_f=args.d_f,
                        pad_idx=pad_idx,
                        pos_pad_idx=pos_pad_idx,
                        drop_rate=args.drop_rate,
                        use_conv=args.use_conv,
                        linear_weight_share=args.linear_weight_share,
                        embed_weight_share=args.embed_weight_share).to(device)

    if args.load_path is not None:
        print(f"Load Model {args.load_path}")
        model.load_state_dict(torch.load(args.load_path))

    # build loss function using LabelSmoothing
    loss_function = LabelSmoothing(trg_vocab_size=trg_vocab_len,
                                   pad_idx=args.pad_idx,
                                   eps=args.smooth_eps)

    optimizer = WarmUpOptim(warmup_steps=args.warmup_steps,
                            d_model=args.d_model,
                            optimizer=optim.Adam(model.parameters(),
                                                 betas=(args.beta1,
                                                        args.beta2),
                                                 eps=10e-9))

    trainer = Trainer(optimizer=optimizer,
                      train_loader=train_loader,
                      test_loader=valid_loader,
                      n_step=args.n_step,
                      device=device,
                      save_path=args.save_path,
                      enc_sos_idx=enc_sos_idx,
                      enc_eos_idx=enc_eos_idx,
                      dec_sos_idx=dec_sos_idx,
                      dec_eos_idx=dec_eos_idx,
                      metrics_method=args.metrics_method,
                      verbose=args.verbose)
    print("Start Training...")
    trainer.main(model=model, loss_function=loss_function)
Пример #3
0
def main():
    import argparse
    parse = argparse.ArgumentParser(description="设置基本参数")
    parse.add_argument("--para_path",
                       type=str,
                       default=os.path.join(root, "data/para.json"),
                       help="所有配置参数")
    parse.add_argument("--model_path",
                       type=str,
                       default=os.path.join(
                           root, "model/transformer_0127/checkpoint_5.pt"),
                       help="所有配置参数")
    parse.add_argument("--no_sample",
                       action='store_true',
                       default=False,
                       help="Set to use greedy decoding instead of sampling")
    parse.add_argument("--repetition_penalty",
                       type=float,
                       default=0.01,
                       help="重复惩罚项")
    parse.add_argument("--temperature",
                       type=float,
                       default=0.7,
                       help="Sampling softmax temperature")
    parse.add_argument(
        "--top_k",
        type=int,
        default=0,
        help="Filter top-k tokens before sampling (<=0: no filtering)")
    parse.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
    args = parse.parse_args()

    with open(args.para_path, mode='r', encoding='utf-8') as fp:
        para_dict = json.load(fp)

    config = TransformerConfig(**para_dict)

    tokenizer = BertTokenizer(vocab_file=config.vocab_path)
    bos_token_id = tokenizer._convert_token_to_id("[CLS]")
    eos_token_id = tokenizer._convert_token_to_id("[SEP]")
    pad_token_id = tokenizer._convert_token_to_id("[PAD]")

    logger.info("Load model.")
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  # 标准写法
    model = Transformer(config=config)
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"),
                          strict=False)
    for name, weights in zip(model.named_parameters(), model.parameters()):
        logger.info("{} --- {}".format(name, weights))
    model.to(device)

    history_tokens = []
    while True:
        user_text = input("User-->>")
        while not user_text:
            logger.info('Prompt should not be empty!')
            user_text = input("User-->>")
        tokens = tokenizer.tokenize(user_text)
        history_tokens.append(tokens)

        # 获取输入tokens
        context_tokens = ["[SEP]"]
        for turn in history_tokens[::-1]:  # 逆序访问
            if len(context_tokens) + len(turn) < config.max_encode_len:
                context_tokens = turn + context_tokens
                context_tokens = ["[SEP]"] + context_tokens
            else:
                break
        context_tokens[0] = "[CLS]"  # 将头部[SEP] token替换为[CLS] token

        # 编码部分
        encode_input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
        encode_input_ids = torch.tensor(encode_input_ids).long().unsqueeze(
            dim=0).to(device)
        encode_outputs, encode_attention_mask = encoder(model.encoder,
                                                        encode_input_ids,
                                                        pad_idx=pad_token_id)

        # 解码部分, 生成文本
        index = 1
        generate_sequence_ids = [bos_token_id]
        while index <= config.max_decode_len:
            # decode_input_ids = torch.LongTensor([generate_sequence_ids])  # 扩充为二维向量
            decode_input_ids = torch.tensor(
                generate_sequence_ids).long().unsqueeze(dim=0).to(device)
            logits = decoder(model.decoder,
                             model.trg_word_prj,
                             decode_input_ids,
                             encode_outputs=encode_outputs,
                             encode_attention_mask=encode_attention_mask)
            next_token_logit = logits[0][-1, :]  # 获取最后一个token的Logit
            for id in set(generate_sequence_ids):
                next_token_logit[id] /= args.repetition_penalty
            next_token_logit = top_filtering(next_token_logit,
                                             top_k=args.top_k,
                                             top_p=args.top_p)
            probs = F.softmax(next_token_logit, dim=-1)

            temp_token_id = torch.topk(probs, 1)
            next_token_id = torch.topk(
                probs, 1)[1] if args.no_sample else torch.multinomial(
                    probs, 1)
            next_token_id = next_token_id.item()

            if next_token_id == eos_token_id:
                generate_sequence_ids.append(next_token_id)
                break

            generate_sequence_ids.append(next_token_id)
            index += 1

        system_tokens = tokenizer.convert_ids_to_tokens(generate_sequence_ids)
        print("System-->>{}".format("".join(system_tokens[1:-1])))
        history_tokens.append(system_tokens[1:-1])  # 删除首尾[CLS] 与 [SEP] token