Exemple #1
0
    nmtModel = NMT(src_vocab_size, trg_vocab_size)
    #classifier = Classifier(wargs.out_size, trg_vocab_size,
    #                        nmtModel.decoder.trg_lookup_table if wargs.copy_trg_emb is True else None)

    if wargs.gpu_id:
        cuda.set_device(wargs.gpu_id[0])
        wlog('Push model onto GPU {} ... '.format(wargs.gpu_id[0]), 0)
        nmtModel.cuda()
        #classifier.cuda()
    else:
        wlog('Push model onto CPU ... ', 0)
        nmtModel.cpu()
        #classifier.cpu()
    wlog('done.')

    _dict = _load_model(model_file)
    if len(_dict) == 4: model_dict, eid, bid, optim = _dict
    elif len(_dict) == 5: model_dict, class_dict, eid, bid, optim = _dict

    nmtModel.load_state_dict(model_dict)
    #classifier.load_state_dict(class_dict)
    #nmtModel.classifier = classifier

    wlog('\nFinish to load model.')

    dec_conf()

    nmtModel.eval()
    #nmtModel.classifier.eval()
    tor = Translator(nmtModel,
                     src_vocab.idx2key,
Exemple #2
0
def main():

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    vocab_data = {}
    train_srcD_file = wargs.src_vocab_from
    wlog('\nPreparing out of domain source vocabulary from {} ... '.format(
        train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict,
                              wargs.src_dict_size)
    #DANN
    train_srcD_file_domain = wargs.src_domain_vocab_from
    wlog('\nPreparing in domain source vocabulary from {} ...'.format(
        train_srcD_file_domain))
    src_vocab = updata_vocab(train_srcD_file_domain, src_vocab, wargs.src_dict,
                             wargs.src_dict_size)

    vocab_data['src'] = src_vocab

    train_trgD_file = wargs.trg_vocab_from
    wlog('\nPreparing out of domain target vocabulary from {} ... '.format(
        train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict,
                              wargs.trg_dict_size)

    #DANN
    train_trgD_file_domain = wargs.trg_domain_vocab_from
    wlog('\nPreparing in domain target vocabulary from {} ... '.format(
        train_trgD_file_domain))
    trg_vocab = updata_vocab(train_trgD_file_domain, trg_vocab, wargs.trg_dict,
                             wargs.trg_dict_size)

    vocab_data['trg'] = trg_vocab

    train_src_file = wargs.train_src
    train_trg_file = wargs.train_trg
    if wargs.fine_tune is False:
        wlog('\nPreparing out of domain training set from {} and {} ... '.
             format(train_src_file, train_trg_file))
        train_src_tlst, train_trg_tlst = wrap_data(
            train_src_file,
            train_trg_file,
            vocab_data['src'],
            vocab_data['trg'],
            max_seq_len=wargs.max_seq_len)
    else:
        wlog('\nNo out of domain trainin set ...')

    #DANN
    train_src_file_domain = wargs.train_src_domain
    train_trg_file_domain = wargs.train_trg_domain
    wlog('\nPreparing in domain training set from {} and {}...'.format(
        train_src_file_domain, train_trg_file_domain))
    train_src_tlst_domain, train_trg_tlst_domain = wrap_data(
        train_src_file_domain,
        train_trg_file_domain,
        vocab_data['src'],
        vocab_data['trg'],
        max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                  wargs.val_src_suffix)
    wlog('\nPreparing validation set from {} ... '.format(valid_file))
    valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab)

    if wargs.fine_tune is False:
        wlog('Out of domain Sentence-pairs count in training data: {}'.format(
            len(train_src_tlst)))
    wlog('In domain Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst_domain)))

    src_vocab_size, trg_vocab_size = vocab_data['src'].size(
    ), vocab_data['trg'].size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))

    if wargs.fine_tune is False:
        batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    else:
        batch_train = None

    batch_valid = Input(valid_src_tlst, None, 1, volatile=True)
    #DANN
    batch_train_domain = Input(train_src_tlst_domain, train_trg_tlst_domain,
                               wargs.batch_size)

    tests_data = None
    if wargs.tests_prefix is not None:
        init_dir(wargs.dir_tests)
        tests_data = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('Preparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = val_wrap_data(test_file, src_vocab)
            tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True)

    sv = vocab_data['src'].idx2key
    tv = vocab_data['trg'].idx2key

    nmtModel = NMT(src_vocab_size, trg_vocab_size)

    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train), 'Requires pre-trained model'
        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1
    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)

    if wargs.gpu_id:
        nmtModel.cuda()
        wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0]))
    else:
        nmtModel.cpu()
        wlog('Push model onto CPU ... ')

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, batch_train_domain, vocab_data,
                      optim, batch_valid, tests_data)

    trainer.train()
Exemple #3
0
def main():

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    vocab_data = {}
    train_srcD_file = wargs.src_vocab_from
    wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict,
                              wargs.src_dict_size)
    vocab_data['src'] = src_vocab

    train_trgD_file = wargs.trg_vocab_from
    wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict,
                              wargs.trg_dict_size)
    vocab_data['trg'] = trg_vocab

    train_src_file = wargs.train_src
    train_trg_file = wargs.train_trg
    wlog('\nPreparing training set from {} and {} ... '.format(
        train_src_file, train_trg_file))
    train_src_tlst, train_trg_tlst = wrap_data(train_src_file,
                                               train_trg_file,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    '''
    devs = {}
    dev_src = wargs.val_tst_dir + wargs.val_prefix + '.src'
    dev_trg = wargs.val_tst_dir + wargs.val_prefix + '.ref0'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src, dev_trg))
    dev_src, dev_trg = wrap_data(dev_src, dev_trg, src_vocab, trg_vocab)
    devs['src'], devs['trg'] = dev_src, dev_trg
    '''

    valid_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                  wargs.val_src_suffix)
    wlog('\nPreparing validation set from {} ... '.format(valid_file))
    valid_src_tlst, valid_src_lens = val_wrap_data(valid_file, src_vocab)

    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))
    src_vocab_size, trg_vocab_size = vocab_data['src'].size(
    ), vocab_data['trg'].size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))

    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    batch_valid = Input(valid_src_tlst, None, 1, volatile=True)

    tests_data = None
    if wargs.tests_prefix is not None:
        init_dir(wargs.dir_tests)
        tests_data = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('Preparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = val_wrap_data(test_file, src_vocab)
            tests_data[prefix] = Input(test_src_tlst, None, 1, volatile=True)
    '''
    # lookup_table on cpu to save memory
    src_lookup_table = nn.Embedding(wargs.src_dict_size + 4,
                                    wargs.src_wemb_size, padding_idx=utils.PAD).cpu()
    trg_lookup_table = nn.Embedding(wargs.trg_dict_size + 4,
                                    wargs.trg_wemb_size, padding_idx=utils.PAD).cpu()

    wlog('Lookup table on CPU ... ')
    wlog(src_lookup_table)
    wlog(trg_lookup_table)
    '''

    sv = vocab_data['src'].idx2key
    tv = vocab_data['trg'].idx2key

    nmtModel = NMT(src_vocab_size, trg_vocab_size)
    #classifier = Classifier(wargs.out_size, trg_vocab_size,
    #                        nmtModel.decoder.trg_lookup_table if wargs.copy_trg_emb is True else None)

    if wargs.pre_train:

        assert os.path.exists(wargs.pre_train)
        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

        #tor = Translator(nmtModel, sv, tv)
        #tor.trans_tests(tests_data, eid, bid)

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        #for n, p in classifier.named_parameters(): init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)

    if wargs.gpu_id:
        nmtModel.cuda()
        #classifier.cuda()
        wlog('Push model onto GPU[{}] ... '.format(wargs.gpu_id[0]))
    else:
        nmtModel.cpu()
        #classifier.cpu()
        wlog('Push model onto CPU ... ')

    #nmtModel.classifier = classifier
    #nmtModel.decoder.map_vocab = classifier.map_vocab
    '''
    nmtModel.src_lookup_table = src_lookup_table
    nmtModel.trg_lookup_table = trg_lookup_table
    print nmtModel.src_lookup_table.weight.data.is_cuda

    nmtModel.classifier.init_weights(nmtModel.trg_lookup_table)
    '''

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    #tor = Translator(nmtModel, sv, tv, wargs.search_mode)
    #tor.trans_tests(tests_data, pre_dict['epoch'], pre_dict['batch'])

    trainer = Trainer(nmtModel, batch_train, vocab_data, optim, batch_valid,
                      tests_data)

    trainer.train()
Exemple #4
0
def main():

    #if wargs.ss_type is not None: assert wargs.model == 1, 'Only rnnsearch support schedule sample'
    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    vocabs = {}
    wlog('\n[o/Subword] Preparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src,
                              wargs.src_dict,
                              wargs.src_dict_size,
                              wargs.max_seq_len,
                              char=wargs.src_char)
    wlog('\n[o/Subword] Preparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size,
                              wargs.max_seq_len)
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data,
                                               wargs.train_prefix,
                                               wargs.train_src_suffix,
                                               wargs.train_trg_suffix,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len,
                                               char=wargs.src_char)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_sort=True)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_src_suffix)
        val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_ref_suffix)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            sort_data=False,
            max_seq_len=wargs.dev_max_seq_len,
            char=wargs.src_char)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            1,
                            volatile=True,
                            batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file,
                                             src_vocab,
                                             char=wargs.src_char)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        1,
                                        volatile=True,
                                        batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    nmtModel = NMT(src_vocab_size, trg_vocab_size)

    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train)

        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu,
                      model=wargs.model)

    if wargs.gpu_id is not None:
        wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        nmtModel.cuda()
    else:
        wlog('Push model onto CPU ... ', 0)
        nmtModel.cpu()

    wlog('done.')

    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    optim.init_optimizer(nmtModel.parameters())

    trainer = Trainer(nmtModel, batch_train, vocabs, optim, batch_valid,
                      batch_tests)

    trainer.train()
Exemple #5
0
def main():

    # Check if CUDA is available
    if cuda.is_available():
        wlog(
            'CUDA is available, specify device by gpu_id argument (i.e. gpu_id=[3])'
        )
    else:
        wlog('Warning: CUDA is not available, try CPU')

    if wargs.gpu_id:
        cuda.set_device(wargs.gpu_id[0])
        wlog('Using GPU {}'.format(wargs.gpu_id[0]))

    init_dir(wargs.dir_model)
    init_dir(wargs.dir_valid)
    '''
    train_srcD_file = wargs.dir_data + 'train.10k.zh5'
    wlog('\nPreparing source vocabulary from {} ... '.format(train_srcD_file))
    src_vocab = extract_vocab(train_srcD_file, wargs.src_dict, wargs.src_dict_size)

    train_trgD_file = wargs.dir_data + 'train.10k.en5'
    wlog('\nPreparing target vocabulary from {} ... '.format(train_trgD_file))
    trg_vocab = extract_vocab(train_trgD_file, wargs.trg_dict, wargs.trg_dict_size)

    train_src_file = wargs.dir_data + 'train.10k.zh0'
    train_trg_file = wargs.dir_data + 'train.10k.en0'
    wlog('\nPreparing training set from {} and {} ... '.format(train_src_file, train_trg_file))
    train_src_tlst, train_trg_tlst = wrap_data(train_src_file, train_trg_file, src_vocab, trg_vocab)
    #list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...], no padding
    wlog('Sentence-pairs count in training data: {}'.format(len(train_src_tlst)))
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(src_vocab_size, trg_vocab_size))
    batch_train = Input(train_src_tlst, train_trg_tlst, wargs.batch_size)
    '''

    src = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_src_suffix))
    trg = os.path.join(
        wargs.dir_data, '{}.{}'.format(wargs.train_prefix,
                                       wargs.train_trg_suffix))
    vocabs = {}
    wlog('\nPreparing source vocabulary from {} ... '.format(src))
    src_vocab = extract_vocab(src, wargs.src_dict, wargs.src_dict_size)
    wlog('\nPreparing target vocabulary from {} ... '.format(trg))
    trg_vocab = extract_vocab(trg, wargs.trg_dict, wargs.trg_dict_size)
    src_vocab_size, trg_vocab_size = src_vocab.size(), trg_vocab.size()
    wlog('Vocabulary size: |source|={}, |target|={}'.format(
        src_vocab_size, trg_vocab_size))
    vocabs['src'], vocabs['trg'] = src_vocab, trg_vocab

    wlog('\nPreparing training set from {} and {} ... '.format(src, trg))
    trains = {}
    train_src_tlst, train_trg_tlst = wrap_data(wargs.dir_data,
                                               wargs.train_prefix,
                                               wargs.train_src_suffix,
                                               wargs.train_trg_suffix,
                                               src_vocab,
                                               trg_vocab,
                                               max_seq_len=wargs.max_seq_len)
    '''
    list [torch.LongTensor (sentence), torch.LongTensor, torch.LongTensor, ...]
    no padding
    '''
    batch_train = Input(train_src_tlst,
                        train_trg_tlst,
                        wargs.batch_size,
                        batch_sort=True)
    wlog('Sentence-pairs count in training data: {}'.format(
        len(train_src_tlst)))

    batch_valid = None
    if wargs.val_prefix is not None:
        val_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_src_suffix)
        val_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.val_prefix,
                                        wargs.val_ref_suffix)
        wlog('\nPreparing validation set from {} and {} ... '.format(
            val_src_file, val_trg_file))
        valid_src_tlst, valid_trg_tlst = wrap_data(
            wargs.val_tst_dir,
            wargs.val_prefix,
            wargs.val_src_suffix,
            wargs.val_ref_suffix,
            src_vocab,
            trg_vocab,
            shuffle=False,
            sort_data=False,
            max_seq_len=wargs.dev_max_seq_len)
        batch_valid = Input(valid_src_tlst,
                            valid_trg_tlst,
                            1,
                            volatile=True,
                            batch_sort=False)

    batch_tests = None
    if wargs.tests_prefix is not None:
        assert isinstance(wargs.tests_prefix,
                          list), 'Test files should be list.'
        init_dir(wargs.dir_tests)
        batch_tests = {}
        for prefix in wargs.tests_prefix:
            init_dir(wargs.dir_tests + '/' + prefix)
            test_file = '{}{}.{}'.format(wargs.val_tst_dir, prefix,
                                         wargs.val_src_suffix)
            wlog('\nPreparing test set from {} ... '.format(test_file))
            test_src_tlst, _ = wrap_tst_data(test_file, src_vocab)
            batch_tests[prefix] = Input(test_src_tlst,
                                        None,
                                        1,
                                        volatile=True,
                                        batch_sort=False)
    wlog('\n## Finish to Prepare Dataset ! ##\n')

    nmtModel = NMT(src_vocab_size, trg_vocab_size)
    if wargs.pre_train is not None:

        assert os.path.exists(wargs.pre_train)

        _dict = _load_model(wargs.pre_train)
        # initializing parameters of interactive attention model
        class_dict = None
        if len(_dict) == 4: model_dict, eid, bid, optim = _dict
        elif len(_dict) == 5:
            model_dict, class_dict, eid, bid, optim = _dict
        for name, param in nmtModel.named_parameters():
            if name in model_dict:
                param.requires_grad = not wargs.fix_pre_params
                param.data.copy_(model_dict[name])
                wlog('{:7} -> grad {}\t{}'.format('Model', param.requires_grad,
                                                  name))
            elif name.endswith('map_vocab.weight'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.weight'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            elif name.endswith('map_vocab.bias'):
                if class_dict is not None:
                    param.requires_grad = not wargs.fix_pre_params
                    param.data.copy_(class_dict['map_vocab.bias'])
                    wlog('{:7} -> grad {}\t{}'.format('Model',
                                                      param.requires_grad,
                                                      name))
            else:
                init_params(param, name, True)

        wargs.start_epoch = eid + 1

    else:
        for n, p in nmtModel.named_parameters():
            init_params(p, n, True)
        optim = Optim(wargs.opt_mode,
                      wargs.learning_rate,
                      wargs.max_grad_norm,
                      learning_rate_decay=wargs.learning_rate_decay,
                      start_decay_from=wargs.start_decay_from,
                      last_valid_bleu=wargs.last_valid_bleu)
        optim.init_optimizer(nmtModel.parameters())

    if wargs.gpu_id:
        wlog('Push model onto GPU {} ... '.format(wargs.gpu_id), 0)
        nmtModel.cuda()
    else:
        wlog('Push model onto CPU ... ', 0)
        nmtModel.cpu()

    wlog('done.')
    wlog(nmtModel)
    wlog(optim)
    pcnt1 = len([p for p in nmtModel.parameters()])
    pcnt2 = sum([p.nelement() for p in nmtModel.parameters()])
    wlog('Parameters number: {}/{}'.format(pcnt1, pcnt2))

    trainer = Trainer(nmtModel,
                      src_vocab.idx2key,
                      trg_vocab.idx2key,
                      optim,
                      trg_vocab_size,
                      valid_data=batch_valid,
                      tests_data=batch_tests)

    # add 1000 to train
    train_all_chunks = (train_src_tlst, train_trg_tlst)
    dh = DataHisto(train_all_chunks)
    '''
    dev_src0 = wargs.dir_data + 'dev.1k.zh0'
    dev_trg0 = wargs.dir_data + 'dev.1k.en0'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src0, dev_trg0))
    dev_src0, dev_trg0 = wrap_data(dev_src0, dev_trg0, src_vocab, trg_vocab)
    wlog(len(train_src_tlst))

    dev_src1 = wargs.dir_data + 'dev.1k.zh1'
    dev_trg1 = wargs.dir_data + 'dev.1k.en1'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src1, dev_trg1))
    dev_src1, dev_trg1 = wrap_data(dev_src1, dev_trg1, src_vocab, trg_vocab)

    dev_src2 = wargs.dir_data + 'dev.1k.zh2'
    dev_trg2 = wargs.dir_data + 'dev.1k.en2'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src2, dev_trg2))
    dev_src2, dev_trg2 = wrap_data(dev_src2, dev_trg2, src_vocab, trg_vocab)

    dev_src3 = wargs.dir_data + 'dev.1k.zh3'
    dev_trg3 = wargs.dir_data + 'dev.1k.en3'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src3, dev_trg3))
    dev_src3, dev_trg3 = wrap_data(dev_src3, dev_trg3, src_vocab, trg_vocab)

    dev_src4 = wargs.dir_data + 'dev.1k.zh4'
    dev_trg4 = wargs.dir_data + 'dev.1k.en4'
    wlog('\nPreparing dev set for tuning from {} and {} ... '.format(dev_src4, dev_trg4))
    dev_src4, dev_trg4 = wrap_data(dev_src4, dev_trg4, src_vocab, trg_vocab)
    wlog(len(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0))
    batch_dev = Input(dev_src4+dev_src3+dev_src2+dev_src1+dev_src0, dev_trg4+dev_trg3+dev_trg2+dev_trg1+dev_trg0, wargs.batch_size)
    '''

    batch_dev = None
    assert wargs.dev_prefix is not None, 'Requires development to tuning.'
    dev_src_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix,
                                    wargs.val_src_suffix)
    dev_trg_file = '{}{}.{}'.format(wargs.val_tst_dir, wargs.dev_prefix,
                                    wargs.val_ref_suffix)
    wlog('\nPreparing dev set from {} and {} ... '.format(
        dev_src_file, dev_trg_file))
    valid_src_tlst, valid_trg_tlst = wrap_data(
        wargs.val_tst_dir,
        wargs.dev_prefix,
        wargs.val_src_suffix,
        wargs.val_ref_suffix,
        src_vocab,
        trg_vocab,
        shuffle=True,
        sort_data=True,
        max_seq_len=wargs.dev_max_seq_len)
    batch_dev = Input(valid_src_tlst,
                      valid_trg_tlst,
                      wargs.batch_size,
                      batch_sort=True)

    trainer.train(dh, batch_dev, 0, merge=True, name='DH_{}'.format('dev'))
    '''