예제 #1
0
def main(args):
    print("Loading data")
    corpus = data.Corpus(args.data,
                         max_vocab_size=args.max_vocab,
                         max_length=args.max_length)
    vocab_size = len(corpus.word2idx)
    print("\ttraining data size: ", corpus.train_data.size)
    print("\tvocabulary size: ", vocab_size)
    print("Constructing model")
    print(args)
    device = torch.device('cpu' if args.nocuda else 'cuda')
    model = LM(vocab_size, args.embed_size, args.hidden_size,
               args.dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.wd)
    best_loss = None

    print("\nStart training")
    try:
        for epoch in range(1, args.epochs + 1):
            epoch_start_time = time.time()
            train_ce, train_ppl = train(corpus.train_data, model, optimizer,
                                        epoch, device)
            valid_ce, valid_ppl = evaluate(corpus.valid_data, model, device)
            print('-' * 70)
            meta = "| epoch {:2d} | time {:5.2f}s ".format(
                epoch,
                time.time() - epoch_start_time)
            print(meta + "| train loss {:5.2f} | train ppl {:5.2f}".format(
                train_ce, train_ppl))
            print(len(meta) * ' ' + "| valid loss {:5.2f} "
                  "| valid ppl {:5.2f}".format(valid_ce, valid_ppl),
                  flush=True)
            if best_loss is None or valid_ce < best_loss:
                best_loss = valid_ce
                with open(get_savepath(args), 'wb') as f:
                    torch.save(model, f)

    except KeyboardInterrupt:
        print('-' * 70)
        print('Exiting from training early')

    with open(get_savepath(args), 'rb') as f:
        model = torch.load(f)
    test_ce, test_ppl = evaluate(corpus.test_data, model, device)
    print('=' * 70)
    print("| End of training | test loss {:5.2f} | test ppl {:5.2f}".format(
        test_ce, test_ppl))
    print('=' * 70)
예제 #2
0
def train():
    # generate file paths
    now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    savedir = os.path.join(config.SAVE_DIR, "hierarchical_bottom", f"{now}")
    os.makedirs(savedir, exist_ok=True)
    checkpointfn = os.path.join(savedir, "checkpoint.model")
    logfn = os.path.join(savedir, "run.log")
    torch.manual_seed(config.SEED)

    # create logger
    logger = logging.getLogger()
    if logger.hasHandlers():
        logger.handlers.clear()

    logger.setLevel(logging.DEBUG)

    fmt = logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s")
    console = logging.StreamHandler()
    console.setFormatter(fmt)
    logger.addHandler(console)

    logfile = logging.FileHandler(logfn, "a")
    logfile.setFormatter(fmt)
    logfile.setLevel(logging.DEBUG)
    logger.addHandler(logfile)

    with open("config.py", "r") as f:
        for l in f:
            logging.debug(l.strip())

    dataset = BrownDataset(config.CONTEXT_SIZE)
    model = LM()
    num_params = sum(p.numel() for p in model.parameters())
    logging.debug(f"The model has {num_params:,} parameters")

    # Trainer init
    logging.debug("Initiate the training environment")
    trainer = LMTrainer(model, dataset, checkpointfn)

    # Training
    logging.debug("Starting the training")
    for epoch in tqdm.tqdm(range(config.EPOCHS),
                           total=config.EPOCHS,
                           desc="EPOCH"):
        trainer.run_epoch()
        if trainer.patience > config.PATIENCE:
            logging.info("patience over {}, exiting".format(config.PATIENCE))
            break
예제 #3
0
def main(test_data_path):
    
    dic = pickle.load(open('vocab.pkl','rb'))
    word_vocab = dic['word_vocab']
    char_vocab = dic['char_vocab']
    max_len = dic['max_len']
    batch_size = config.batch_size
    embed_dim = config.embed_dim
    out_channels = config.out_channels
    kernels = config.kernels
    hidden_size = config.hidden_size
    learning_rate = config.learning_rate
    seq_len = config.seq_len

    test_data, _ = corpus_to_word(test_data_path, batch_size)
    
    test_idx = word_to_idx(test_data,word_vocab)
    test_idx = test_idx.contiguous().view(batch_size, -1)

    test_data = word_to_char(test_data, char_vocab, max_len)
    test_data = torch.from_numpy(test_data)
    test_data = test_data.contiguous().view(batch_size, -1, max_len)

    model = LM(word_vocab,char_vocab,max_len,embed_dim,out_channels,kernels,hidden_size)

    if torch.cuda.is_available():
        model.cuda()


    model.load_state_dict(torch.load('model.pkl'))

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.5,patience=1,verbose=True)

    hidden_state = (Variable(torch.zeros(2,batch_size,hidden_size).cuda(), volatile=False), 
                    Variable(torch.zeros(2,batch_size,hidden_size).cuda(), volatile=False))
    model.eval()
    test_loss = eval(seq_len,test_data,test_idx,model,hidden_state, criterion)
    test_loss = np.exp(test_loss)
예제 #4
0
def try_params(n_iterations, params):
    n_iterations = int(n_iterations)
    use_cuda = (len(options.gpuid) >= 1)
    if options.gpuid:
        cuda.set_device(options.gpuid[0])

    src_vocab = dill.load(open('src_vocab.pickle', 'rb'))
    trg_vocab = dill.load(open('trg_vocab.pickle', 'rb'))

    src_dev = dill.load(open('src_dev.pickle', 'rb'))
    trg_dev = dill.load(open('trg_dev.pickle', 'rb'))
    batched_dev_src, batched_dev_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_dev, options.batch_size, src_vocab.stoi["<blank>"])
    batched_dev_trg, batched_dev_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_dev, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)

    batches = []

    if options.contain_bilingual:
        print('Load')
        src_train = dill.load(open('src_sents1.pickle', 'rb'))
        print('Load src sents 1')
        trg_train = dill.load(open('trg_sents1.pickle', 'rb'))
        print('Load trg sents 1')
        batched_train_src1, batched_train_src_mask1, sort_index = utils.tensor.advanced_batchize(
            src_train, options.batch_size, src_vocab.stoi["<blank>"])
        batched_train_trg1, batched_train_trg_mask1 = utils.tensor.advanced_batchize_no_sort(
            trg_train, options.batch_size, trg_vocab.stoi["<blank>"],
            sort_index)
        batches = batches + [(1, i) for i in range(len(batched_train_src1))]
        if options.mono_loss:
            batches = batches + [(4, i)
                                 for i in range(len(batched_train_src1))]
            batches = batches + [(5, i)
                                 for i in range(len(batched_train_src1))]

    if options.contain_trg:
        print('Load')
        # src_train = dill.load(open('src_sents2.pickle', 'rb'))
        # print('Load src sents 2')
        trg_train = dill.load(open('trg_sents2.pickle', 'rb'))
        print('Load trg sents 2')
        # batched_train_src2, batched_train_src_mask2, sort_index = utils.tensor.advanced_batchize(src_train, options.batch_size, src_vocab.stoi["<blank>"])
        batched_train_trg2, batched_train_trg_mask2 = utils.tensor.advanced_batchize_no_sort(
            trg_train, options.batch_size, trg_vocab.stoi["<blank>"],
            sort_index)
        batches = batches + [(2, i) for i in range(len(batched_train_trg2))]

    if options.contain_src:
        print('Load')
        src_train = dill.load(open('src_sents3.pickle', 'rb'))
        print('Load src sents 3')
        # trg_train = dill.load(open('trg_sents3.pickle', 'rb'))
        # print('Load trg sents 3')
        batched_train_src3, batched_train_src_mask3, sort_index = utils.tensor.advanced_batchize(
            src_train, options.batch_size, src_vocab.stoi["<blank>"])
        # batched_train_trg3, batched_train_trg_mask3 = utils.tensor.advanced_batchize_no_sort(trg_train, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)
        batches = batches + [(3, i) for i in range(len(batched_train_src3))]

    src_vocab_size = len(src_vocab)
    trg_vocab_size = len(trg_vocab)

    if os.path.isfile(options.load_file_src) and os.path.isfile(
            options.load_file_trg):
        src_lm = torch.load(open(options.load_file_src, 'rb'))
        trg_lm = torch.load(open(options.load_file_trg, 'rb'))
    else:
        src_lm = LM(src_vocab_size, src_vocab.stoi['<s>'],
                    src_vocab.stoi['</s>'], params['embedding_size'],
                    params['hidden_size'], params['dropout'], use_cuda)
        trg_lm = LM(trg_vocab_size, trg_vocab.stoi['<s>'],
                    trg_vocab.stoi['</s>'], params['embedding_size'],
                    params['hidden_size'], params['dropout'], use_cuda)

    if use_cuda > 0:
        src_lm.cuda()
        trg_lm.cuda()
    else:
        src_lm.cpu()
        trg_lm.cpu()

    criterion = torch.nn.NLLLoss()
    optimizer_src = eval("torch.optim." + options.optimizer)(
        src_lm.parameters(), params['learning_rate'])
    optimizer_trg = eval("torch.optim." + options.optimizer)(
        trg_lm.parameters(), params['learning_rate'])

    # main training loop
    # last_dev_avg_loss = float("inf")
    for epoch_i in range(n_iterations):
        print(epoch_i)
        logging.info("At {0}-th epoch.".format(epoch_i))

        shuffle(batches)
        src_lm.train()
        trg_lm.train()
        for i, (index, batch_i) in enumerate(batches):

            train_src_batch = None
            train_src_mask = None
            train_trg_batch = None
            train_trg_mask = None

            if index == 1:
                train_src_batch = Variable(batched_train_src1[batch_i])
                train_src_mask = Variable(batched_train_src_mask1[batch_i])
                train_trg_batch = Variable(batched_train_trg1[batch_i])
                train_trg_mask = Variable(batched_train_trg_mask1[batch_i])
                if use_cuda:
                    train_src_batch = train_src_batch.cuda()
                    train_trg_batch = train_trg_batch.cuda()
                    train_src_mask = train_src_mask.cuda()
                    train_trg_mask = train_trg_mask.cuda()
            elif index == 2:
                train_trg_batch = Variable(batched_train_trg2[batch_i])
                train_trg_mask = Variable(batched_train_trg_mask2[batch_i])
                if use_cuda:
                    train_trg_batch = train_trg_batch.cuda()
                    train_trg_mask = train_trg_mask.cuda()
            elif index == 3:
                train_src_batch = Variable(batched_train_src3[batch_i])
                train_src_mask = Variable(batched_train_src_mask3[batch_i])
                if use_cuda:
                    train_src_batch = train_src_batch.cuda()
                    train_src_mask = train_src_mask.cuda()
            elif index == 4:
                train_src_batch = Variable(batched_train_src1[batch_i])
                train_src_mask = Variable(batched_train_src_mask1[batch_i])
                if use_cuda:
                    train_src_batch = train_src_batch.cuda()
                    train_src_mask = train_src_mask.cuda()
            elif index == 5:
                train_trg_batch = Variable(batched_train_trg1[batch_i])
                train_trg_mask = Variable(batched_train_trg_mask1[batch_i])
                if use_cuda:
                    train_trg_batch = train_trg_batch.cuda()
                    train_trg_mask = train_trg_mask.cuda()
            else:
                raise ValueError()

            total_loss = 0
            if index == 1:
                optimizer_trg.zero_grad()
                optimizer_src.zero_grad()
                h_src, c_src = src_lm(sent=train_src_batch)
                use_teacher_forcing = True if random.random(
                ) < params['teacher_forcing_ratio'] else False
                sys_out_batch = trg_lm(h=h_src,
                                       c=c_src,
                                       encode=False,
                                       tgt_sent=train_trg_batch,
                                       teacher_forcing=use_teacher_forcing)

                train_trg_mask_tmp = train_trg_mask.view(-1)
                train_trg_batch_tmp = train_trg_batch.view(-1)
                train_trg_batch_tmp = train_trg_batch_tmp.masked_select(
                    train_trg_mask_tmp)
                train_trg_mask_tmp = train_trg_mask_tmp.unsqueeze(1).expand(
                    len(train_trg_mask_tmp), trg_vocab_size)
                sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
                sys_out_batch = sys_out_batch.masked_select(
                    train_trg_mask_tmp).view(-1, trg_vocab_size)
                loss = criterion(sys_out_batch, train_trg_batch_tmp)
                loss.backward()
                optimizer_src.step()
                optimizer_trg.step()
                if i % 100 == 0:
                    logging.debug("loss at batch {0}: {1}".format(
                        i, loss.data[0]))

            elif options.mono_loss and train_src_batch is not None:
                optimizer_trg.zero_grad()
                optimizer_src.zero_grad()
                h_src, c_src = src_lm(sent=train_src_batch)
                use_teacher_forcing = True if random.random(
                ) < params['teacher_forcing_ratio'] else False
                sys_out_batch = src_lm(h=h_src,
                                       c=c_src,
                                       encode=False,
                                       tgt_sent=train_src_batch,
                                       teacher_forcing=use_teacher_forcing)

                train_src_mask_tmp = train_src_mask.view(-1)
                train_src_batch_tmp = train_src_batch.view(-1)
                train_src_batch_tmp = train_src_batch_tmp.masked_select(
                    train_src_mask_tmp)
                train_src_mask_tmp = train_src_mask_tmp.unsqueeze(1).expand(
                    len(train_src_mask_tmp), src_vocab_size)
                sys_out_batch = sys_out_batch.view(-1, src_vocab_size)
                sys_out_batch = sys_out_batch.masked_select(
                    train_src_mask_tmp).view(-1, src_vocab_size)
                loss = criterion(sys_out_batch, train_src_batch_tmp)
                loss *= params['mono_loss_multi'] * (1.0 / 10 * epoch_i)
                loss.backward()
                optimizer_src.step()
                optimizer_trg.step()
                if i % 100 == 0:
                    logging.debug("loss at batch {0}: {1}".format(
                        i, loss.data[0]))

            elif train_trg_batch is not None and options.mono_loss:
                optimizer_trg.zero_grad()
                optimizer_src.zero_grad()

                h_trg, c_trg = trg_lm(sent=train_trg_batch)
                use_teacher_forcing = True if random.random(
                ) < params['teacher_forcing_ratio'] else False
                sys_out_batch = trg_lm(h=h_trg,
                                       c=c_trg,
                                       encode=False,
                                       tgt_sent=train_trg_batch,
                                       teacher_forcing=use_teacher_forcing)

                train_trg_mask_tmp = train_trg_mask.view(-1)
                train_trg_batch_tmp = train_trg_batch.view(-1)
                train_trg_batch_tmp = train_trg_batch_tmp.masked_select(
                    train_trg_mask_tmp)
                train_trg_mask_tmp = train_trg_mask_tmp.unsqueeze(1).expand(
                    len(train_trg_mask_tmp), trg_vocab_size)
                sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
                sys_out_batch = sys_out_batch.masked_select(
                    train_trg_mask_tmp).view(-1, trg_vocab_size)
                loss = criterion(sys_out_batch, train_trg_batch_tmp)
                loss *= params['mono_loss_multi'] * (1.0 / 10 * epoch_i)
                loss.backward()
                optimizer_src.step()
                optimizer_trg.step()
                if i % 100 == 0:
                    logging.debug("loss at batch {0}: {1}".format(
                        i, loss.data[0]))

        # validation -- this is a crude esitmation because there might be some paddings at the end
        dev_loss = 0.0
        src_lm.eval()
        trg_lm.eval()
        for batch_i in range(len(batched_dev_src)):
            dev_src_batch = Variable(batched_dev_src[batch_i], volatile=True)
            dev_trg_batch = Variable(batched_dev_trg[batch_i], volatile=True)
            dev_src_mask = Variable(batched_dev_src_mask[batch_i],
                                    volatile=True)
            dev_trg_mask = Variable(batched_dev_trg_mask[batch_i],
                                    volatile=True)
            if use_cuda:
                dev_src_batch = dev_src_batch.cuda()
                dev_trg_batch = dev_trg_batch.cuda()
                dev_src_mask = dev_src_mask.cuda()
                dev_trg_mask = dev_trg_mask.cuda()

            h_src, c_src = src_lm(sent=dev_src_batch)
            sys_out_batch = trg_lm(h=h_src,
                                   c=c_src,
                                   encode=False,
                                   tgt_sent=dev_trg_batch)

            dev_trg_mask = dev_trg_mask.view(-1)
            dev_trg_batch = dev_trg_batch.view(-1)
            dev_trg_batch = dev_trg_batch.masked_select(dev_trg_mask)
            dev_trg_mask = dev_trg_mask.unsqueeze(1).expand(
                len(dev_trg_mask), trg_vocab_size)
            sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
            sys_out_batch = sys_out_batch.masked_select(dev_trg_mask).view(
                -1, trg_vocab_size)

            loss = criterion(sys_out_batch, dev_trg_batch)
            logging.debug("dev loss at batch {0}: {1}".format(
                batch_i, loss.data[0]))
            dev_loss += loss

        dev_avg_loss = dev_loss / len(batched_dev_src)
        logging.info(
            "Average loss value per instance is {0} at the end of epoch {1}".
            format(dev_avg_loss.data[0], epoch_i))
        # if (last_dev_avg_loss - dev_avg_loss).data[0] < options.estop:
        # logging.info("Early stopping triggered with threshold {0} (previous dev loss: {1}, current: {2})".format(epoch_i, last_dev_avg_loss.data[0], dev_avg_loss.data[0]))
        # break

        # torch.save(src_lm, open(options.model_file_src + ".nll_{0:.2f}.epoch_{1}".format(dev_avg_loss.data[0], epoch_i), 'wb'), pickle_module=dill)
        # torch.save(trg_lm, open(options.model_file_trg + ".nll_{0:.2f}.epoch_{1}".format(dev_avg_loss.data[0], epoch_i), 'wb'), pickle_module=dill)
        # last_dev_avg_loss = dev_avg_loss

        return {'loss': dev_avg_loss.data[0]}
예제 #5
0
          (val_loss / count, np.exp(val_loss / count)))

    return val_loss / count


model = LM(word_vocab, char_vocab, max_len, embed_dim, out_channels, kernels,
           hidden_size, batch_size)

if torch.cuda.is_available():
    model.cuda()

model.load_state_dict(torch.load('model.pkl'))

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(),
                            lr=learning_rate,
                            weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       'min',
                                                       factor=0.5,
                                                       patience=1,
                                                       verbose=True)

hidden_state = (to_var(torch.zeros(2, batch_size, hidden_size)),
                to_var(torch.zeros(2, batch_size, hidden_size)))

#validate
test_loss = eval(seq_len, test_data, test_label, model, hidden_state)
test_loss = np.exp(test_loss)
예제 #6
0
    def train_(self):


        cur_best = 10000

        model = LM(self.unique_words, self.char_vocab, self.max_len, self.embed_dim, self.channels, self.kernels, self.hidden_size)

        if torch.cuda.is_available():
            model.cuda()
        
        learning_rate = self.learning_rate
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

        for epoch in range(self.epochs):

            model.train(True)

            hidden_state = [torch.zeros(2, self.batch_size, self.hidden_size).cuda()] * 2 ########

            for i in range(0, self.train.size(1)-self.seq_len, self.seq_len):

                model.zero_grad()

                inputs = self.train[:, i : i + self.seq_len,:].cuda() # 20 * 35 * 21
                targets = self.train_idx[:, (i+1) : (i+1) + self.seq_len].cuda() # 20 * 35

                temp = []           

                for state in hidden_state:
                    temp.append(state.detach())
                
                hidden_state = temp

                output, hidden_state = model(inputs, hidden_state) # initialize?
                
                loss = criterion(output, targets.view(-1))
                        
                loss.backward()
                
                nn.utils.clip_grad_norm_(model.parameters(), 5) # clipping
                
                optimizer.step()
                
                step = (i+1) // self.seq_len                    
            
                if step % 100 == 0:        
                    print ('Epoch %d/%d, Batch x Seq_Len %d/%d, Loss: %.3f, Perplexity: %5.2f' % (epoch, self.epochs, step, self.num_batches//self.seq_len, loss.item(), np.exp(loss.item())))
            
            model.eval() 
            val_loss = self._validate(self.seq_len, self.valid, self.valid_idx, model, hidden_state, criterion)
            val_perplex = np.exp(val_loss)
                        
            if cur_best-val_perplex < 1 : # pivot?
            
                if learning_rate > 0.03: 
                    learning_rate = learning_rate * 0.5
                    print("Adjusted learning_rate : %.5f"%learning_rate)
                    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
                
                else:
                    pass
            
            if val_perplex < cur_best:
                print("The current best val loss: ", val_loss)
                cur_best = val_perplex
                torch.save(model.state_dict(), 'model.pkl')
예제 #7
0
for trial in range(num_trial):

    pivot = 100000

    model = LM(word_vocab, char_vocab, max_len, embed_dim, out_channels,
               kernels, hidden_size, batch_size)

    if torch.cuda.is_available():
        model.cuda()

    learning_rate = 1.0

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=1,
                                                           verbose=True)

    for epoch in range(num_epochs):

        hidden_state = (to_var(torch.zeros(2, batch_size, hidden_size)),
                        to_var(torch.zeros(2, batch_size, hidden_size)))

        model.train(True)
        for i in range(0, data.size(1) - seq_len, seq_len):

            inputs = to_var(data[:, i:i + seq_len, :])
예제 #8
0
def main(options):

    use_cuda = (len(options.gpuid) >= 1)
    if options.gpuid:
        cuda.set_device(options.gpuid[0])

    src_vocab = dill.load(open('src_vocab.pickle', 'rb'))
    trg_vocab = dill.load(open('trg_vocab.pickle', 'rb'))

    src_dev = dill.load(open('src_dev.pickle', 'rb'))
    trg_dev = dill.load(open('trg_dev.pickle', 'rb'))
    batched_dev_src, batched_dev_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_dev, options.batch_size, src_vocab.stoi["<blank>"])
    batched_dev_trg, batched_dev_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_dev, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)

    # batches = []

    # if options.contain_bilingual:
    print('Load')
    src_train = dill.load(open('src_sents1.pickle', 'rb'))
    print('Load src sents 1')
    trg_train = dill.load(open('trg_sents1.pickle', 'rb'))
    print('Load trg sents 1')
    src_train = src_train + dill.load(open('src_sents2.pickle', 'rb'))
    print('Load src sents 2')
    trg_train = trg_train + dill.load(open('trg_sents2.pickle', 'rb'))
    print('Load trg sents 2')
    src_train = src_train + dill.load(open('src_sents3.pickle', 'rb'))
    print('Load src sents 3')
    trg_train = trg_train + dill.load(open('trg_sents3.pickle', 'rb'))
    print('Load trg sents 3')

    batched_train_src, batched_train_src_mask, sort_index = utils.tensor.advanced_batchize(
        src_train, options.batch_size, src_vocab.stoi["<blank>"])
    batched_train_trg, batched_train_trg_mask = utils.tensor.advanced_batchize_no_sort(
        trg_train, options.batch_size, trg_vocab.stoi["<blank>"], sort_index)

    src_vocab_size = len(src_vocab)
    trg_vocab_size = len(trg_vocab)

    if os.path.isfile(options.load_file_src) and os.path.isfile(
            options.load_file_trg):
        src_lm = torch.load(open(options.load_file_src, 'rb'))
        trg_lm = torch.load(open(options.load_file_trg, 'rb'))
    else:
        src_lm = LM(src_vocab_size, src_vocab.stoi['<s>'],
                    src_vocab.stoi['</s>'], options.embedding_size,
                    options.hidden_size, options.dropout, use_cuda)
        trg_lm = LM(trg_vocab_size, trg_vocab.stoi['<s>'],
                    trg_vocab.stoi['</s>'], options.embedding_size,
                    options.hidden_size, options.dropout, use_cuda)

    if use_cuda > 0:
        src_lm.cuda()
        trg_lm.cuda()
    else:
        src_lm.cpu()
        trg_lm.cpu()

    criterion = torch.nn.NLLLoss()
    optimizer_src = eval("torch.optim." + options.optimizer)(
        src_lm.parameters(), options.learning_rate)
    optimizer_trg = eval("torch.optim." + options.optimizer)(
        trg_lm.parameters(), options.learning_rate)

    # main training loop
    # last_dev_avg_loss = float("inf")
    for epoch_i in range(options.epochs):
        print(epoch_i)
        logging.info("At {0}-th epoch.".format(epoch_i))
        # srange generates a lazy sequence of shuffled range
        src_lm.train()
        trg_lm.train()
        for i, batch_i in enumerate(range(len(batched_train_src))):
            optimizer_trg.zero_grad()
            optimizer_src.zero_grad()

            train_src_batch = Variable(batched_train_src[batch_i])
            train_src_mask = Variable(batched_train_src_mask[batch_i])
            train_trg_batch = Variable(batched_train_trg[batch_i])
            train_trg_mask = Variable(batched_train_trg_mask[batch_i])
            if use_cuda:
                train_src_batch = train_src_batch.cuda()
                train_trg_batch = train_trg_batch.cuda()
                train_src_mask = train_src_mask.cuda()
                train_trg_mask = train_trg_mask.cuda()

            h_src, c_src = src_lm(sent=train_src_batch)
            use_teacher_forcing = True if random.random(
            ) < options.teacher_forcing_ratio else False
            sys_out_batch = trg_lm(h=h_src,
                                   c=c_src,
                                   encode=False,
                                   tgt_sent=train_trg_batch,
                                   teacher_forcing=use_teacher_forcing)

            train_trg_mask_tmp = train_trg_mask.view(-1)
            train_trg_batch_tmp = train_trg_batch.view(-1)
            train_trg_batch_tmp = train_trg_batch_tmp.masked_select(
                train_trg_mask_tmp)
            train_trg_mask_tmp = train_trg_mask_tmp.unsqueeze(1).expand(
                len(train_trg_mask_tmp), trg_vocab_size)
            sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
            sys_out_batch = sys_out_batch.masked_select(
                train_trg_mask_tmp).view(-1, trg_vocab_size)
            loss = criterion(sys_out_batch, train_trg_batch_tmp)
            loss.backward()
            optimizer_src.step()
            optimizer_trg.step()
            if i % 100 == 0:
                logging.debug("loss at batch {0}: {1}".format(i, loss.data[0]))

        # validation -- this is a crude esitmation because there might be some paddings at the end
        dev_loss = 0.0
        src_lm.eval()
        trg_lm.eval()
        for batch_i in range(len(batched_dev_src)):
            dev_src_batch = Variable(batched_dev_src[batch_i], volatile=True)
            dev_trg_batch = Variable(batched_dev_trg[batch_i], volatile=True)
            dev_src_mask = Variable(batched_dev_src_mask[batch_i],
                                    volatile=True)
            dev_trg_mask = Variable(batched_dev_trg_mask[batch_i],
                                    volatile=True)
            if use_cuda:
                dev_src_batch = dev_src_batch.cuda()
                dev_trg_batch = dev_trg_batch.cuda()
                dev_src_mask = dev_src_mask.cuda()
                dev_trg_mask = dev_trg_mask.cuda()

            h_src, c_src = src_lm(sent=dev_src_batch)
            sys_out_batch = trg_lm(h=h_src,
                                   c=c_src,
                                   encode=False,
                                   tgt_sent=dev_trg_batch)

            dev_trg_mask = dev_trg_mask.view(-1)
            dev_trg_batch = dev_trg_batch.view(-1)
            dev_trg_batch = dev_trg_batch.masked_select(dev_trg_mask)
            dev_trg_mask = dev_trg_mask.unsqueeze(1).expand(
                len(dev_trg_mask), trg_vocab_size)
            sys_out_batch = sys_out_batch.view(-1, trg_vocab_size)
            sys_out_batch = sys_out_batch.masked_select(dev_trg_mask).view(
                -1, trg_vocab_size)

            loss = criterion(sys_out_batch, dev_trg_batch)
            logging.debug("dev loss at batch {0}: {1}".format(
                batch_i, loss.data[0]))
            dev_loss += loss

        dev_avg_loss = dev_loss / len(batched_dev_src)
        logging.info(
            "Average loss value per instance is {0} at the end of epoch {1}".
            format(dev_avg_loss.data[0], epoch_i))

        # if (last_dev_avg_loss - dev_avg_loss).data[0] < options.estop:
        # logging.info("Early stopping triggered with threshold {0} (previous dev loss: {1}, current: {2})".format(epoch_i, last_dev_avg_loss.data[0], dev_avg_loss.data[0]))
        # break

        torch.save(
            src_lm,
            open(
                options.model_file_src +
                ".nll_{0:.2f}.epoch_{1}".format(dev_avg_loss.data[0], epoch_i),
                'wb'),
            pickle_module=dill)
        torch.save(
            trg_lm,
            open(
                options.model_file_trg +
                ".nll_{0:.2f}.epoch_{1}".format(dev_avg_loss.data[0], epoch_i),
                'wb'),
            pickle_module=dill)