Пример #1
0
def trainIters(attn_model, hidden_size,encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, \
               learning_rate, decoder_learning_ratio, print_every, save_every, clip, dropout, \
               corpus_name, datafile, modelFile=None, need_trim=True):
    # load train data
    voc, pairs = loadPrepareData(datafile)
    if need_trim:
        # Trim voc and pairs
        pairs = trimRareWords(voc, pairs, MIN_COUNT)
    # Load batches for each iteration
    training_batches = [
        batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
        for _ in range(n_iteration)
    ]

    if modelFile:
        # If loading on same machine the model was trained on
        checkpoint = torch.load(modelFile)
        # If loading a model trained on GPU to CPU
        # checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
        encoder_sd = checkpoint['en']
        decoder_sd = checkpoint['de']
        encoder_optimizer_sd = checkpoint['en_opt']
        decoder_optimizer_sd = checkpoint['de_opt']
        embedding_sd = checkpoint['embedding']

    embedding = nn.Embedding(voc.num_words, hidden_size)
    if modelFile:
        embedding.load_state_dict(embedding_sd)

    # Initialize encoder & decoder models
    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.num_words, decoder_n_layers, dropout)
    # get model params
    if modelFile:
        encoder.load_state_dict(encoder_sd)
        decoder.load_state_dict(decoder_sd)

    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if modelFile:
        encoder_optimizer.load_state_dict(encoder_optimizer_sd)
        decoder_optimizer.load_state_dict(decoder_optimizer_sd)
    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if modelFile:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    encoder.train()
    decoder.train()

    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, encoder_optimizer,
                     decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, \
                                                            iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, "model", '{}-{}_{}'.format(encoder_n_layers, \
                                                                                          decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'embedding': embedding.state_dict()
                },
                os.path.join(directory,
                             '{}_{}.tar'.format(iteration, 'checkpoint')))
Пример #2
0
def trainIters(corpus,
               reverse,
               n_iteration,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               save_every,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):

    voc, pairs = loadPrepareData(corpus)

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    training_batches = None
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except BaseException:  #OWEN: was FileNotFoundError
        print('Training pairs not found, generating ...')
        training_batches = [
            batch2TrainData(voc,
                            [random.choice(pairs)
                             for _ in range(batch_size)], reverse)
            for _ in range(n_iteration)
        ]
        torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name,
                                                  '{}_{}_{}.tar'.format(n_iteration, \
                                                                        filename(reverse, 'training_batches'), \
                                                                        batch_size)))
    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    embedding = nn.Embedding(voc.n_words, hidden_size)
    encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
    attn_model = 'dot'
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.n_words, n_layers)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    if USE_CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']

    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, embedding,
                     encoder_optimizer, decoder_optimizer, batch_size)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            perplexity.append(print_loss_avg)
            # show perplexity (lots of numbers!):
            #print(perplexity, iteration)
            # plotPerplexity(perplexity, iteration)
            print('%d %d%% %.4f' %
                  (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(
                save_dir, 'model', corpus_name,
                '{}-{}_{}'.format(n_layers, n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(iteration,
                                       filename(reverse,
                                                'backup_bidir_model'))))
def train(**kwargs):

    opt = Config()
    for k, v in kwargs.items():  #设置参数
        setattr(opt, k, v)

    # 数据
    dataloader = get_dataloader(opt)
    _data = dataloader.dataset._data
    word2ix = _data['word2ix']
    sos = word2ix.get(_data.get('sos'))
    voc_length = len(word2ix)

    #定义模型
    encoder = EncoderRNN(opt, voc_length)
    decoder = LuongAttnDecoderRNN(opt, voc_length)

    #加载断点,从上次结束地方开始
    if opt.model_ckpt:
        checkpoint = torch.load(opt.model_ckpt)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])

    #切换模式
    encoder = encoder.to(opt.device)
    decoder = decoder.to(opt.device)
    encoder.train()
    decoder.train()

    #定义优化器(注意与encoder.to(device)前后不要反)
    encoder_optimizer = torch.optim.Adam(encoder.parameters(),
                                         lr=opt.learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(),
                                         lr=opt.learning_rate *
                                         opt.decoder_learning_ratio)
    if opt.model_ckpt:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    #定义打印loss的变量
    print_loss = 0

    for epoch in range(opt.epoch):
        for ii, data in enumerate(dataloader):
            #取一个batch训练
            loss = train_by_batch(sos, opt, data, encoder_optimizer,
                                  decoder_optimizer, encoder, decoder)
            print_loss += loss
            #打印损失
            if ii % opt.print_every == 0:
                print_loss_avg = print_loss / opt.print_every
                print(
                    "Epoch: {}; Epoch Percent complete: {:.1f}%; Average loss: {:.4f}"
                    .format(epoch, epoch / opt.epoch * 100, print_loss_avg))
                print_loss = 0

        # 保存checkpoint
        if epoch % opt.save_every == 0:
            checkpoint_path = '{prefix}_{time}'.format(
                prefix=opt.prefix, time=time.strftime('%m%d_%H%M'))
            torch.save(
                {
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                }, checkpoint_path)
Пример #4
0
def trainIters(corpus,
               pre_modelFile,
               reverse,
               n_iteration,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               save_every,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):

    voc, pairs = loadPrepareData(corpus)

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    training_batches = None
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = [
            batch2TrainData(voc,
                            [random.choice(pairs)
                             for _ in range(batch_size)], reverse)
            for _ in range(n_iteration)
        ]
        torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name,
                                                  '{}_{}_{}.tar'.format(n_iteration, \
                                                                        filename(reverse, 'training_batches'), \
                                                                        batch_size)))
    # model
    checkpoint = None
    #print('Building pretrained word2vector model...')
    embedding = nn.Embedding(
        300, hidden_size)  #The dimension of google's model is 300
    #-----------------------------------------------------------------
    #my code
    '''
    EMBEDDING_DIM = 300 #Should be the same as hidden_size!
    if EMBEDDING_DIM != hidden_size:
        sys.exit("EMBEDDING_DIM do not equal to hidden_size. Please correct it.")
    CONTEXT_SIZE = 2
    pre_checkpoint = torch.load(pre_modelFile)
    pretrained_model = NGramLanguageModeler(voc.n_words, EMBEDDING_DIM, CONTEXT_SIZE)
    pretrained_model.load_state_dict(pre_checkpoint['w2v'])
    pretrained_model.train(False)
    embedding = pretrained_model
    '''
    if USE_CUDA:
        embedding = embedding.cuda()

    #-----------------------------------------------------------------
    #replace embedding by pretrained_model
    print('Building encoder and decoder ...')
    encoder = EncoderRNN(300, hidden_size, embedding, n_layers)
    attn_model = 'dot'
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.n_words, n_layers)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    if USE_CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # Load Google's pre-trained Word2Vec model.
    print('Loading w2v_model ...')
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',
                        level=logging.INFO)
    w2v_model = gensim.models.KeyedVectors.load_word2vec_format(pre_modelFile,
                                                                binary=True)
    print("Loading complete!")

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']

    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, embedding,
                     encoder_optimizer, decoder_optimizer, batch_size,
                     w2v_model, voc)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            # perplexity.append(print_loss_avg)
            # plotPerplexity(perplexity, iteration)
            print('%d %d%% %.4f' %
                  (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(
                save_dir, 'model', corpus_name,
                '{}-{}_{}'.format(n_layers, n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(iteration,
                                       filename(reverse,
                                                'backup_bidir_model'))))
Пример #5
0
def main():
    epoch = 1000
    batch_size = 64
    hidden_dim = 300
    use_cuda = True

    encoder = Encoder(num_words, hidden_dim)
    if args.attn:
        attn_model = 'dot'
        decoder = LuongAttnDecoderRNN(attn_model, hidden_dim, num_words)
    else:
        decoder = DecoderRhyme(hidden_dim, num_words, num_target_lengths,
                               num_rhymes)

    if args.train:
        weight = torch.ones(num_words)
        weight[word2idx_mapping[PAD_TOKEN]] = 0
        if use_cuda:
            encoder = encoder.cuda()
            decoder = decoder.cuda()
            weight = weight.cuda()
        encoder_optimizer = Adam(encoder.parameters(), lr=0.001)
        decoder_optimizer = Adam(decoder.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss(weight=weight)

        np.random.seed(1124)
        order = np.arange(len(train_data))

        best_loss = 1e10
        best_epoch = 0

        for e in range(epoch):
            #if e - best_epoch > 20: break

            np.random.shuffle(order)
            shuffled_train_data = train_data[order]
            shuffled_x_lengths = input_lengths[order]
            shuffled_y_lengths = target_lengths[order]
            shuffled_y_rhyme = target_rhymes[order]
            train_loss = 0
            valid_loss = 0
            for b in tqdm(range(int(len(order) // batch_size))):
                #print(b, '\r', end='')
                batch_x = torch.LongTensor(
                    shuffled_train_data[b * batch_size:(b + 1) *
                                        batch_size][:, 0].tolist()).t()
                batch_y = torch.LongTensor(
                    shuffled_train_data[b * batch_size:(b + 1) *
                                        batch_size][:, 1].tolist()).t()
                batch_x_lengths = shuffled_x_lengths[b * batch_size:(b + 1) *
                                                     batch_size]
                batch_y_lengths = shuffled_y_lengths[b * batch_size:(b + 1) *
                                                     batch_size]
                batch_y_rhyme = shuffled_y_rhyme[b * batch_size:(b + 1) *
                                                 batch_size]

                if use_cuda:
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                train_loss += train(batch_x, batch_y, batch_y_lengths,
                                    max(batch_y_lengths), batch_y_rhyme,
                                    encoder, decoder, encoder_optimizer,
                                    decoder_optimizer, criterion, use_cuda,
                                    False)

            train_loss /= b
            '''
            for b in range(len(valid_data) // batch_size):
                batch_x = torch.LongTensor(valid_data[b*batch_size: (b+1)*batch_size][:, 0].tolist()).t()
                batch_y = torch.LongTensor(valid_data[b*batch_size: (b+1)*batch_size][:, 1].tolist()).t()
                if use_cuda:
                    batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

                valid_loss += train(batch_x, batch_y, max_seqlen, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, use_cuda, True)
            valid_loss /= b
            '''
            print(
                "epoch {}, train_loss {:.4f}, valid_loss {:.4f}, best_epoch {}, best_loss {:.4f}"
                .format(e, train_loss, valid_loss, best_epoch, best_loss))
            '''
            if valid_loss < best_loss:
                best_loss = valid_loss
                best_epoch = e
                torch.save(encoder.state_dict(), args.encoder_path + '.best')
                torch.save(decoder.state_dict(), args.decoder_path + '.best')
            '''
            torch.save(encoder.state_dict(), args.encoder_path)
            torch.save(decoder.state_dict(), args.decoder_path)
        print(encoder)
        print(decoder)
        print("==============")

    else:
        encoder.load_state_dict(torch.load(
            args.encoder_path))  #, map_location=torch.device('cpu')))
        decoder.load_state_dict(torch.load(
            args.decoder_path))  #, map_location=torch.device('cpu')))
        print(encoder)
        print(decoder)

    predict(encoder, decoder)
Пример #6
0
def trainIters(corpus,
               reverse,
               n_iteration,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               save_every,
               dropout,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):

    voc, pairs = loadPrepareData(corpus)
    #todo:string转数字的字典,pairs为等待转换的对话

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    training_batches = None
    #todo:training_batches=随机抽取64组对话,交给batch2TrainData构成一组batch
    #TODO:没有采用epoch的模式,batch2TrainData负责將 load.py 所整理好的training pairs,轉換成input, output Variable。 总计循环n_iteration次,
    #TODO: 每次iteration调用batch2TrainData构造一个batch。每个batch为随机抽取64组对话,交给batch2TrainData构成一组batch。 因此此处有待改造
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = [
            batch2TrainData(voc,
                            [random.choice(pairs)
                             for _ in range(batch_size)], reverse)
            for _ in range(n_iteration)
        ]
    # # model
    checkpoint = None
    print('Building encoder and decoder ...')

    embedding = nn.Embedding(voc.n_words, hidden_size)
    encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers,
                         dropout)
    attn_model = 'dot'
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.n_words, n_layers, dropout)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
# if torch.cuda.device_count()>1:
# encoder=nn.DataParallel(encoder)
#decoder=nn.DataParallel(decoder)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']

    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, embedding,
                     encoder_optimizer, decoder_optimizer, batch_size)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            #print('%d %d%% %.4f' % (iteration, iteration / n_iteration * 100, print_loss_avg))
            with open('log.txt', 'a') as f:
                import time
                template = ' Iter: {:0>6d} process: {:.2f} avg_loss: {:.4f} time: {}\n'
                str = template.format(
                    iteration, iteration / n_iteration * 100, print_loss_avg,
                    time.asctime(time.localtime(time.time())))
                f.write(str)
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(
                save_dir, 'model', corpus_name,
                '{}-{}_{}'.format(n_layers, n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(iteration,
                                       filename(reverse,
                                                'backup_bidir_model'))))
Пример #7
0
def trainIters(corpus,
               reverse,
               n_iteration,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               save_every,
               dropout,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):

    voc, pairs = loadPrepareData(corpus)
    embedding_dict = concate_embedding(pairs, voc, hidden_size)

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    training_batches = None
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except FileNotFoundError:
        print('Generating training batches...')
        training_batches = [
            batch2TrainData([random.choice(pairs)
                             for _ in range(batch_size)], voc, reverse)
            for _ in range(n_iteration)
        ]
        torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name,
                                                  '{}_{}_{}.tar'.format(n_iteration, \
                                                                            filename(reverse, 'training_batches'), \
                                                                            batch_size)))

    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    encoder = EncoderRNN(hidden_size, batch_size, n_layers, dropout)
    attn_model = 'dot'
    decoder = LuongAttnDecoderRNN(attn_model, hidden_size, batch_size,
                                  voc.loc_count, n_layers, dropout)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']

    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        training_batch = training_batches[iteration - 1]
        input_vec, input_lengths, target_vec, max_target_len = training_batch
        # print("input_lengths:", input_lengths)

        loss = train(input_vec, input_lengths, target_vec, max_target_len,
                     encoder, decoder, embedding_dict, encoder_optimizer,
                     decoder_optimizer, batch_size)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            print('%d %d%% %.4f' %
                  (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(
                save_dir, 'model', corpus_name,
                '{}-{}_{}'.format(n_layers, batch_size, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(iteration,
                                       filename(reverse,
                                                'backup_bidir_model'))))
Пример #8
0
def train():
    parameter = Config()
    model_name = parameter.model_name
    save_dir = parameter.save_dir
    loadFilename = parameter.model_ckpt

    pretrained_embedding_path = parameter.pretrained_embedding_path
    max_input_length = parameter.max_input_length
    max_generate_length = parameter.max_generate_length
    embedding_dim = parameter.embedding_dim
    batch_size = parameter.batch_size
    hidden_size = parameter.hidden_size
    attn_model = parameter.method
    dropout = parameter.dropout
    clip = parameter.clip
    num_layers = parameter.num_layers

    learning_rate = parameter.learning_rate
    teacher_forcing_ratio = parameter.teacher_forcing_ratio
    decoder_learning_ratio = parameter.decoder_learning_ratio
    n_iteration = parameter.epoch
    print_every = parameter.print_every
    save_every = parameter.save_every
    print(max_input_length,max_generate_length)
    #data
    voc = read_voc_file() #从保存的词汇表之中读取词汇
    print(voc)
    pairs = get_pairs()
    train_batches = None
    try :
        training_batches = torch.load( os.path.join(save_dir, '{}_{}_{}.tar'.format(n_iteration, 'training_batches', batch_size)))
    except FileNotFoundError:
        training_batches = [get_batch(voc, batch_size, pairs, max_input_length, max_generate_length) for _ in
                            range(n_iteration)]
        torch.save(training_batches, os.path.join(save_dir, '{}_{}_{}.tar'.format(n_iteration, 'training_batches', batch_size)))

    #model
    checkpoint = None
    print('Building encoder and decoder ...')
    if pretrained_embedding_path == None :
        embedding = nn.Embedding(len(voc), embedding_dim)
    else:
        embedding = get_weight(voc, pretrained_embedding_path, embedding_dim)
    print('embedding加载完成')
    encoder = EncoderRNN(hidden_size, embedding, num_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, len(voc), num_layers, dropout)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])
    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']
    
    f = open('record.txt','w',encoding ='utf-8')
    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch
        loss = train_by_batch(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size,clip,teacher_forcing_ratio)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            print('%d %d%% %.4f' % (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, 'model', model_name, '{}-{}_{}'.format(num_layers, num_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'plt': perplexity
            }, os.path.join(directory, '{}_{}.tar'.format(iteration,  'backup_bidir_model')))
            print(perplexity)
Пример #9
0
def trainIters(corpus,
               reverse,
               n_epoch,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=1.0):
    print(
        "corpus: {}, reverse={}, n_epoch={}, learning_rate={}, batch_size={}, n_layers={}, hidden_size={}, decoder_learning_ratio={}"
        .format(corpus, reverse, n_epoch, learning_rate, batch_size, n_layers,
                hidden_size, decoder_learning_ratio))

    voc, pairs, valid_pairs, test_pairs = loadPrepareData(corpus)
    print('load data...')

    path = "data/expansion"
    # training data
    corpus_name = corpus
    training_batches = None
    try:
        training_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'training_batches'),
                                   batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = batchify(pairs, batch_size, voc, reverse)
        print('Complete building training pairs ...')
        torch.save(
            training_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'training_batches'),
                                   batch_size)))

    # validation/test data
    eval_batch_size = 10
    try:
        val_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'val_batches'),
                                   eval_batch_size)))
    except FileNotFoundError:
        print('Validation pairs not found, generating ...')
        val_batches = batchify(valid_pairs,
                               eval_batch_size,
                               voc,
                               reverse,
                               evaluation=True)
        print('Complete building validation pairs ...')
        torch.save(
            val_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'val_batches'),
                                   eval_batch_size)))

    try:
        test_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'test_batches'),
                                   eval_batch_size)))
    except FileNotFoundError:
        print('Test pairs not found, generating ...')
        test_batches = batchify(test_pairs,
                                eval_batch_size,
                                voc,
                                reverse,
                                evaluation=True)
        print('Complete building test pairs ...')
        torch.save(
            test_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'test_batches'),
                                   eval_batch_size)))

    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    # aspect
    with open(os.path.join(save_dir, '15_aspect.pkl'), 'rb') as fp:
        aspect_ids = pickle.load(fp)
    aspect_num = 15  # 15 | 20 main aspects and each of them has 100 words
    aspect_ids = Variable(
        torch.LongTensor(aspect_ids), requires_grad=False
    )  # convert list into torch Variable, used to index word embedding
    # attribute embeddings
    attr_size = 64  #
    attr_num = 2

    print(
        "corpus: {}, reverse={}, n_words={}, n_epoch={}, learning_rate={}, batch_size={}, n_layers={}, hidden_size={}, decoder_learning_ratio={}, attr_size={}, aspect_num={}"
        .format(corpus, reverse, voc.n_words, n_epoch, learning_rate,
                batch_size, n_layers, hidden_size, decoder_learning_ratio,
                attr_size, aspect_num))
    with open(os.path.join(save_dir, 'user_item.pkl'), 'rb') as fp:
        user_dict, item_dict = pickle.load(fp)
    num_user = len(user_dict)
    num_item = len(item_dict)
    attr_embeddings = []
    attr_embeddings.append(nn.Embedding(num_user, attr_size))
    attr_embeddings.append(nn.Embedding(num_item, attr_size))
    aspect_embeddings = []
    aspect_embeddings.append(nn.Embedding(num_user, aspect_num))
    aspect_embeddings.append(nn.Embedding(num_item, aspect_num))
    if USE_CUDA:
        for attr_embedding in attr_embeddings:
            attr_embedding = attr_embedding.cuda()
        for aspect_embedding in aspect_embeddings:
            aspect_embedding = aspect_embedding.cuda()
        aspect_ids = aspect_ids.cuda()

    encoder1 = AttributeEncoder(attr_size, attr_num, hidden_size,
                                attr_embeddings, n_layers)
    encoder2 = AttributeEncoder(aspect_num, attr_num, hidden_size,
                                aspect_embeddings, n_layers)
    embedding = nn.Embedding(voc.n_words, hidden_size)
    encoder3 = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
    attn_model = 'dot'
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  attr_size, voc.n_words, aspect_ids, n_layers)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder1.load_state_dict(checkpoint['en1'])
        encoder2.load_state_dict(checkpoint['en2'])
        encoder3.load_state_dict(checkpoint['en3'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    if USE_CUDA:
        encoder1 = encoder1.cuda()
        encoder2 = encoder2.cuda()
        encoder3 = encoder3.cuda()
        decoder = decoder.cuda()

    # optimizer
    print('Building optimizers ...')
    encoder1_optimizer = optim.Adam(encoder1.parameters(), lr=learning_rate)
    encoder2_optimizer = optim.Adam(encoder2.parameters(), lr=learning_rate)
    encoder3_optimizer = optim.Adam(encoder3.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder1_optimizer.load_state_dict(checkpoint['en1_opt'])
        encoder2_optimizer.load_state_dict(checkpoint['en2_opt'])
        encoder3_optimizer.load_state_dict(checkpoint['en3_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_epoch = 0
    perplexity = []
    best_val_loss = None
    print_loss = 0
    if loadFilename:
        start_epoch = checkpoint['epoch'] + 1
        perplexity = checkpoint['plt']

    for epoch in range(start_epoch, n_epoch):
        epoch_start_time = time.time()
        # train epoch
        encoder1.train()
        encoder2.train()
        encoder3.train()
        decoder.train()
        print_loss = 0
        start_time = time.time()
        for batch, training_batch in enumerate(training_batches):
            attr_input, summary_input, summary_input_lengths, title_input, title_input_lengths, target_variable, mask, max_target_len = training_batch

            loss = train(attr_input, summary_input, summary_input_lengths,
                         title_input, title_input_lengths, target_variable,
                         mask, max_target_len, encoder1, encoder2, encoder3,
                         decoder, embedding, encoder1_optimizer,
                         encoder2_optimizer, encoder3_optimizer,
                         decoder_optimizer, batch_size)
            print_loss += loss
            perplexity.append(loss)
            #print("batch {} loss={}".format(batch, loss))
            if batch % print_every == 0 and batch > 0:
                cur_loss = print_loss / print_every
                elapsed = time.time() - start_time

                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch, len(training_batches), learning_rate,
                        elapsed * 1000 / print_every, cur_loss,
                        math.exp(cur_loss)))

                print_loss = 0
                start_time = time.time()
        # evaluate
        val_loss = 0
        for val_batch in val_batches:
            attr_input, summary_input, summary_input_lengths, title_input, title_input_lengths, target_variable, mask, max_target_len = val_batch
            loss = evaluate(attr_input, summary_input, summary_input_lengths,
                            title_input, title_input_lengths, target_variable,
                            mask, max_target_len, encoder1, encoder2, encoder3,
                            decoder, embedding, encoder1_optimizer,
                            encoder2_optimizer, encoder3_optimizer,
                            decoder_optimizer, batch_size)
            val_loss += loss
        val_loss /= len(val_batches)

        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch,
                                         (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
        print('-' * 89)
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            directory = os.path.join(save_dir, 'model',
                                     '{}_{}'.format(n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'epoch': epoch,
                    'en1': encoder1.state_dict(),
                    'en2': encoder2.state_dict(),
                    'en3': encoder3.state_dict(),
                    'de': decoder.state_dict(),
                    'en1_opt': encoder1_optimizer.state_dict(),
                    'en2_opt': encoder2_optimizer.state_dict(),
                    'en3_opt': encoder3_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory, '{}_{}.tar'.format(
                        epoch,
                        filename(reverse, 'lexicon_title_expansion_model'))))
            best_val_loss = val_loss

            # Run on test data.
            test_loss = 0
            for test_batch in test_batches:
                attr_input, summary_input, summary_input_lengths, title_input, title_input_lengths, target_variable, mask, max_target_len = test_batch
                loss = evaluate(attr_input, summary_input,
                                summary_input_lengths, title_input,
                                title_input_lengths, target_variable, mask,
                                max_target_len, encoder1, encoder2, encoder3,
                                decoder, embedding, encoder1_optimizer,
                                encoder2_optimizer, encoder3_optimizer,
                                decoder_optimizer, batch_size)
                test_loss += loss
            test_loss /= len(test_batches)
            print('-' * 89)
            print('| test loss {:5.2f} | test ppl {:8.2f}'.format(
                test_loss, math.exp(test_loss)))
            print('-' * 89)

        if val_loss > best_val_loss:
            break
Пример #10
0
def trainIters(corpus, reverse, n_iteration, learning_rate, batch_size, n_layers, hidden_size,
                print_every, save_every, dropout, loadFilename=None, attn_model='concat', decoder_learning_ratio=5.0):

    voc, pairs = loadPrepareData(corpus)
    random.shuffle(pairs)
    pairs_valid = pairs[-2000:]
    pairs = pairs[:-2000]

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    '''
    training_batches = None
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)], reverse)
                          for _ in range(n_iteration)]
        torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name,
                                                  '{}_{}_{}.tar'.format(n_iteration, \
                                                                        filename(reverse, 'training_batches'), \
                                                                        batch_size)))
    '''
    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    embedding = nn.Embedding(voc.n_words, hidden_size)
    encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers, dropout)
    attn_model = 'concat'
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers, dropout)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']

    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        # training_batch = training_batches[iteration - 1]
        training_batch = batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)], reverse)
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            print('%d %d%% %.4f' % (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0
            # *********************
            #     start valid
            # *********************
            valid_loss = 0
            for i in range(100):
                training_batch = batch2TrainData(voc, [random.choice(pairs_valid) for _ in range(batch_size)], reverse)
                input_variable, lengths, target_variable, mask, max_target_len = training_batch
                loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                             decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, valid=True)
                valid_loss += loss
            valid_loss_avg = math.exp(valid_loss / 100)
            print('valid loss %.4f' % valid_loss_avg)

        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, 'model', corpus_name, '{}-{}_{}'.format(n_layers, n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'plt': perplexity
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, filename(reverse, 'backup_bidir_model'))))
Пример #11
0
def trainIters(corpus,
               reverse,
               n_iteration,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               save_every,
               dropout,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):

    pinyin_voc, word_voc, tuples = loadPrepareData(corpus)

    # training data
    corpus_name = os.path.split(corpus)[-1].split('.')[0]
    training_batches = None
    try:
        training_batches = torch.load(os.path.join(save_dir, 'training_data', corpus_name,
                                                   '{}_{}_{}.tar'.format(n_iteration, \
                                                                         filename(reverse, 'training_batches'), \
                                                                         batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = [
            batch2TrainData(pinyin_voc, word_voc,
                            [random.choice(tuples)
                             for _ in range(batch_size)], reverse)
            for _ in range(n_iteration)
        ]
        torch.save(training_batches, os.path.join(save_dir, 'training_data', corpus_name,
                                                  '{}_{}_{}.tar'.format(n_iteration, \
                                                                        filename(reverse, 'training_batches'), \
                                                                        batch_size)))
    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    pinyin_embedding = nn.Embedding(pinyin_voc.n_words, hidden_size)
    word_embedding = nn.Embedding(word_voc.n_words, hidden_size)
    # 第一层Encoder,解码汉字
    encoder = EncoderRNN(word_voc.n_words, hidden_size, word_embedding,
                         n_layers, dropout)
    # 构建第二层Encoder,解码拼音
    encoder_second = EncoderRNN(pinyin_voc.n_words, hidden_size,
                                pinyin_embedding, n_layers, dropout)
    attn_model = 'dot'
    # 第一层decoder,解析拼音,基于注意力
    decoder = LuongAttnDecoderRNN(attn_model, pinyin_embedding, hidden_size,
                                  pinyin_voc.n_words, n_layers, dropout)
    # 构建第二层Decoder,解析汉字 ,先暂时不用注意力模型
    decoder_second = DecoderWithoutAttn(word_embedding, hidden_size,
                                        word_voc.n_words, n_layers, dropout)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        encoder_second.load_state_dict(checkpoint['en_sec'])
        decoder.load_state_dict(checkpoint['de'])
        decoder_second.load_state_dict(checkpoint['de_sec'])
    # use cuda
    # encoder = encoder.to(device)
    # decoder = decoder.to(device)

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    encoder_second_optimizer = optim.Adam(encoder_second.parameters(),
                                          lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    decoder_second_optimizer = optim.Adam(decoder_second.parameters(),
                                          lr=learning_rate *
                                          decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        encoder_second_optimizer.load_state_dict(checkpoint['en_sec_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])
        decoder_second_optimizer.load_state_dict(checkpoint['de_sec_opt'])

    # initialize
    print('Initializing ...')
    start_iteration = 1
    perplexity = []
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
        perplexity = checkpoint['plt']
    # 进度条显示
    for iteration in tqdm(range(start_iteration, n_iteration + 1)):
        # 得到当前iteration的数据
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable_pinyin, target_variable_word, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable_pinyin,
                     target_variable_word, mask, max_target_len, encoder,
                     decoder, pinyin_embedding, word_embedding,
                     encoder_optimizer, decoder_optimizer, batch_size)
        print_loss += loss
        perplexity.append(loss)

        if iteration % print_every == 0:
            print_loss_avg = math.exp(print_loss / print_every)
            print('%d %d%% %.4f' %
                  (iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(
                save_dir, 'model', corpus_name,
                '{}-{}_{}'.format(n_layers, n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iteration,
                    'en': encoder.state_dict(),
                    'en_sec': encoder_second.state_dict(),
                    'de': decoder.state_dict(),
                    'de_sec': decoder_second.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'en_sec_opt': encoder_second_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'de_sec_opt': decoder_second_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(iteration,
                                       filename(reverse,
                                                'backup_bidir_model'))))