Exemplo n.º 1
0
def evaluate(config: Transformer,
             model: Transformer,
             eval_loader: DataLoader,
             device='cpu'):
    """evaluate function"""
    model.eval()
    with torch.no_grad():
        total_eval_loss, n_word_correct, n_word_total = 0.0, 0, 0
        for ids, sample in enumerate(tqdm(eval_loader)):
            input_ids, decoder_input_ids, decoder_target_ids = (
                sample['input_ids'].to(device),
                sample['decode_input_ids'].to(device),
                sample['decode_label_ids'].to(device))
            logits = model(input_ids, decoder_input_ids)
            loss, n_correct, n_word = cal_performance(
                logits,
                gold=decoder_target_ids,
                trg_pad_idx=config.pad_idx,
                smoothing=config.label_smoothing)
            total_eval_loss += loss.item()
            n_word_correct += n_correct
            n_word_total += n_word

        average_loss = total_eval_loss / n_word_total
        accuracy = n_word_correct / n_word_total

        return average_loss, accuracy
Exemplo n.º 2
0
def evaluate_epoch(opt: Namespace, model: Transformer, val_data):
    model.eval()
    start_time = datetime.now()
    total_loss = total_word = total_corrected_word = 0

    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_data),
                             total=len(val_data),
                             leave=False):
            # Prepare validation data
            src_input, trg_input, y_true = _prepare_batch_data(
                batch, opt.device)

            # Forward
            y_pred = model(src_input, trg_input)
            loss = calculate_loss(y_pred, y_true, opt.trg_pad_idx)
            n_word, n_corrected = calculate_performance(
                y_pred, y_true, opt.trg_pad_idx)

            # Validation Logs
            total_loss += loss.item()
            total_word += n_word
            total_corrected_word += n_corrected

    loss_per_word = total_loss / total_word
    accuracy = total_corrected_word / total_word

    return {
        'total_seconds': (datetime.now() - start_time).total_seconds(),
        'total_loss': total_loss,
        'total_word': total_word,
        'total_corrected_word': total_corrected_word,
        'loss_per_word': loss_per_word,
        'accuracy': accuracy
    }
Exemplo n.º 3
0
def train(config: Transformer,
          model: Transformer,
          optimizer: ScheduledOptim,
          train_loader: DataLoader,
          eval_loader: DataLoader = None,
          device='cpu'):
    """train function"""
    model.train()
    best_eval_accuracy = -float('Inf')
    for epoch in range(config.epochs):
        logger.info("Epoch: {}".format(epoch))
        total_loss, n_word_total, n_word_correct = 0, 0, 0
        for ids, sample in enumerate(tqdm(train_loader)):
            for k, v in sample.items():
                sample[k] = v.to(device)
            input_ids, decoder_input_ids, decoder_target_ids = (
                sample['input_ids'], sample['decode_input_ids'],
                sample['decode_label_ids'])
            optimizer.zero_grad()
            logits = model(input_ids, decoder_input_ids)
            loss, n_correct, n_word = cal_performance(
                logits,
                gold=decoder_target_ids,
                trg_pad_idx=config.pad_idx,
                smoothing=config.label_smoothing)
            loss.backward()
            optimizer.step_and_update_lr()

            # note keeping
            n_word_total += n_word
            n_word_correct += n_correct
            total_loss += loss.item()
        loss_per_word = total_loss / n_word_total
        accuracy = n_word_correct / n_word_total
        logger.info("The {} epoch train loss: {}, train accuray: {}".format(
            epoch, loss_per_word, accuracy))

        if eval_loader is not None:
            eval_loss, eval_accuracy = evaluate(config,
                                                model,
                                                eval_loader=eval_loader,
                                                device=device)
            if eval_accuracy > best_eval_accuracy:
                best_eval_accuracy = eval_accuracy

                # 保存最佳模型
                model_save = model.module if hasattr(model,
                                                     "module") else model
                model_file = os.path.join(config.save_dir,
                                          "checkpoint_{}.pt".format(epoch))
                torch.save(model_save.state_dict(), f=model_file)

        if epoch % config.save_epoch == 0:
            model_save = model.module if hasattr(model, "module") else model
            model_file = os.path.join(config.save_dir,
                                      "checkpoint_{}.pt".format(epoch))
            torch.save(model_save.state_dict(), f=model_file)

    return model
Exemplo n.º 4
0
 def __init__(self, in_dim, mem_dim, opt):
     super(ChildSumTreeLSTM, self).__init__()
     self.in_dim = in_dim
     self.mem_dim = mem_dim
     '''
     self.ioux = nn.Linear(self.in_dim, self.mem_dim)
     self.iouh = nn.Linear(self.mem_dim, self.mem_dim)
     self.fx = nn.Linear(self.in_dim, self.mem_dim)
     self.fh = nn.Linear(self.mem_dim, self.mem_dim)
     self.Wv = nn.Linear(self.mem_dim, self.mem_dim)
     '''
     self.transformer = Transformer(opt)
Exemplo n.º 5
0
    def __init__(self, src_vocab, tgt_vocab, checkpoint, opts):

        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

        hparams = checkpoint['hparams']

        transformer = Transformer(len(src_vocab),
                                  len(tgt_vocab),
                                  hparams.max_len + 2,
                                  n_layers=hparams.n_layers,
                                  d_model=hparams.d_model,
                                  d_emb=hparams.d_model,
                                  d_hidden=hparams.d_hidden,
                                  n_heads=hparams.n_heads,
                                  d_k=hparams.d_k,
                                  d_v=hparams.d_v,
                                  dropout=hparams.dropout,
                                  pad_id=src_vocab.pad_id)

        transformer.load_state_dict(checkpoint['model'])
        log_proj = torch.nn.LogSoftmax()

        if hparams.cuda:
            transformer.cuda()
            log_proj.cuda()

        transformer.eval()

        self.hparams = hparams
        self.opts = opts
        self.model = transformer
        self.log_proj = log_proj
Exemplo n.º 6
0
def get_transformer(opt) -> Transformer:
    model = Transformer(embed_dim=opt.embed_dim,
                        src_vocab_size=opt.src_vocab_size,
                        trg_vocab_size=opt.trg_vocab_size,
                        src_pad_idx=opt.src_pad_idx,
                        trg_pad_idx=opt.trg_pad_idx,
                        n_head=opt.n_head)
    model = model.to(opt.device)
    checkpoint_file_path = get_best_checkpoint(opt)
    if checkpoint_file_path is not None:
        print(f'Checkpoint loaded - {checkpoint_file_path}')
        checkpoint = torch.load(checkpoint_file_path, map_location=opt.device)
        model.load_state_dict(checkpoint['model'])
    return model
Exemplo n.º 7
0
def train(opt: Namespace, model: Transformer, optimizer: ScheduledAdam):
    if not os.path.exists(opt.checkpoint_path):
        os.makedirs(opt.checkpoint_path)

    train_data, val_data, src_vocab, trg_vocab = load_preprocessed_data(opt)
    min_loss = float('inf')

    for epoch in range(opt.epoch):
        # Training and Evaluation
        _t = train_per_epoch(opt, model, optimizer, train_data, src_vocab,
                             trg_vocab)
        _v = evaluate_epoch(opt, model, val_data, src_vocab, trg_vocab)

        # Save checkpoint
        min_loss = _v['loss_per_word']
        checkpoint = {
            'epoch': epoch,
            'opt': opt,
            'weights': model.state_dict(),
            'loss': min_loss,
            '_t': _t,
            '_v': _v
        }
        model_name = os.path.join(
            opt.checkpoint_path, f'checkpoint_{epoch:04}_{min_loss:.4f}.chkpt')
        torch.save(checkpoint, model_name)
        is_checkpointed = True

        # Print performance
        _show_performance(epoch=epoch,
                          step=optimizer.n_step,
                          lr=optimizer.lr,
                          t=_t,
                          v=_v,
                          checkpoint=is_checkpointed)
Exemplo n.º 8
0
def train_per_epoch(opt: Namespace, model: Transformer,
                    optimizer: ScheduledAdam, train_data, src_vocab,
                    trg_vocab) -> dict:
    model.train()
    start_time = datetime.now()
    total_loss = total_word = total_corrected_word = 0

    for i, batch in tqdm(enumerate(train_data),
                         total=len(train_data),
                         leave=False):
        src_input, trg_input, y_true = _prepare_batch_data(batch, opt.device)

        # Forward
        optimizer.zero_grad()
        y_pred = model(src_input, trg_input)

        # DEBUG
        pred_sentence = to_sentence(y_pred[0], trg_vocab)
        true_sentence = to_sentence(batch.trg[:, 0], trg_vocab)
        print(pred_sentence)
        print(true_sentence)
        import ipdb
        ipdb.set_trace()

        # Backward and update parameters
        loss = calculate_loss(y_pred, y_true, opt.trg_pad_idx, trg_vocab)
        n_word, n_corrected = calculate_performance(y_pred, y_true,
                                                    opt.trg_pad_idx)
        loss.backward()
        optimizer.step()

        # Training Logs
        total_loss += loss.item()
        total_word += n_word
        total_corrected_word += n_corrected

    loss_per_word = total_loss / total_word
    accuracy = total_corrected_word / total_word

    return {
        'total_seconds': (datetime.now() - start_time).total_seconds(),
        'total_loss': total_loss,
        'total_word': total_word,
        'total_corrected_word': total_corrected_word,
        'loss_per_word': loss_per_word,
        'accuracy': accuracy
    }
Exemplo n.º 9
0
    def __init__(self, config, transformer_opt):
        super(NLINet, self).__init__()
        use_Transformer = False
        # classifier
        self.nonlinear_fc = config['nonlinear_fc']
        self.fc_dim = config['fc_dim']
        self.n_classes = config['n_classes']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.encoder_type = config['encoder_type']
        self.dpout_fc = config['dpout_fc']

        if use_Transformer:
            self.encoder_type = 'transformer'
            self.encoder = Transformer(transformer_opt)
        else:
            self.encoder = eval(self.encoder_type)(config)

        self.inputdim = 4 * 1 * self.enc_lstm_dim
        self.inputdim = 4*self.inputdim if self.encoder_type in \
                        ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim
        self.inputdim = int(self.inputdim/2) if self.encoder_type == "LSTMEncoder" \
                                        else self.inputdim
        if self.encoder_type == "transformer":
            self.inputdim = 300
            self.w_kp = torch.rand(5)
            self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
#            self.w1 = nn.Parameter(torch.FloatTensor(5, 300, 1500))
#            self.w_a = torch.rand(5)
#            self.w_a = nn.Parameter(self.w_a/self.w_a.sum())

        if self.nonlinear_fc:
            self.classifier = nn.Sequential(
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.n_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(int(self.inputdim), self.fc_dim),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Linear(self.fc_dim, self.n_classes))
Exemplo n.º 10
0
def load_transformer(opt) -> Transformer:
    checkpoint_file_path = get_best_checkpoint(opt)
    checkpoint = torch.load(checkpoint_file_path, map_location=opt.device)

    assert checkpoint is not None
    assert checkpoint['opt'] is not None
    assert checkpoint['weights'] is not None

    model_opt = checkpoint['opt']
    model = Transformer(embed_dim=model_opt.embed_dim,
                        src_vocab_size=model_opt.src_vocab_size,
                        trg_vocab_size=model_opt.trg_vocab_size,
                        src_pad_idx=model_opt.src_pad_idx,
                        trg_pad_idx=model_opt.trg_pad_idx,
                        n_head=model_opt.n_head)

    model.load_state_dict(checkpoint['weights'])
    print('model loaded:', checkpoint_file_path)
    return model.to(opt.device)
Exemplo n.º 11
0
def create_model(opt):
    data = torch.load(opt.data_path)
    opt.src_vocab_size = len(data['src_dict'])
    opt.tgt_vocab_size = len(data['tgt_dict'])

    print('Creating new model parameters..')
    model = Transformer(opt)  # Initialize a model state.
    model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0}

    # If opt.model_path exists, load model parameters.
    if os.path.exists(opt.model_path):
        print('Reloading model parameters..')
        model_state = torch.load(opt.model_path)
        model.load_state_dict(model_state['model_params'])

    if use_cuda:
        print('Using GPU..')
        model = model.cuda()

    return model, model_state
Exemplo n.º 12
0
    def __init__(self, opt, use_cuda):
        self.opt = opt
        self.use_cuda = use_cuda
        self.tt = torch.cuda if use_cuda else torch

        checkpoint = torch.load(opt.model_path)
        model_opt = checkpoint['opt']

        self.model_opt = model_opt
        model = Transformer(model_opt)
        if use_cuda:
            print('Using GPU..')
            model = model.cuda()

        prob_proj = nn.LogSoftmax(dim=-1)
        model.load_state_dict(checkpoint['model_params'])
        print('Loaded pre-trained model_state..')

        self.model = model
        self.model.prob_proj = prob_proj
        self.model.eval()
Exemplo n.º 13
0
def get_model(sh_path):
    if sh_path.count(".", 0, 2) == 2:
        arguments = " ".join([s.strip() for s in Path(sh_path).read_text().replace("\\", "").replace('"', "").replace("./", "../").splitlines()[1:-1]])
    else:
        arguments = " ".join([s.strip() for s in Path(sh_path).read_text().replace("\\", "").replace('"', "").splitlines()[1:-1]])
    parser = argument_parsing(preparse=True)
    args = parser.parse_args(arguments.split())

    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
    (src, trg), (train, _, test), (train_loader, _, test_loader) = get_data(args)
    src_vocab_len = len(src.vocab.stoi)
    trg_vocab_len = len(trg.vocab.stoi)
    enc_max_seq_len = args.max_length
    dec_max_seq_len = args.max_length
    pad_idx = src.vocab.stoi.get("<pad>") if args.pad_idx is None else args.pad_idx
    pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx

    model = Transformer(enc_vocab_len=src_vocab_len, 
                        enc_max_seq_len=enc_max_seq_len, 
                        dec_vocab_len=trg_vocab_len, 
                        dec_max_seq_len=dec_max_seq_len, 
                        n_layer=args.n_layer, 
                        n_head=args.n_head, 
                        d_model=args.d_model, 
                        d_k=args.d_k, 
                        d_v=args.d_v, 
                        d_f=args.d_f, 
                        pad_idx=pad_idx,
                        pos_pad_idx=pos_pad_idx, 
                        drop_rate=args.drop_rate, 
                        use_conv=args.use_conv, 
                        linear_weight_share=args.linear_weight_share, 
                        embed_weight_share=args.embed_weight_share).to(device)
    if device == "cuda":
        model.load_state_dict(torch.load(args.save_path))
    else:
        model.load_state_dict(torch.load(args.save_path, map_location=torch.device(device)))
    
    return model, (src, trg), (test, test_loader)
Exemplo n.º 14
0
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-share_proj_weight', action='store_true')
    parser.add_argument('-share_embs_weight', action='store_true')
    parser.add_argument('-weighted_model', action='store_true')

    # training params
    parser.add_argument('-lr', type=float, default=0.002)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-max_src_seq_len', type=int, default=50)
    parser.add_argument('-max_tgt_seq_len', type=int, default=10)
    parser.add_argument('-max_grad_norm', type=float, default=None)
    parser.add_argument('-n_warmup_steps', type=int, default=4000)
    parser.add_argument('-display_freq', type=int, default=100)
    parser.add_argument('-log', default=None)

    opt = parser.parse_args()

    data = torch.load(opt.data_path)
    opt.src_vocab_size = len(data['src_dict'])
    opt.tgt_vocab_size = len(data['tgt_dict'])

    print('Creating new model parameters..')
    model = Transformer(opt)  # Initialize a model state.
    model_state = {'opt': opt, 'curr_epochs': 0, 'train_steps': 0}

    print('Reloading model parameters..')
    model_state = torch.load('./train_log/emoji_model.pt', map_location=device)
    model.load_state_dict(model_state['model_params'])

    emojilize(opt, model)
Exemplo n.º 15
0
class ChildSumTreeLSTM(nn.Module):
    def __init__(self, in_dim, mem_dim, opt):
        super(ChildSumTreeLSTM, self).__init__()
        self.in_dim = in_dim
        self.mem_dim = mem_dim
        '''
        self.ioux = nn.Linear(self.in_dim, self.mem_dim)
        self.iouh = nn.Linear(self.mem_dim, self.mem_dim)
        self.fx = nn.Linear(self.in_dim, self.mem_dim)
        self.fh = nn.Linear(self.mem_dim, self.mem_dim)
        self.Wv = nn.Linear(self.mem_dim, self.mem_dim)
        '''
        self.transformer = Transformer(opt)
        #self.W_mv = nn.Parameter(torch.randn(50, 100))
        #self.W_mv_M = nn.Parameter(torch.randn(50, 100))

    def node_forward(self, inputs, child_c, child_h):
        child_h_sum = torch.sum(child_h, dim=0, keepdim=True)

        iou = self.ioux(inputs) + self.iouh(child_h_sum)
        i, o, u = torch.split(iou, iou.size(1) // 3, dim=1)
        i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u)

        f = F.sigmoid(
            self.fh(child_h) + self.fx(inputs).repeat(len(child_h), 1))
        fc = torch.mul(f, child_c)

        c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True)
        h = torch.mul(o, F.tanh(c))
        return c, h

    def forward(self, tree, inputs, arcs, S, ttype):
        '''
        num_words = 1
        child_words = []
        residual = []
        residual.append(inputs[tree.idx].unsqueeze(0))
        for idx in range(tree.num_children):
            self.forward(tree.children[idx], inputs, arc, S)
            num_words += tree.children[idx].words
            child_words.append(tree.children[idx].words)
            residual.append(inputs[tree.children[idx].idx].unsqueeze(0))
        
        tree.words = num_words
        child_words.append(tree.words)

        if tree.num_children == 0:
            tree.state = inputs[tree.idx].view(1,-1) #child_h
            tree.words = 1
            return tree.words
        else:
            states = []
            for x in tree.children:
                states.append(x.state)
            child_h = torch.cat(states, dim=0)
        
        x_hat = inputs[tree.idx].view(1,-1)
        tree.state = self.transformer.tree_encode(x_hat, child_h.unsqueeze(0), S, child_words, residual)
        
        return tree.state
        '''
        num_words = 1
        child_words = []
        residual = []
        residual.append(inputs[tree.idx].unsqueeze(0))

        for idx in range(tree.num_children):
            self.forward(tree.children[idx], inputs, arcs, S, ttype)
            num_words += tree.children[idx].words
            child_words.append(tree.children[idx].words)
            residual.append(inputs[tree.children[idx].idx].unsqueeze(0))

        tree.words = num_words
        child_words.append(tree.words)

        if tree.num_children == 0:
            tree.state = inputs[tree.idx].view(1, -1)  #child_h
            tree.arc = arcs[tree.idx].view(1, -1)
            tree.words = 1
            return tree.words
        else:
            states = []
            arc_labels = []
            for x in tree.children:
                states.append(x.state)
                arc_labels.append(x.arc)
            child_h = torch.cat(
                states, dim=0)  #+ self.Wv(torch.cat(arc_labels, dim=0))
            child_arcs = torch.cat(arc_labels, dim=0)

        x_hat = inputs[tree.idx].view(1, -1)
        tree.state = self.transformer.tree_encode(x_hat, child_h.unsqueeze(0),
                                                  child_arcs.unsqueeze(0), S,
                                                  child_words, residual, ttype)
        tree.arc = arcs[tree.idx].view(1, -1)
        return tree.state

    def forward1(self, tree, inputs, S):
        if tree.num_children == 0:
            tree.state = inputs[tree.idx].view(1, -1)  #child_h
            return [tree.state]
        subtree_list = []
        for idx in range(tree.num_children):
            subtree_list += self.forward1(tree.children[idx], inputs, S)
        dummy = torch.cat(subtree_list, dim=0)
        word_vec = self.transformer.tree_encode1(dummy.unsqueeze(0), S)
        return [word_vec.squeeze(0)]

    def forward_MVRNN(self, tree, inputs, Minputs, S):  # for dependency RNNs
        for idx in range(tree.num_children):
            self.forward_MVRNN(tree.children[idx], inputs, Minputs, S)

        if tree.num_children == 0:
            tree.Vstate = inputs[tree.idx].view(1, -1)  #child_h
            tree.Mstate = Minputs[tree.idx].view(1, 50, -1)  #child_h
            return
        else:
            states = []
            matrix = []
            for x in tree.children:
                states.append(x.Vstate.view(1, -1))
                matrix.append(x.Mstate.view(1, 50, -1))
            child_hV = torch.cat(states, dim=0)
            child_hM = torch.cat(matrix, dim=0)

        term1 = torch.mm(child_hM[1].view(50, -1),
                         child_hV[0].view(-1, 1)).view(1, -1)
        term2 = torch.mm(child_hM[0].view(50, -1),
                         child_hV[1].view(-1, 1)).view(1, -1)
        tree.Vstate = torch.tanh(
            torch.mm(self.W_mv,
                     torch.cat([term1, term2], dim=1).t()).t())
        tree.Mstate = torch.mm(
            self.W_mv_M,
            torch.cat([child_hM[0], child_hM[1]], dim=1).t())

        return tree.Vstate.view(1, -1)
Exemplo n.º 16
0
def main():
    import argparse
    parse = argparse.ArgumentParser(description="设置基本参数")
    parse.add_argument("--para_path",
                       type=str,
                       default=os.path.join(root, "data/para.json"),
                       help="所有配置参数")
    parse.add_argument("--model_path",
                       type=str,
                       default=os.path.join(
                           root, "model/transformer_0127/checkpoint_5.pt"),
                       help="所有配置参数")
    parse.add_argument("--no_sample",
                       action='store_true',
                       default=False,
                       help="Set to use greedy decoding instead of sampling")
    parse.add_argument("--repetition_penalty",
                       type=float,
                       default=0.01,
                       help="重复惩罚项")
    parse.add_argument("--temperature",
                       type=float,
                       default=0.7,
                       help="Sampling softmax temperature")
    parse.add_argument(
        "--top_k",
        type=int,
        default=0,
        help="Filter top-k tokens before sampling (<=0: no filtering)")
    parse.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
    args = parse.parse_args()

    with open(args.para_path, mode='r', encoding='utf-8') as fp:
        para_dict = json.load(fp)

    config = TransformerConfig(**para_dict)

    tokenizer = BertTokenizer(vocab_file=config.vocab_path)
    bos_token_id = tokenizer._convert_token_to_id("[CLS]")
    eos_token_id = tokenizer._convert_token_to_id("[SEP]")
    pad_token_id = tokenizer._convert_token_to_id("[PAD]")

    logger.info("Load model.")
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  # 标准写法
    model = Transformer(config=config)
    model.load_state_dict(torch.load(args.model_path, map_location="cpu"),
                          strict=False)
    for name, weights in zip(model.named_parameters(), model.parameters()):
        logger.info("{} --- {}".format(name, weights))
    model.to(device)

    history_tokens = []
    while True:
        user_text = input("User-->>")
        while not user_text:
            logger.info('Prompt should not be empty!')
            user_text = input("User-->>")
        tokens = tokenizer.tokenize(user_text)
        history_tokens.append(tokens)

        # 获取输入tokens
        context_tokens = ["[SEP]"]
        for turn in history_tokens[::-1]:  # 逆序访问
            if len(context_tokens) + len(turn) < config.max_encode_len:
                context_tokens = turn + context_tokens
                context_tokens = ["[SEP]"] + context_tokens
            else:
                break
        context_tokens[0] = "[CLS]"  # 将头部[SEP] token替换为[CLS] token

        # 编码部分
        encode_input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
        encode_input_ids = torch.tensor(encode_input_ids).long().unsqueeze(
            dim=0).to(device)
        encode_outputs, encode_attention_mask = encoder(model.encoder,
                                                        encode_input_ids,
                                                        pad_idx=pad_token_id)

        # 解码部分, 生成文本
        index = 1
        generate_sequence_ids = [bos_token_id]
        while index <= config.max_decode_len:
            # decode_input_ids = torch.LongTensor([generate_sequence_ids])  # 扩充为二维向量
            decode_input_ids = torch.tensor(
                generate_sequence_ids).long().unsqueeze(dim=0).to(device)
            logits = decoder(model.decoder,
                             model.trg_word_prj,
                             decode_input_ids,
                             encode_outputs=encode_outputs,
                             encode_attention_mask=encode_attention_mask)
            next_token_logit = logits[0][-1, :]  # 获取最后一个token的Logit
            for id in set(generate_sequence_ids):
                next_token_logit[id] /= args.repetition_penalty
            next_token_logit = top_filtering(next_token_logit,
                                             top_k=args.top_k,
                                             top_p=args.top_p)
            probs = F.softmax(next_token_logit, dim=-1)

            temp_token_id = torch.topk(probs, 1)
            next_token_id = torch.topk(
                probs, 1)[1] if args.no_sample else torch.multinomial(
                    probs, 1)
            next_token_id = next_token_id.item()

            if next_token_id == eos_token_id:
                generate_sequence_ids.append(next_token_id)
                break

            generate_sequence_ids.append(next_token_id)
            index += 1

        system_tokens = tokenizer.convert_ids_to_tokens(generate_sequence_ids)
        print("System-->>{}".format("".join(system_tokens[1:-1])))
        history_tokens.append(system_tokens[1:-1])  # 删除首尾[CLS] 与 [SEP] token
Exemplo n.º 17
0
def main(args):
    # configs path to load data & save model
    from pathlib import Path
    if not Path(args.root_dir).exists():
        Path(args.root_dir).mkdir()

    p = Path(args.save_path).parent
    if not p.exists():
        p.mkdir()

    device = "cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu"
    import sys
    print(sys.version)
    print(f"Using {device}")
    print("Loading Data...")
    (src, trg), (train, valid, _), (train_loader, valid_loader,
                                    _) = get_data(args)
    src_vocab_len = len(src.vocab.stoi)
    trg_vocab_len = len(trg.vocab.stoi)
    # check vocab size
    print(f"SRC vocab {src_vocab_len}, TRG vocab {trg_vocab_len}")
    enc_max_seq_len = args.max_length
    dec_max_seq_len = args.max_length
    pad_idx = src.vocab.stoi.get(
        "<pad>") if args.pad_idx is None else args.pad_idx
    enc_sos_idx = src.vocab.stoi.get(
        "<s>") if args.enc_sos_idx is None else args.enc_sos_idx
    enc_eos_idx = src.vocab.stoi.get(
        "</s>") if args.enc_eos_idx is None else args.enc_eos_idx
    dec_sos_idx = trg.vocab.stoi.get(
        "<s>") if args.dec_sos_idx is None else args.dec_sos_idx
    dec_eos_idx = trg.vocab.stoi.get(
        "</s>") if args.dec_eos_idx is None else args.dec_eos_idx
    pos_pad_idx = 0 if args.pos_pad_idx is None else args.pos_pad_idx

    print("Building Model...")
    model = Transformer(enc_vocab_len=src_vocab_len,
                        enc_max_seq_len=enc_max_seq_len,
                        dec_vocab_len=trg_vocab_len,
                        dec_max_seq_len=dec_max_seq_len,
                        n_layer=args.n_layer,
                        n_head=args.n_head,
                        d_model=args.d_model,
                        d_k=args.d_k,
                        d_v=args.d_v,
                        d_f=args.d_f,
                        pad_idx=pad_idx,
                        pos_pad_idx=pos_pad_idx,
                        drop_rate=args.drop_rate,
                        use_conv=args.use_conv,
                        linear_weight_share=args.linear_weight_share,
                        embed_weight_share=args.embed_weight_share).to(device)

    if args.load_path is not None:
        print(f"Load Model {args.load_path}")
        model.load_state_dict(torch.load(args.load_path))

    # build loss function using LabelSmoothing
    loss_function = LabelSmoothing(trg_vocab_size=trg_vocab_len,
                                   pad_idx=args.pad_idx,
                                   eps=args.smooth_eps)

    optimizer = WarmUpOptim(warmup_steps=args.warmup_steps,
                            d_model=args.d_model,
                            optimizer=optim.Adam(model.parameters(),
                                                 betas=(args.beta1,
                                                        args.beta2),
                                                 eps=10e-9))

    trainer = Trainer(optimizer=optimizer,
                      train_loader=train_loader,
                      test_loader=valid_loader,
                      n_step=args.n_step,
                      device=device,
                      save_path=args.save_path,
                      enc_sos_idx=enc_sos_idx,
                      enc_eos_idx=enc_eos_idx,
                      dec_sos_idx=dec_sos_idx,
                      dec_eos_idx=dec_eos_idx,
                      metrics_method=args.metrics_method,
                      verbose=args.verbose)
    print("Start Training...")
    trainer.main(model=model, loss_function=loss_function)
Exemplo n.º 18
0
class NLINet(nn.Module):
    def __init__(self, config, transformer_opt):
        super(NLINet, self).__init__()
        use_Transformer = True
        # classifier
        self.nonlinear_fc = config['nonlinear_fc']
        self.fc_dim = config['fc_dim']
        self.n_classes = config['n_classes']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.encoder_type = config['encoder_type']
        self.dpout_fc = config['dpout_fc']

        if use_Transformer:
            self.encoder_type = 'transformer'
            self.encoder = Transformer(transformer_opt)
        else:
            self.encoder = eval(self.encoder_type)(config)

        self.inputdim = 4 * 2 * self.enc_lstm_dim
        self.inputdim = 4*self.inputdim if self.encoder_type in \
                        ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim
        self.inputdim = int(self.inputdim/2) if self.encoder_type == "LSTMEncoder" \
                                        else self.inputdim
        if self.encoder_type == "transformer":
            self.inputdim = 300
            self.w_kp = torch.rand(5)
            self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
#            self.w1 = nn.Parameter(torch.FloatTensor(5, 300, 1500))
#            self.w_a = torch.rand(5)
#            self.w_a = nn.Parameter(self.w_a/self.w_a.sum())

        if self.nonlinear_fc:
            self.classifier = nn.Sequential(
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.n_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(int(self.inputdim), self.fc_dim),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Linear(self.fc_dim, self.n_classes))

    def get_trainable_parameters(self):
        ''' Avoid updating the position encoding '''
        enc_freezed_param_ids = set(
            map(id, self.encoder.encoder.pos_emb.parameters()))
        #        enc_embeddings = set(map(id, self.encoder.embedding_table.parameters()))
        #        dec_freezed_param_ids = set(map(id, self.encoder.decoder.pos_emb.parameters()))
        #        freezed_param_ids = enc_freezed_param_ids | dec_freezed_param_ids
        freezed_param_ids = enc_freezed_param_ids
        return (p for p in self.parameters() if id(p) not in freezed_param_ids)

    def filter_parameters(self):
        ''' Avoid updating the position encoding '''
        self.encoder.encoder.pos_emb.weight.requires_grad = False


#        self.encoder.decoder.pos_emb.weight.requires_grad = False

    def node_forward(self, tree):
        if type(tree[0]) != Tree:
            #            print("leaf_found")
            return tree[0]
        else:
            idd = 0
            ids = []
            vectors = []
            for subtree in tree:
                vec = self.node_forward(subtree).cuda()
                #                print("vector size: ", vec.size())
                vectors.append(vec.unsqueeze(0).unsqueeze(0))
                idd += 1
                ids.append(idd)
            input_tree = torch.cat(vectors, dim=1)
            mask = torch.ones(len(ids)).unsqueeze(0)
            position = torch.tensor(ids).unsqueeze(0)
            #            print(input_tree.size(), position.size(), mask.size())
            u, _ = self.encoder.encode(input_tree.cuda(), position.cuda(),
                                       mask.cuda())
            sum_a = torch.zeros(u.size(0), 300).cuda()  # 16 300
            for i in range(sum_a.size(0)):  # 16
                temp = 0
                for k in range(len(ids)):  # 21
                    temp += u[i][k]
                sum_a[i] = temp
                sum_a[i] /= math.sqrt(len(ids))
            u = sum_a

        return u.squeeze(0)

    def forward(self, s1, s2):
        # s1 : (s1, s1_len)
        u = self.node_forward(s1).unsqueeze(0)
        v = self.node_forward(s2).unsqueeze(0)
        u = F.tanh(u)
        v = F.tanh(v)

        #        features = torch.cat((u, v, torch.abs(u-v), u*v), 1)
        bs = 1

        features = [u, torch.abs(u - v), v, u * v, (u + v) / 2]  # 16x1500
        outputs = [
            kappa * feature for feature, kappa in zip(features, self.w_kp)
        ]
        outputs = torch.cat(outputs, dim=0).view(5, -1, 300)
        features = torch.sum(outputs, dim=0).view(bs, 300)

        output = self.classifier(features)

        return output

    def encode(self, s1):
        emb = self.encoder(s1)
        return emb
Exemplo n.º 19
0
def main():
    args = parse_args()

    loader = DataLoader(MachineTranslationDataLoader,
                        args.src,
                        args.tgt,
                        max_vocab_size=args.max_vocab_size,
                        min_word_count=args.min_word_count,
                        max_len=args.max_len,
                        cuda=args.cuda)

    src_vocab, tgt_vocab = loader.loader.src.vocab, loader.loader.tgt_in.vocab
    print(len(src_vocab), len(tgt_vocab))

    torch.save(src_vocab, os.path.join(args.logdir, 'src_vocab.pt'))
    torch.save(tgt_vocab, os.path.join(args.logdir, 'tgt_vocab.pt'))

    transformer = Transformer(len(src_vocab),
                              len(tgt_vocab),
                              args.max_len + 2,
                              n_layers=args.n_layers,
                              d_model=args.d_model,
                              d_emb=args.d_model,
                              d_hidden=args.d_hidden,
                              n_heads=args.n_heads,
                              d_k=args.d_k,
                              d_v=args.d_v,
                              dropout=args.dropout,
                              pad_id=src_vocab.pad_id)

    weights = torch.ones(len(tgt_vocab))
    weights[tgt_vocab.pad_id] = 0

    optimizer = torch.optim.Adam(transformer.get_trainable_parameters(),
                                 lr=args.lr)

    loss_fn = torch.nn.CrossEntropyLoss(weights)

    if args.cuda:
        transformer = transformer.cuda()
        loss_fn = loss_fn.cuda()

    def loss_fn_wrap(src, tgt_in, tgt_out, src_pos, tgt_pos, logits):
        return loss_fn(logits, tgt_out.contiguous().view(-1))

    def get_performance(gold, logits, pad_id):
        gold = gold.contiguous().view(-1)
        logits = logits.max(dim=1)[1]

        n_corrects = logits.data.eq(gold.data)
        n_corrects = n_corrects.masked_select(gold.ne(pad_id).data).sum()

        return n_corrects

    def epoch_fn(epoch, stats):
        (n_corrects,
         n_words) = list(zip(*[(x['n_corrects'], x['n_words'])
                               for x in stats]))

        train_acc = sum(n_corrects) / sum(n_words)

        return {'train_acc': train_acc}

    def step_fn(step, src, tgt_in, tgt_out, src_pos, tgt_pos, logits):
        n_corrects = get_performance(tgt_out, logits, tgt_vocab.pad_id)
        n_words = tgt_out.data.ne(tgt_vocab.pad_id).sum()

        return {'n_corrects': n_corrects, 'n_words': n_words}

    trainer = Trainer(transformer,
                      loss_fn_wrap,
                      optimizer,
                      logdir=args.logdir,
                      hparams=args,
                      save_mode=args.save_mode)

    trainer.train(
        lambda: loader.iter(batch_size=args.batch_size, with_pos=True),
        epochs=args.epochs,
        epoch_fn=epoch_fn,
        step_fn=step_fn,
        metric='train_acc')
Exemplo n.º 20
0
class NLINet(nn.Module):
    def __init__(self, config, transformer_opt):
        super(NLINet, self).__init__()
        use_Transformer = False
        # classifier
        self.nonlinear_fc = config['nonlinear_fc']
        self.fc_dim = config['fc_dim']
        self.n_classes = config['n_classes']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.encoder_type = config['encoder_type']
        self.dpout_fc = config['dpout_fc']

        if use_Transformer:
            self.encoder_type = 'transformer'
            self.encoder = Transformer(transformer_opt)
        else:
            self.encoder = eval(self.encoder_type)(config)

        self.inputdim = 4 * 1 * self.enc_lstm_dim
        self.inputdim = 4*self.inputdim if self.encoder_type in \
                        ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim
        self.inputdim = int(self.inputdim/2) if self.encoder_type == "LSTMEncoder" \
                                        else self.inputdim
        if self.encoder_type == "transformer":
            self.inputdim = 300
            self.w_kp = torch.rand(5)
            self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
#            self.w1 = nn.Parameter(torch.FloatTensor(5, 300, 1500))
#            self.w_a = torch.rand(5)
#            self.w_a = nn.Parameter(self.w_a/self.w_a.sum())

        if self.nonlinear_fc:
            self.classifier = nn.Sequential(
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.n_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(int(self.inputdim), self.fc_dim),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Linear(self.fc_dim, self.n_classes))

    def get_trainable_parameters(self):
        ''' Avoid updating the position encoding '''
        enc_freezed_param_ids = set(
            map(id, self.encoder.encoder.pos_emb.parameters()))
        #        enc_embeddings = set(map(id, self.encoder.embedding_table.parameters()))
        dec_freezed_param_ids = set(
            map(id, self.encoder.decoder.pos_emb.parameters()))
        freezed_param_ids = enc_freezed_param_ids | dec_freezed_param_ids
        #        freezed_param_ids = enc_freezed_param_ids
        return (p for p in self.parameters() if id(p) not in freezed_param_ids)

    def filter_parameters(self):
        ''' Avoid updating the position encoding '''
        pass
        #self.encoder.encoder.pos_emb.weight.requires_grad = False
#        self.encoder.decoder.pos_emb.weight.requires_grad = False

    def forward(self, s1, s2):
        # s1 : (s1, s1_len)

        s1, s1_len, position_a, mask_a = s1
        s2, s2_len, position_b, mask_b = s2

        if self.encoder_type == "transformer":

            #            u_hat, _ = self.encoder.encode(s1, position_a, mask_a)
            #            v_hat, _ = self.encoder.encode(s2, position_b, mask_b)

            #            print(u_hat.size(), v_hat.size(), position_a, s1,mask_b, mask_a)

            #            u, _, _ = self.encoder.decode(s1, position_a, mask_b, v_hat, mask_a)
            #            v, _, _ = self.encoder.decode(s2, position_b, mask_a, u_hat, mask_b)

            #            u, _, _ = self.encoder.decode(s1, position_a, mask_b, s2, mask_a)
            #            v, _, _ = self.encoder.decode(s2, position_b, mask_a, s1, mask_b)

            #            u = u_hat
            #            v = v_hat

            #            pool_type = "max"
            #            max_pad = True
            #            if pool_type == "mean":
            #                s1_len = torch.FloatTensor(s1_len.copy()).unsqueeze(1).cuda()
            #                emb = torch.sum(u, 0).squeeze(0)
            #                emb = emb / s1_len.expand_as(emb)
            #            elif pool_type == "max":
            #                if not max_pad:
            #                    u[u == 0] = -1e9
            #                emb = torch.max(u, 0)[0]
            #                if emb.ndimension() == 3:
            #                    emb = emb.squeeze(0)
            #                    assert emb.ndimension() == 2
            #            u = emb
            #
            #            if pool_type == "mean":
            #                s2_len = torch.FloatTensor(s2_len.copy()).unsqueeze(1).cuda()
            #                emb = torch.sum(v, 0).squeeze(0)
            #                emb = emb / s2_len.expand_as(emb)
            #            elif pool_type == "max":
            #                if not max_pad:
            #                    v[v == 0] = -1e9
            #                emb = torch.max(v, 0)[0]
            #                if emb.ndimension() == 3:
            #                    emb = emb.squeeze(0)
            #                    assert emb.ndimension() == 2
            #            v = emb

            v, u = self.encoder.ed(s1, position_a, mask_a, s2, position_b,
                                   mask_b)
            #            v, _, _ = self.encoder.ed(s2, position_b, mask_b, s1, position_a, mask_a)
            #            print(u.size(), s1.size(), v.size(), s2.size())
            #            asasas
            sum_a = torch.zeros(u.size(0), 300).cuda()  # 16 300
            for i in range(sum_a.size(0)):  # 16
                temp = 0
                for k in range(s1_len[i]):  # 21
                    temp += u[i][k]
                sum_a[i] = temp
                sum_a[i] /= math.sqrt(s1_len[i])
            u = sum_a
            #            u = F.tanh(sum_a)

            sum_b = torch.zeros(v.size(0), 300).cuda()
            for i in range(sum_b.size(0)):
                temp = 0
                for k in range(s2_len[i]):
                    temp += v[i][k]
                sum_b[i] = temp
                sum_b[i] /= math.sqrt(s2_len[i])
            v = sum_b


#            v = F.tanh(sum_b)

        else:
            s1 = (s1.transpose(0, 1), s1_len)
            s2 = (s2.transpose(0, 1), s2_len)
            u = self.encoder(s1)
            v = self.encoder(s2)

        features = torch.cat((u, v, torch.abs(u - v), u * v), 1)
        #features = torch.cat((u, v, torch.abs(u-v), u*v, (u+v)/2), 1)
        #        features = torch.cat((u,torch.abs(u - v),v,u*v, (u+v)/2), 1) # 16x1500
        #        features = features.unsqueeze(1).view(5,bs,300)
        #        features = torch.bmm(features, self.w1)
        #        features = features.sum(0).view(bs,1,1500)
        #        outputs = [kappa * features for kappa in self.w_kp]
        #        outputs = torch.cat(outputs, dim=0).view(5, -1, 1500)
        #        outputs = self.w_a.view(-1, 1, 1) * outputs
        #        features = torch.sum(outputs, dim=0).view(bs, 1500)
        '''
        bs = s1.size(0)
        features = [u,torch.abs(u - v),v,u*v, (u+v)/2] # 16x1500
        outputs = [kappa * feature for feature,kappa in zip(features,self.w_kp)]
        outputs = torch.cat(outputs, dim=0).view(5, -1, 300)
        features = torch.sum(outputs, dim=0).view(bs, 300)
        '''

        output = self.classifier(features)

        return output

    def encode(self, s1):
        emb = self.encoder(s1)
        return emb
Exemplo n.º 21
0
def main():
    import argparse
    parse = argparse.ArgumentParser(description="设置基本参数")
    # model parameter
    parse.add_argument("--vocab_size", type=int, default=1000, help="字典大小")
    parse.add_argument("--n_position",
                       type=int,
                       default=256,
                       help="位置数量序列最大长度")
    parse.add_argument("--word_vec_size",
                       type=int,
                       default=512,
                       help="embedding输出大小")
    parse.add_argument("--d_model", type=int, default=512, help="隐层大小")
    parse.add_argument("--d_inner", type=int, default=1024, help="隐层中间层大小")
    parse.add_argument("--n_head", type=int, default=8, help="自注意力头的数量")
    parse.add_argument("--d_k",
                       type=int,
                       default=64,
                       help="d_model/n_head每个头隐层的大小")
    parse.add_argument("--d_v",
                       type=int,
                       default=64,
                       help="d_model/n_head每个头隐层的大小")
    parse.add_argument("--encoder_n_layers", type=int, default=6, help="编码的层数")
    parse.add_argument("--decoder_n_layers", type=int, default=6, help="解码的层数")
    parse.add_argument("--dropout", type=float, default=0.1, help="dropout概率")
    parse.add_argument("--pad_idx", type=int, default=-1, help="padding index")
    parse.add_argument("--trg_emb_prj_weight_sharing",
                       action="store_true",
                       default=True)
    parse.add_argument("--emb_src_trg_weight_sharing",
                       action="store_true",
                       default=True)

    # data parameter
    parse.add_argument("--vocab_path",
                       type=str,
                       default=os.path.join(root, "vocabulary/vocab.txt"),
                       help="词汇表路径")
    parse.add_argument("--train_data_path",
                       type=str,
                       default=os.path.join(root, "data/train_small.txt"),
                       help="训练数据路径")
    parse.add_argument("--evaluate_data_path",
                       type=str,
                       default=None,
                       help="评估数据路径")
    parse.add_argument("--max_encode_len",
                       type=int,
                       default=192,
                       help="最大编码序列长度")
    parse.add_argument("--max_decode_len",
                       type=int,
                       default=64,
                       help="最大解码序列长度")
    parse.add_argument("--history_turns", type=int, default=3, help="历史对话轮数")
    parse.add_argument("--max_lines", type=int, default=525106, help="最多处理数据量")
    parse.add_argument("--batch_size",
                       type=int,
                       default=32,
                       help="batch size 大小")

    # train parameter
    parse.add_argument("--epochs", type=int, default=20, help="训练epoch数量")
    parse.add_argument("--save_epoch",
                       type=int,
                       default=5,
                       help="每训练多少epoch保存一次模型")
    parse.add_argument("--save_dir",
                       type=str,
                       default=os.path.join(root, "model/transformer_0127"),
                       help="模型保存路径")
    parse.add_argument("--init_lr", type=float, default=1.0, help="初始学习率")
    parse.add_argument("--n_warmup_steps", type=int, default=100, help="热身步长")
    parse.add_argument("--label_smoothing", action="store_true", default=False)

    args = parse.parse_args()

    tokenizer = BertTokenizer(vocab_file=args.vocab_path)
    args.vocab_size = tokenizer.vocab_size
    args.pad_idx = tokenizer._convert_token_to_id("[PAD]")

    args_dict = vars(args)
    config = TransformerConfig(**args_dict)

    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)  # 创建模型保存路径

    logger.info("Load dataset.")
    train_dataset = ChatDataset(config.train_data_path,
                                tokenizer=tokenizer,
                                max_encode_len=config.max_encode_len,
                                max_decode_len=config.max_decode_len,
                                history_turns=config.history_turns,
                                max_lines=config.max_lines)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    if config.evaluate_data_path is not None:
        eval_dataset = ChatDataset(config.evaluate_data_path,
                                   tokenizer=tokenizer,
                                   max_encode_len=config.max_encode_len,
                                   max_decode_len=config.max_decode_len,
                                   history_turns=config.history_turns,
                                   max_lines=config.max_lines)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=config.batch_size,
                                 shuffle=False)
    else:
        eval_loader = False

    logger.info("Load model.")
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  # 标准写法
    model = Transformer(config=config)
    model.to(device)

    logger.info("Load optimizer.")
    optimizer = ScheduledOptim(
        optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        config.init_lr, config.d_model, config.n_warmup_steps)

    logger.info("Save all config parameter.")
    config.save_para_to_json_file(os.path.join(root, "data/para.json"))

    logger.info("Training model.")
    train(config,
          model,
          optimizer,
          train_loader=train_loader,
          eval_loader=eval_loader,
          device=device)