def __init__(self, opt, device):
        self.opt = opt
        self.device = device

        checkpoint = torch.load(opt.model)
        model_opt = checkpoint['settings']
        self.model_opt = model_opt

        model = Transformer(model_opt.input_dim,
                            model_opt.output_dim,
                            model_opt.n_inputs_max_seq,
                            model_opt.n_outputs_max_seq,
                            d_k=model_opt.d_k,
                            d_v=model_opt.d_v,
                            d_model=model_opt.d_model,
                            d_inner_hid=model_opt.d_inner_hid,
                            n_layers=model_opt.n_layers,
                            n_head=model_opt.n_head,
                            dropout=model_opt.dropout,
                            device=device,
                            is_train=False)

        model.load_state_dict(checkpoint['model'])
        print('[Info] Trained model state loaded.')

        model.to(device)
        prob_projection.to(device)

        model.prob_projection = prob_projection

        self.model = model
        self.model.eval()
예제 #2
0
    def __init__(self, opt):
        self.opt = opt
        print(opt, "\n")
        self.device = torch.device('cuda' if opt.cuda else 'cpu')

        checkpoint = torch.load(opt.model)
        model_opt = checkpoint['settings']
        self.model_opt = model_opt
        print(model_opt)
        model = Transformer(
            model_opt.src_vocab_size,
            model_opt.tgt_vocab_size,
            model_opt.max_token_seq_len,
            tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
            emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
            d_k=model_opt.d_k,
            d_v=model_opt.d_v,
            d_model=model_opt.d_model,
            d_word_vec=model_opt.d_word_vec,
            d_inner=model_opt.d_inner_hid,
            n_layers=model_opt.n_layers,
            n_head=model_opt.n_head,
            dropout=model_opt.dropout)

        model.load_state_dict(checkpoint['model'])
        print('[Info] Trained model state loaded.')

        model.word_prob_prj = nn.LogSoftmax(dim=1)

        model = model.to(self.device)

        self.model = model
        self.model.eval()
    def __init__(self, model):

        self.device = torch.device('cuda')
        checkpoint = torch.load(model)
        checkpoint_copy = checkpoint['model'].copy()

        for k in list(checkpoint_copy.keys()):
            new_key = k.replace('module.model.', '')
            checkpoint_copy.update({str(new_key): checkpoint_copy.pop(k)})

        model_opt = checkpoint['settings']
        model = Transformer(
            model_opt.src_vocab_size,
            model_opt.tgt_vocab_size,
            model_opt.max_token_seq_len,
            tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
            emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
            d_k=model_opt.d_k,
            d_v=model_opt.d_v,
            d_model=model_opt.d_model,
            d_word_vec=model_opt.d_word_vec,
            d_inner=model_opt.d_inner_hid,
            n_layers=model_opt.n_layers,
            n_head=model_opt.n_head,
            dropout=model_opt.dropout)
        model.load_state_dict(checkpoint_copy)
        model = model.to(self.device)
        self.model = model
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cuda' if opt.cuda else 'cpu')

        checkpoint = torch.load(opt.model)
        model_opt = checkpoint['settings']
        self.model_opt = model_opt

        if opt.prune:
            # NetworkWrapper
            prune_params = {'alpha': opt.prune_alpha}
            pruner = Pruner(device=device,
                            load_mask=opt.load_mask,
                            prune_params=prune_params)

            transformer = NetworkWrapper(
                model_opt.src_vocab_size,
                model_opt.tgt_vocab_size,
                model_opt.max_token_seq_len,
                tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
                emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
                d_k=model_opt.d_k,
                d_v=model_opt.d_v,
                d_model=model_opt.d_model,
                d_word_vec=model_opt.d_word_vec,
                d_inner=model_opt.d_inner_hid,
                n_layers=model_opt.n_layers,
                n_head=model_opt.n_head,
                dropout=model_opt.dropout,
                transformer=pruner)
        else:
            model = Transformer(
                model_opt.src_vocab_size,
                model_opt.tgt_vocab_size,
                model_opt.max_token_seq_len,
                tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
                emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
                d_k=model_opt.d_k,
                d_v=model_opt.d_v,
                d_model=model_opt.d_model,
                d_word_vec=model_opt.d_word_vec,
                d_inner=model_opt.d_inner_hid,
                n_layers=model_opt.n_layers,
                n_head=model_opt.n_head,
                dropout=model_opt.dropout)

        model.load_state_dict(checkpoint['model'])
        print('[Info] Trained model state loaded.')

        model.word_prob_prj = nn.LogSoftmax(dim=1)

        model = model.to(self.device)

        self.model = model
        self.model.eval()
예제 #5
0
    def __init__(self, opt):
        #opt is from argprass
        self.opt = opt
        self.device = torch.device('cuda' if opt.cuda else 'cpu')
        self.m = opt.m
        #opt.model is the model path
        checkpoint = torch.load(opt.model)
        #model_opt is the model hyper params
        model_opt = checkpoint['settings']
        self.model_opt = model_opt

        model = Transformer(
            model_opt.src_vocab_size,
            model_opt.tgt_vocab_size,
            model_opt.max_token_seq_len,
            tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
            emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
            d_k=model_opt.d_k,
            d_v=model_opt.d_v,
            d_model=model_opt.d_model,
            d_word_vec=model_opt.d_word_vec,
            d_inner=model_opt.d_inner_hid,
            n_layers=model_opt.n_layers,
            n_head=model_opt.n_head,
            dropout=model_opt.dropout,
            return_attns=opt.return_attns)

        #Load the actual model weights
        model.load_state_dict(checkpoint['model'])
        print('[Info] Trained model state loaded.')

        model.word_prob_prj = nn.LogSoftmax(dim=1)

        model = model.to(self.device)

        self.model = model
        self.model.eval()
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cuda' if opt.cuda else 'cpu')

        checkpoint = torch.load(opt.model)
        model_opt = checkpoint['settings']
        self.model_opt = model_opt
        '''added by self'''
        checkpoint_copy = checkpoint['model'].copy()
        for k in list(checkpoint_copy.keys()):
            new_key = k.replace('module.model.', '')
            checkpoint_copy.update({str(new_key): checkpoint_copy.pop(k)})
        ''' end '''
        model = Transformer(
            model_opt.src_vocab_size,
            model_opt.tgt_vocab_size,
            model_opt.max_token_seq_len,
            tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
            emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
            d_k=model_opt.d_k,
            d_v=model_opt.d_v,
            d_model=model_opt.d_model,
            d_word_vec=model_opt.d_word_vec,
            d_inner=model_opt.d_inner_hid,
            n_layers=model_opt.n_layers,
            n_head=model_opt.n_head,
            dropout=model_opt.dropout)

        model.load_state_dict(checkpoint_copy)
        print('[Info] Trained model state loaded.')

        model.word_prob_prj = nn.LogSoftmax(dim=1)

        model = model.to(self.device)

        self.model = model
        self.model.eval()
def main():
    '''
    Usage:
    python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000
    '''
    global C
    global shapes
    global Beta
    parser = argparse.ArgumentParser()

    parser.add_argument('-data_pkl',
                        default=None)  # all-in-1 data pickle or bpe field
    parser.add_argument('-srn', type=bool, default=False)
    parser.add_argument('-optimize_c', type=bool, default=False)
    parser.add_argument('-Beta', type=float, default=1.0)
    parser.add_argument("-lr", type=float, default=1e-1)
    parser.add_argument("-scheduler_mode", type=str, default=None)
    parser.add_argument("-scheduler_factor", type=float, default=0.5)
    parser.add_argument('-train_path', default=None)  # bpe encoded data
    parser.add_argument('-val_path', default=None)  # bpe encoded data

    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model
    Beta = opt.Beta

    if not opt.log and not opt.save_model:
        print('No experiment result will be saved.')
        raise

    if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000:
        print('[Warning] The warmup steps may be not enough.\n'\
              '(sz_b, warmup) = (2048, 4000) is the official setting.\n'\
              'Using smaller batch w/o longer warmup may cause '\
              'the warmup stage ends with only little data trained.')

    device = torch.device('cuda' if opt.cuda else 'cpu')

    #========= Loading Dataset =========#

    if all((opt.train_path, opt.val_path)):
        training_data, validation_data = prepare_dataloaders_from_bpe_files(
            opt, device)
    elif opt.data_pkl:
        training_data, validation_data = prepare_dataloaders(opt, device)
    else:
        raise

    print(opt)

    transformer = Transformer(opt.src_vocab_size,
                              opt.trg_vocab_size,
                              src_pad_idx=opt.src_pad_idx,
                              trg_pad_idx=opt.trg_pad_idx,
                              trg_emb_prj_weight_sharing=opt.proj_share_weight,
                              emb_src_trg_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)
    if opt.srn:
        transformer = migrate_to_srn(transformer)
        transformer = transformer.to(device)
    if opt.optimize_c:
        srn_modules = [
            module for module in transformer.modules()
            if isinstance(module, (SRNLinear, SRNConv2d))
        ]
        sranks = []
        shapes = []

        for module in srn_modules:
            W = module.weight.detach()
            shape_w = W.shape
            W = W.view(shape_w[0], -1)
            sranks.append(stable_rank(W).item())
            shapes.append(W.shape)

        # a rule of thump to initialize the target srank with the current srank of the model
        C = [
            Parameter((torch.ones(1) * sranks[i] / min(shapes[i])).view(()))
            for i in range(len(srn_modules))
        ]
        for i, module in enumerate(srn_modules):
            C[i].to(device)
            module.c = C[i]
        criteria = criteria_
    else:
        criteria = cal_performance
    optimizer = ScheduledOptim(optim.Adam(transformer.parameters(),
                                          lr=1e-2,
                                          betas=(0.9, 0.98),
                                          eps=1e-09),
                               opt.lr,
                               opt.d_model,
                               opt.n_warmup_steps,
                               mode=opt.scheduler_mode,
                               factor=opt.scheduler_factor,
                               patience=3)

    train(transformer,
          training_data,
          validation_data,
          optimizer,
          device,
          opt,
          loss=criteria)
    print("~~~~~~~~~~~~~C~~~~~~~~~~~~~")
    print(C)
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print("-----------Model-----------")
    print(transformer)
    print("---------------------------")
    with torch.no_grad():
        for pname, p in transformer.named_parameters():
            if len(p.shape) > 1:
                print("...Parameter ", pname, ", srank=",
                      stable_rank(p.view(p.shape[0], -1)).item())
예제 #8
0
def main():
    """ Main function. """

    parser = argparse.ArgumentParser()

    parser.add_argument('-data', required=True)

    parser.add_argument('-epoch', type=int, default=30)
    parser.add_argument('-batch_size', type=int, default=16)

    parser.add_argument('-d_model', type=int, default=64)
    parser.add_argument('-d_rnn', type=int, default=256)
    parser.add_argument('-d_inner_hid', type=int, default=128)
    parser.add_argument('-d_k', type=int, default=16)
    parser.add_argument('-d_v', type=int, default=16)

    parser.add_argument('-n_head', type=int, default=4)
    parser.add_argument('-n_layers', type=int, default=4)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-lr', type=float, default=1e-4)
    parser.add_argument('-smooth', type=float, default=0.1)

    parser.add_argument('-log', type=str, default='log.txt')

    opt = parser.parse_args()

    # default device is CUDA
    opt.device = torch.device('cuda')

    # setup the log file
    with open(opt.log, 'w') as f:
        f.write('Epoch, Log-likelihood, Accuracy, RMSE\n')

    print('[Info] parameters: {}'.format(opt))
    """ prepare dataloader """
    trainloader, testloader, num_types = prepare_dataloader(opt)
    """ prepare model """
    model = Transformer(
        num_types=num_types,
        d_model=opt.d_model,
        d_rnn=opt.d_rnn,
        d_inner=opt.d_inner_hid,
        n_layers=opt.n_layers,
        n_head=opt.n_head,
        d_k=opt.d_k,
        d_v=opt.d_v,
        dropout=opt.dropout,
    )
    model.to(opt.device)
    """ optimizer and scheduler """
    optimizer = optim.Adam(filter(lambda x: x.requires_grad,
                                  model.parameters()),
                           opt.lr,
                           betas=(0.9, 0.999),
                           eps=1e-05)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.5)
    """ prediction loss function, either cross entropy or label smoothing """
    if opt.smooth > 0:
        pred_loss_func = Utils.LabelSmoothingLoss(opt.smooth,
                                                  num_types,
                                                  ignore_index=-1)
    else:
        pred_loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
    """ number of parameters """
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('[Info] Number of parameters: {}'.format(num_params))
    """ train the model """
    train(model, trainloader, testloader, optimizer, scheduler, pred_loss_func,
          opt)
예제 #9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-data', default='./data/preprocessedData')

    parser.add_argument('-epoch', type=int, default=50)
    parser.add_argument('-batch_size', type=int, default=64)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default='log')  # None
    parser.add_argument('-save_model', default='trained')  # None
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true', default=True)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    # Loading Dataset
    data = torch.load(opt.data)
    opt.max_token_seq_len = data['settings'].max_token_seq_len

    training_data, validation_data = prepare_dataloaders(data, opt)

    opt.src_vocab_size = training_data.dataset.src_vocab_size
    opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size

    # Preparing Model
    if opt.embs_share_weight:
        assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
            'The src/tgt word2idx table are different but asked to share word embedding.'

    print(opt)

    device = torch.device('cuda' if opt.cuda else 'cpu')
    # device = torch.device('cpu')

    transformer = Transformer(opt.src_vocab_size,
                              opt.tgt_vocab_size,
                              opt.max_token_seq_len,
                              tgt_emb_prj_weight_sharing=opt.proj_share_weight,
                              emb_src_tgt_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout)

    discriminator = Discriminator(opt.d_model, 1024, opt.max_token_seq_len,
                                  device)

    #'''
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        transformer = nn.DataParallel(transformer)
    #    '''
    transformer.to(device)
    discriminator.to(device)

    optimizer = ScheduledOptim(
        optim.Adam(filter(lambda x: x.requires_grad, transformer.parameters()),
                   betas=(0.9, 0.98),
                   eps=1e-09), opt.d_model, opt.n_warmup_steps)
    optimizer_d = optim.RMSprop(discriminator.parameters(), lr=5e-4)

    train(transformer, discriminator, training_data, validation_data,
          optimizer, optimizer_d, device, opt)
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('-epoch', type=int, default=1)
    parser.add_argument('-batch_size', type=int, default=4)
    parser.add_argument('-context_width', type=int, default=1)
    parser.add_argument('-frame_rate', type=int, default=30)

    #parser.add_argument('-d_word_vec', type=int, default=512)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=1024)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-n_warmup_steps', type=int, default=400)

    parser.add_argument('-dropout', type=float, default=0.1)

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default='./exp')
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    opt = parser.parse_args()

    cfg_path = './config/transformer.cfg'
    config = configparser.ConfigParser()
    config.read(cfg_path)

    #========= Preparing DataLoader =========#
    training_data = DataLoader('train',
                               config,
                               DEVICE,
                               batch_size=opt.batch_size,
                               context_width=opt.context_width,
                               frame_rate=opt.frame_rate)
    validation_data = DataLoader('dev',
                                 config,
                                 DEVICE,
                                 batch_size=opt.batch_size,
                                 context_width=opt.context_width,
                                 frame_rate=opt.frame_rate)
    test_data = DataLoader('test',
                           config,
                           DEVICE,
                           batch_size=opt.batch_size,
                           context_width=opt.context_width,
                           frame_rate=opt.frame_rate)

    #========= Preparing Model =========#

    print(opt)

    input_dim = training_data.features_dim
    output_dim = training_data.vocab_size
    n_inputs_max_seq = max(training_data.inputs_max_seq_lengths,
                           validation_data.inputs_max_seq_lengths,
                           test_data.inputs_max_seq_lengths)
    n_outputs_max_seq = max(training_data.outputs_max_seq_lengths,
                            validation_data.outputs_max_seq_lengths,
                            test_data.outputs_max_seq_lengths)
    print('*************************')
    print('The max length of inputs is %d:' % n_inputs_max_seq)
    print('The max length of targets is %d' % n_outputs_max_seq)

    transformer = Transformer(input_dim,
                              output_dim,
                              n_inputs_max_seq,
                              n_outputs_max_seq,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_inner_hid=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout,
                              device=DEVICE)

    # print(transformer)

    optimizer = ScheduledOptim(
        optim.Adam(transformer.get_trainable_parameters(),
                   betas=(0.9, 0.98),
                   eps=1e-09), opt.d_model, opt.n_warmup_steps)

    def get_criterion(output_dim):
        ''' With PAD token zero weight '''
        weight = torch.ones(output_dim)
        weight[Constants.PAD] = 0
        return nn.CrossEntropyLoss(weight, size_average=False)

    crit = get_criterion(training_data.vocab_size)

    transformer = transformer.to(DEVICE)
    crit = crit.to(DEVICE)

    train(transformer, training_data, validation_data, crit, optimizer, opt)
예제 #11
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser()

    parser.add_argument('--train_src', required=True)
    parser.add_argument('--valid_src', required=True)
    parser.add_argument('--max_word_seq_len', type=int, default=100)
    parser.add_argument('--min_word_count', type=int, default=5)
    parser.add_argument('--keep_case', action='store_true')

    parser.add_argument('--epoch', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_worker', type=int, default=8)

    # parser.add_argument('-d_word_vec', type=int, default=512)
    parser.add_argument('--d_model', type=int, default=512)
    parser.add_argument('--d_inner_hid', type=int, default=2048)
    parser.add_argument('--d_k', type=int, default=64)
    parser.add_argument('--d_v', type=int, default=64)

    parser.add_argument('--n_head', type=int, default=8)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--n_warmup_steps', type=int, default=4000)

    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--embs_share_weight', action='store_true')
    parser.add_argument('--proj_share_weight', action='store_true')

    parser.add_argument('--model', default=None, help='Path to model file')
    parser.add_argument('--log', default=None)
    parser.add_argument('--save_model', default=None)
    parser.add_argument('--save_data', default='./data/word2idx.pth')
    parser.add_argument('--save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    opt.max_token_seq_len = opt.max_word_seq_len + 2
    #========= Loading Dataset =========#
    training_data = torch.utils.data.DataLoader(dataset.TranslationDataset(
        dir_name=opt.train_src,
        max_word_seq_len=opt.max_word_seq_len,
        min_word_count=opt.min_word_count,
        keep_case=opt.keep_case,
        src_word2idx=None,
        tgt_word2idx=None),
                                                num_workers=opt.num_worker,
                                                batch_size=opt.batch_size,
                                                collate_fn=paired_collate_fn,
                                                shuffle=True)
    validation_data = torch.utils.data.DataLoader(dataset.TranslationDataset(
        dir_name=opt.valid_src,
        max_word_seq_len=opt.max_word_seq_len,
        min_word_count=opt.min_word_count,
        keep_case=opt.keep_case,
        src_word2idx=training_data.dataset.src_word2idx,
        tgt_word2idx=training_data.dataset.tgt_word2idx),
                                                  num_workers=opt.num_worker,
                                                  batch_size=opt.batch_size,
                                                  collate_fn=paired_collate_fn,
                                                  shuffle=True)
    data = {
        'dict': {
            'src': training_data.dataset.src_word2idx,
            'tgt': training_data.dataset.tgt_word2idx
        }
    }
    print('[Info] Dumping the processed data to pickle file', opt.save_data)
    torch.save(data, opt.save_data)
    print('[Info] Finish.')
    del data
    opt.src_vocab_size = training_data.dataset.src_vocab_size
    opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
            'The src/tgt word2idx table are different but asked to share word embedding.'

    print(opt)

    device = torch.device('cuda' if opt.cuda else 'cpu')
    transformer = Transformer(opt.src_vocab_size,
                              opt.tgt_vocab_size,
                              opt.max_token_seq_len,
                              tgt_emb_prj_weight_sharing=opt.proj_share_weight,
                              emb_src_tgt_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)

    optimizer = ScheduledOptim(
        optim.Adam(filter(lambda x: x.requires_grad, transformer.parameters()),
                   betas=(0.9, 0.98),
                   eps=1e-09), opt.d_model, opt.n_warmup_steps)
    if (opt.model is not None):
        print('pretrain model!')
        checkpoint = torch.load(opt.model)
        model_opt = checkpoint['settings']
        transformer = Transformer(
            model_opt.src_vocab_size,
            model_opt.tgt_vocab_size,
            model_opt.max_token_seq_len,
            tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
            emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
            d_k=model_opt.d_k,
            d_v=model_opt.d_v,
            d_model=model_opt.d_model,
            d_word_vec=model_opt.d_word_vec,
            d_inner=model_opt.d_inner_hid,
            n_layers=model_opt.n_layers,
            n_head=model_opt.n_head,
            dropout=model_opt.dropout)
        transformer.load_state_dict(checkpoint['model'])
        transformer = transformer.to(device)

    train(transformer, training_data, validation_data, optimizer, device, opt)