コード例 #1
0
ファイル: translate.py プロジェクト: Yevgnen/transformer
    def __init__(self, src_vocab, tgt_vocab, checkpoint, opts):

        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

        hparams = checkpoint['hparams']

        transformer = Transformer(len(src_vocab),
                                  len(tgt_vocab),
                                  hparams.max_len + 2,
                                  n_layers=hparams.n_layers,
                                  d_model=hparams.d_model,
                                  d_emb=hparams.d_model,
                                  d_hidden=hparams.d_hidden,
                                  n_heads=hparams.n_heads,
                                  d_k=hparams.d_k,
                                  d_v=hparams.d_v,
                                  dropout=hparams.dropout,
                                  pad_id=src_vocab.pad_id)

        transformer.load_state_dict(checkpoint['model'])
        log_proj = torch.nn.LogSoftmax()

        if hparams.cuda:
            transformer.cuda()
            log_proj.cuda()

        transformer.eval()

        self.hparams = hparams
        self.opts = opts
        self.model = transformer
        self.log_proj = log_proj
コード例 #2
0
def get_transformer(opt) -> Transformer:
    model = Transformer(embed_dim=opt.embed_dim,
                        src_vocab_size=opt.src_vocab_size,
                        trg_vocab_size=opt.trg_vocab_size,
                        src_pad_idx=opt.src_pad_idx,
                        trg_pad_idx=opt.trg_pad_idx,
                        n_head=opt.n_head)
    model = model.to(opt.device)
    checkpoint_file_path = get_best_checkpoint(opt)
    if checkpoint_file_path is not None:
        print(f'Checkpoint loaded - {checkpoint_file_path}')
        checkpoint = torch.load(checkpoint_file_path, map_location=opt.device)
        model.load_state_dict(checkpoint['model'])
    return model
コード例 #3
0
def load_transformer(opt) -> Transformer:
    checkpoint_file_path = get_best_checkpoint(opt)
    checkpoint = torch.load(checkpoint_file_path, map_location=opt.device)

    assert checkpoint is not None
    assert checkpoint['opt'] is not None
    assert checkpoint['weights'] is not None

    model_opt = checkpoint['opt']
    model = Transformer(embed_dim=model_opt.embed_dim,
                        src_vocab_size=model_opt.src_vocab_size,
                        trg_vocab_size=model_opt.trg_vocab_size,
                        src_pad_idx=model_opt.src_pad_idx,
                        trg_pad_idx=model_opt.trg_pad_idx,
                        n_head=model_opt.n_head)

    model.load_state_dict(checkpoint['weights'])
    print('model loaded:', checkpoint_file_path)
    return model.to(opt.device)
コード例 #4
0
def create_model(opt):
    data = torch.load(opt.data_path)
    opt.src_vocab_size = len(data['src_dict'])
    opt.tgt_vocab_size = len(data['tgt_dict'])

    print('Creating new model parameters..')
    model = Transformer(opt)  # Initialize a model state.
    model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0}

    # If opt.model_path exists, load model parameters.
    if os.path.exists(opt.model_path):
        print('Reloading model parameters..')
        model_state = torch.load(opt.model_path)
        model.load_state_dict(model_state['model_params'])

    if use_cuda:
        print('Using GPU..')
        model = model.cuda()

    return model, model_state
コード例 #5
0
    def __init__(self, opt, use_cuda):
        self.opt = opt
        self.use_cuda = use_cuda
        self.tt = torch.cuda if use_cuda else torch

        checkpoint = torch.load(opt.model_path)
        model_opt = checkpoint['opt']

        self.model_opt = model_opt
        model = Transformer(model_opt)
        if use_cuda:
            print('Using GPU..')
            model = model.cuda()

        prob_proj = nn.LogSoftmax(dim=-1)
        model.load_state_dict(checkpoint['model_params'])
        print('Loaded pre-trained model_state..')

        self.model = model
        self.model.prob_proj = prob_proj
        self.model.eval()
コード例 #6
0
def get_model(sh_path):
    if sh_path.count(".", 0, 2) == 2:
        arguments = " ".join([s.strip() for s in Path(sh_path).read_text().replace("\\", "").replace('"', "").replace("./", "../").splitlines()[1:-1]])
    else:
        arguments = " ".join([s.strip() for s in Path(sh_path).read_text().replace("\\", "").replace('"', "").splitlines()[1:-1]])
    parser = argument_parsing(preparse=True)
    args = parser.parse_args(arguments.split())

    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
    (src, trg), (train, _, test), (train_loader, _, test_loader) = get_data(args)
    src_vocab_len = len(src.vocab.stoi)
    trg_vocab_len = len(trg.vocab.stoi)
    enc_max_seq_len = args.max_length
    dec_max_seq_len = args.max_length
    pad_idx = src.vocab.stoi.get("<pad>") if args.pad_idx is None else args.pad_idx
    pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx

    model = Transformer(enc_vocab_len=src_vocab_len, 
                        enc_max_seq_len=enc_max_seq_len, 
                        dec_vocab_len=trg_vocab_len, 
                        dec_max_seq_len=dec_max_seq_len, 
                        n_layer=args.n_layer, 
                        n_head=args.n_head, 
                        d_model=args.d_model, 
                        d_k=args.d_k, 
                        d_v=args.d_v, 
                        d_f=args.d_f, 
                        pad_idx=pad_idx,
                        pos_pad_idx=pos_pad_idx, 
                        drop_rate=args.drop_rate, 
                        use_conv=args.use_conv, 
                        linear_weight_share=args.linear_weight_share, 
                        embed_weight_share=args.embed_weight_share).to(device)
    if device == "cuda":
        model.load_state_dict(torch.load(args.save_path))
    else:
        model.load_state_dict(torch.load(args.save_path, map_location=torch.device(device)))
    
    return model, (src, trg), (test, test_loader)
コード例 #7
0
def main(args):
    # configs path to load data & save model
    from pathlib import Path
    if not Path(args.root_dir).exists():
        Path(args.root_dir).mkdir()

    p = Path(args.save_path).parent
    if not p.exists():
        p.mkdir()

    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
    import sys
    print(sys.version)
    print(f"Using {device}")
    print("Loading Data...")
    (src, trg), (train, valid, _), (train_loader, valid_loader,
                                    _) = get_data(args)
    src_vocab_len = len(src.vocab.stoi)
    trg_vocab_len = len(trg.vocab.stoi)
    # check vocab size
    print(f"SRC vocab {src_vocab_len}, TRG vocab {trg_vocab_len}")
    enc_max_seq_len = args.max_length
    dec_max_seq_len = args.max_length
    pad_idx = src.vocab.stoi.get(
        "<pad>") if args.pad_idx is None else args.pad_idx
    enc_sos_idx = src.vocab.stoi.get(
        "<s>") if args.enc_sos_idx is None else args.enc_sos_idx
    enc_eos_idx = src.vocab.stoi.get(
        "</s>") if args.enc_eos_idx is None else args.enc_eos_idx
    dec_sos_idx = trg.vocab.stoi.get(
        "<s>") if args.dec_sos_idx is None else args.dec_sos_idx
    dec_eos_idx = trg.vocab.stoi.get(
        "</s>") if args.dec_eos_idx is None else args.dec_eos_idx
    pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx

    print("Building Model...")
    model = Transformer(enc_vocab_len=src_vocab_len,
                        enc_max_seq_len=enc_max_seq_len,
                        dec_vocab_len=trg_vocab_len,
                        dec_max_seq_len=dec_max_seq_len,
                        n_layer=args.n_layer,
                        n_head=args.n_head,
                        d_model=args.d_model,
                        d_k=args.d_k,
                        d_v=args.d_v,
                        d_f=args.d_f,
                        pad_idx=pad_idx,
                        pos_pad_idx=pos_pad_idx,
                        drop_rate=args.drop_rate,
                        use_conv=args.use_conv,
                        linear_weight_share=args.linear_weight_share,
                        embed_weight_share=args.embed_weight_share).to(device)

    if args.load_path is not None:
        print(f"Load Model {args.load_path}")
        model.load_state_dict(torch.load(args.load_path))

    # build loss function using LabelSmoothing
    loss_function = LabelSmoothing(trg_vocab_size=trg_vocab_len,
                                   pad_idx=args.pad_idx,
                                   eps=args.smooth_eps)

    optimizer = WarmUpOptim(warmup_steps=args.warmup_steps,
                            d_model=args.d_model,
                            optimizer=optim.Adam(model.parameters(),
                                                 betas=(args.beta1,
                                                        args.beta2),
                                                 eps=10e-9))

    trainer = Trainer(optimizer=optimizer,
                      train_loader=train_loader,
                      test_loader=valid_loader,
                      n_step=args.n_step,
                      device=device,
                      save_path=args.save_path,
                      enc_sos_idx=enc_sos_idx,
                      enc_eos_idx=enc_eos_idx,
                      dec_sos_idx=dec_sos_idx,
                      dec_eos_idx=dec_eos_idx,
                      metrics_method=args.metrics_method,
                      verbose=args.verbose)
    print("Start Training...")
    trainer.main(model=model, loss_function=loss_function)
コード例 #8
0
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-share_proj_weight', action='store_true')
    parser.add_argument('-share_embs_weight', action='store_true')
    parser.add_argument('-weighted_model', action='store_true')

    # training params
    parser.add_argument('-lr', type=float, default=0.002)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-max_src_seq_len', type=int, default=50)
    parser.add_argument('-max_tgt_seq_len', type=int, default=10)
    parser.add_argument('-max_grad_norm', type=float, default=None)
    parser.add_argument('-n_warmup_steps', type=int, default=4000)
    parser.add_argument('-display_freq', type=int, default=100)
    parser.add_argument('-log', default=None)

    opt = parser.parse_args()

    data = torch.load(opt.data_path)
    opt.src_vocab_size = len(data['src_dict'])
    opt.tgt_vocab_size = len(data['tgt_dict'])

    print('Creating new model parameters..')
    model = Transformer(opt)  # Initialize a model state.
    model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0}

    print('Reloading model parameters..')
    model_state = torch.load('./train_log/emoji_model.pt', map_location=device)
    model.load_state_dict(model_state['model_params'])

    emojilize(opt, model)
コード例 #9
0
def main():
    import argparse
    parse = argparse.ArgumentParser(description="设置基本参数")
    parse.add_argument("--para_path",
                       type=str,
                       default=os.path.join(root, "data/para.json"),
                       help="所有配置参数")
    parse.add_argument("--model_path",
                       type=str,
                       default=os.path.join(
                           root, "model/transformer_0127/checkpoint_5.pt"),
                       help="所有配置参数")
    parse.add_argument("--no_sample",
                       action='store_true',
                       default=False,
                       help="Set to use greedy decoding instead of sampling")
    parse.add_argument("--repetition_penalty",
                       type=float,
                       default=0.01,
                       help="重复惩罚项")
    parse.add_argument("--temperature",
                       type=float,
                       default=0.7,
                       help="Sampling softmax temperature")
    parse.add_argument(
        "--top_k",
        type=int,
        default=0,
        help="Filter top-k tokens before sampling (<=0: no filtering)")
    parse.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
    args = parse.parse_args()

    with open(args.para_path, mode='r', encoding='utf-8') as fp:
        para_dict = json.load(fp)

    config = TransformerConfig(**para_dict)

    tokenizer = BertTokenizer(vocab_file=config.vocab_path)
    bos_token_id = tokenizer._convert_token_to_id("[CLS]")
    eos_token_id = tokenizer._convert_token_to_id("[SEP]")
    pad_token_id = tokenizer._convert_token_to_id("[PAD]")

    logger.info("Load model.")
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  # 标准写法
    model = Transformer(config=config)
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"),
                          strict=False)
    for name, weights in zip(model.named_parameters(), model.parameters()):
        logger.info("{} --- {}".format(name, weights))
    model.to(device)

    history_tokens = []
    while True:
        user_text = input("User-->>")
        while not user_text:
            logger.info('Prompt should not be empty!')
            user_text = input("User-->>")
        tokens = tokenizer.tokenize(user_text)
        history_tokens.append(tokens)

        # 获取输入tokens
        context_tokens = ["[SEP]"]
        for turn in history_tokens[::-1]:  # 逆序访问
            if len(context_tokens) + len(turn) < config.max_encode_len:
                context_tokens = turn + context_tokens
                context_tokens = ["[SEP]"] + context_tokens
            else:
                break
        context_tokens[0] = "[CLS]"  # 将头部[SEP] token替换为[CLS] token

        # 编码部分
        encode_input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
        encode_input_ids = torch.tensor(encode_input_ids).long().unsqueeze(
            dim=0).to(device)
        encode_outputs, encode_attention_mask = encoder(model.encoder,
                                                        encode_input_ids,
                                                        pad_idx=pad_token_id)

        # 解码部分, 生成文本
        index = 1
        generate_sequence_ids = [bos_token_id]
        while index <= config.max_decode_len:
            # decode_input_ids = torch.LongTensor([generate_sequence_ids])  # 扩充为二维向量
            decode_input_ids = torch.tensor(
                generate_sequence_ids).long().unsqueeze(dim=0).to(device)
            logits = decoder(model.decoder,
                             model.trg_word_prj,
                             decode_input_ids,
                             encode_outputs=encode_outputs,
                             encode_attention_mask=encode_attention_mask)
            next_token_logit = logits[0][-1, :]  # 获取最后一个token的Logit
            for id in set(generate_sequence_ids):
                next_token_logit[id] /= args.repetition_penalty
            next_token_logit = top_filtering(next_token_logit,
                                             top_k=args.top_k,
                                             top_p=args.top_p)
            probs = F.softmax(next_token_logit, dim=-1)

            temp_token_id = torch.topk(probs, 1)
            next_token_id = torch.topk(
                probs, 1)[1] if args.no_sample else torch.multinomial(
                    probs, 1)
            next_token_id = next_token_id.item()

            if next_token_id == eos_token_id:
                generate_sequence_ids.append(next_token_id)
                break

            generate_sequence_ids.append(next_token_id)
            index += 1

        system_tokens = tokenizer.convert_ids_to_tokens(generate_sequence_ids)
        print("System-->>{}".format("".join(system_tokens[1:-1])))
        history_tokens.append(system_tokens[1:-1])  # 删除首尾[CLS] 与 [SEP] token