Esempio n. 1
0
def load_model(saved_vae, stored_info, device, cache_path=str(Path('../tmp')), seed=None):
    stored_info = stored_info.split(os.sep)[-1]
    cache_file =  os.path.join(cache_path, stored_info)

    start_load = time.time()
    print(f"Fetching cached info at {cache_file}")
    with open(cache_file, "rb") as f:
        dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers = pickle.load(f)
    end_load = time.time()
    print(f"Cache {cache_file} loaded (load time: {end_load - start_load:.2f}s)")

    if os.path.exists(saved_vae):
        print(f"Found saved model {saved_vae}")
        start_load_model = time.time()

        e = model.EncoderRNN(dataset.input_side.n_words, encoder_hidden_size, z_size, n_encoder_layers, bidirectional=True)
        d = model.DecoderRNN(z_size, dataset.trn_split.n_conditions, condition_size, decoder_hidden_size, dataset.input_side.n_words, 1, word_dropout=0)
        vae = model.VAE(e, d).to(device)
        vae.load_state_dict(torch.load(saved_vae, map_location=lambda storage, loc: storage))
        vae.eval()
        print(f"Trained for {vae.steps_seen} steps (load time: {time.time() - start_load_model:.2f}s)")

        print("Setting new random seed")
        if seed is None:
            # TODO: torch.manual_seed(1999) in model.py is affecting this
            new_seed = int(time.time())
            new_seed = abs(new_seed) % 4294967295 # must be between 0 and 4294967295
        else:
            new_seed = seed
        torch.manual_seed(new_seed)

        random_state = np.random.RandomState(new_seed)
        #random_state.shuffle(dataset.trn_pairs)

    return vae, dataset, z_size, random_state
Esempio n. 2
0
def build_model_from_scratch(args, corpus, embed_df):
    print("Building the initial models")
    ntokens = len(corpus.dictionary)
    ntopic = args.ntopic
    iembedding_tensor = init_embedding(args.embed, corpus.dictionary, embed_df)
    Encoder = model.EncoderRNN(args.model, ntokens, args.embed, args.nhid,
                               args.nlayers, args.dropout, iembedding_tensor)
    ''' Regular classifier without any attention
    Classifier = model.AttentionClassifier(ntopic, args.nhid, args.hhid, args.cembed,  args.max_len, args.dropout)
    '''
    Classifier = model.AttentionClassifier(ntopic, args.nhid, args.hhid,
                                           args.cembed, args.max_len,
                                           args.dropout)

    if args.cuda:
        Encoder = Encoder.cuda()
        Classifier = Classifier.cuda()
    return Encoder, Classifier
Esempio n. 3
0
    def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2):
        super(NMT, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.vocab = vocab

        src_vocab_size = len(self.vocab.src.word2id)
        tgt_vocab_size = len(self.vocab.tgt.word2id)

        self.encoder = model.EncoderRNN(vocab_size=src_vocab_size,
                                        embed_size=self.embed_size,
                                        hidden_size=self.hidden_size)
        self.decoder = model.DecoderRNN(embed_size=self.embed_size,
                                        hidden_size=self.hidden_size,
                                        output_size=tgt_vocab_size)
        self.encoder = self.encoder.cuda()
        self.decoder = self.decoder.cuda()

        self.criterion = torch.nn.CrossEntropyLoss().cuda()
            bleu_total += bleu
        print_loss_total /= test_len
        bleu_total /= test_len
        print(f'Test loss: {print_loss_total}, bleu: {bleu_total}')
        
        with open(f'{latent_hidden_size}/train_loss', 'a') as f:
            f.write(f'{str(train_loss_total/tot_cnt)}\n')
        with open(f'{latent_hidden_size}/train_KL_loss', 'a') as f:
            f.write(f'{str(train_KL_total/tot_cnt)}\n')
        with open(f'{latent_hidden_size}/test_bleu', 'a') as f:
            f.write(f'{str(bleu_total)}\n')

        test_bleu_list.append(bleu_total)
        train_loss_list.append(train_loss_total/tot_cnt)
        train_KL_list.append(train_KL_total/tot_cnt)
        train_loss_total = 0
        train_KL_total = 0
        tot_cnt = 0

        if bleu_total > highest_score:
            highest_score = bleu_total
            torch.save(encoder, f'/home/karljackab/DL/lab5/{latent_hidden_size}/encoder_{str(bleu_total)}.pkl')
            torch.save(decoder, f'/home/karljackab/DL/lab5/{latent_hidden_size}/decoder_{str(bleu_total)}.pkl')
            torch.save(enc_last, f'/home/karljackab/DL/lab5/{latent_hidden_size}/enc_last_{str(bleu_total)}.pkl')
            print('save model')

enc_last = model.EncodeLast(hidden_size+4, latent_hidden_size, device).to(device)
encoder = model.EncoderRNN(vocab_size, hidden_size+4, device).to(device)
decoder = model.DecoderRNN(hidden_size+4, vocab_size, device).to(device)

trainIters(encoder, decoder, enc_last, 300, print_every=2000)
Esempio n. 5
0
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab,
                 dropout_rate,
                 num_layers,
                 bidirectional,
                 attention_type,
                 self_attention,
                 tau,
                 gamma1,
                 gamma2,
                 cost_fcn,
                 uniform_init,
                 embedding_file=None):

        super(NMT, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.vocab = vocab
        self.bidirectional = bidirectional
        self.tau = tau
        self.gamma1 = gamma1
        self.gamma2 = gamma2
        self.cost_fcn = cost_fcn
        src_vocab_size = len(self.vocab.src.word2id)
        tgt_vocab_size = len(self.vocab.tgt.word2id)

        if embedding_file is not None:

            Glove = {}
            f = open(embedding_file)
            print("Loading the vectors.")

            i = 0
            for line in f:
                if i != 0:
                    word, vec = line.split(' ', 1)
                    Glove[word] = np.fromstring(vec, sep=' ')
                i += 1
            f.close()

            print("Done.")
            X_train = np.zeros((len(self.vocab.src.id2word), self.embed_size))

            for i in range(len(self.vocab.src.id2word)):
                if self.vocab.src.id2word[i] in Glove:
                    X_train[i] = Glove[self.vocab.src.id2word[i]]

            embeddings = np.asarray(X_train)
        else:
            embeddings = None

        self.encoder = model.EncoderRNN(vocab_size=src_vocab_size,
                                        embed_size=self.embed_size,
                                        hidden_size=hidden_size,
                                        dropout_rate=dropout_rate,
                                        num_layers=num_layers,
                                        bidirectional=bidirectional,
                                        embeddings=embeddings)
        self.decoder = model.DecoderRNN(embed_size=self.embed_size,
                                        hidden_size=self.hidden_size,
                                        output_size=tgt_vocab_size,
                                        dropout_rate=dropout_rate,
                                        num_layers=num_layers,
                                        attention_type=attention_type,
                                        self_attention=self_attention,
                                        bidirectional=bidirectional)
        self.encoder = self.encoder.cuda()
        self.decoder = self.decoder.cuda()

        # Initialize all parameter weights uniformly
        for param in list(self.encoder.parameters()) + list(
                self.decoder.parameters()):
            torch.nn.init.uniform(param, a=-uniform_init, b=uniform_init)

        self.criterion = torch.nn.CrossEntropyLoss(reduce=0).cuda()
Esempio n. 6
0
    # If loading a model trained on GPU to CPU
    checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']

print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = model.EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = model.LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                    voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')


def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [dp.indexesFromSentence(voc, sentence)]
Esempio n. 7
0
def main(config):

    if(config.dataset == 'real'):
        #initialize the dictionary
        lang_real = prepare.Lang_real('txt')
        lines = open('data/opensubtitles/vocab4000').read().strip().split('\n')
        for sen in lines:
            lang_real.addSentence(sen)
        lang_txt = lang_real

        train_data = prepare.get_dataset('data/opensubtitles/train.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        shuffle(train_data)
        dev_data = prepare.get_dataset('data/opensubtitles/dev.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        test_data = prepare.get_dataset('data/opensubtitles/test.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        
    elif(config.dataset == 'counting'):
        lang_counting = prepare.Lang_counting('txt')
        lang_txt = lang_counting

        train_data = prepare.get_dataset('data/counting/train_counting.txt', batch_size = 16, lang_txt = lang_counting, task = 'counting')
        shuffle(train_data)
        dev_data = prepare.get_dataset('data/counting/dev_counting.txt', batch_size = 16, lang_txt = lang_counting, task = 'counting')
        test_data = prepare.get_dataset_test_counting('data/counting/test_counting.txt', batch_size = 16)

    feature = config.feature
    encoder = model.EncoderRNN(feature, feature, lang_txt.n_words)                     
    decoder = model.DecoderRNN(feature, feature,  lang_txt.n_words)
    evaluater = model.EvaluateR(feature)
    decoder_prev =  model.DecoderRNN(feature, feature,  lang_txt.n_words)
    encoder_prev =  model.EncoderRNN(feature, feature,  lang_txt.n_words)
    dis_encoder = model.disEncoderRNN(feature, feature, lang_txt.n_words)
    dis_decoder = model.disDecoderRNN(feature, feature, lang_txt.n_words)
    eva_encoder = model.disEncoderRNN(feature, feature, lang_txt.n_words)
    eva_decoder = model.disDecoderRNN(feature, feature, lang_txt.n_words)
    if use_cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        evaluater= evaluater.cuda()
        decoder_prev = decoder_prev.cuda()
        encoder_prev = encoder_prev.cuda()
        dis_encoder = dis_encoder.cuda(0)
        dis_decoder = dis_decoder.cuda(0)
        eva_encoder = eva_encoder.cuda(0)
        eva_decoder = eva_decoder.cuda(0)


    print_every = config.print_every
    dev_every = config.dev_every
    use_ppo = config.use_ppo
    ppo_a1 = config.ppo_a1
    ppo_a2 = config.ppo_a2
    ppo_b1 = config.ppo_b1
    ppo_b2 = config.ppo_b2

    if(config.type == 'reinforce'): 

        lr = config.lr
        test1 = train.seq2seq(lang_txt, dev_data,test_data,  encoder, decoder, evaluater, 
                   encoder_prev,decoder_prev,
                    task = config.dataset,
                   god_rs_dev = [],
                    god_loss_dev = [],
                    god_loss = [],
                    god_rs_test = [])

        losses, rewards = test1.trainIters(train_data,1,1,
                                           use_ppo = use_ppo,actor_fixed = False, 
                                           min_rein_step = 0, max_rein_step = 5,
                                          ppo_b1 = ppo_b1, 
                                          ppo_b2 = ppo_b2, 
                                          ppo_a1 = ppo_a1,
                                          ppo_a2 = ppo_a2,
                                           ppo_a3 = 1e10,
                                           rate = 1,
                                          lr = lr,
                                          dev_every = dev_every,
                                          print_every = print_every,
                                          plot_every = 5000000000,
                                          name = '_z',
                                          file_name = 'MIXER')
    elif(config.type == 'gan'):

        test_gan = train.ganSeq2seq(lang_txt, dev_data,test_data,
                encoder,decoder,
                dis_encoder,dis_decoder,
                eva_encoder, eva_decoder,
                encoder_prev,
                decoder_prev,
                god_rs_dev = [],
               god_loss_dev = [],
               god_loss = [],
               god_rs_test = [],
                task = config.dataset)

        loss_g, loss_d = test_gan.trainIters(train_data, 0,0,1,
                                  use_ppo= config.use_ppo,g_lr = config.g_lr, d_lr = config.d_lr,
                                  search_n = 1, width = 1,
                                 ppo_b1 = ppo_b1, 
                                  ppo_b2 = ppo_b2, 
                                  ppo_a1 = ppo_a1,
                                  ppo_a2 = ppo_a2, 
                                  ppo_a3 = 10000000000,
                                 print_every = print_every,
                                 plot_every = 50000000000,
                                 dev_every = dev_every)
Esempio n. 8
0
        decoder_input = Variable(torch.LongTensor([[ni]]))
        decoder_input = decoder_input.cuda() if use_cuda else decoder_input

    return decoded_words


def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')


hidden_dim = 256
embedding_dim = 100

encoder_1 = model.EncoderRNN(input_lang.n_words, embedding_dim, hidden_dim)
decoder_1 = model.DecoderRNN(output_lang.n_words, embedding_dim, hidden_dim)
#attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, 1, dropout_p=0.1)

if use_cuda:
    encoder_1 = encoder_1.cuda()
    decoder_1 = decoder_1.cuda()
    #attn_decoder1 = attn_decoder1.cuda()

trainIters(encoder_1, decoder_1, 75000)
Esempio n. 9
0
                   iter / n_iters * 100, print_loss_avg))

            print_rec_total = print_rec_total / print_every
            print_kl_total = print_kl_total / print_every
            print('average kl =  %.4f' % print_kl_total)
            print('average reconstruction =  %.4f' % print_rec_total)

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
            print_kl_total = 0
            print_rec_total = 0


encoder = model.EncoderRNN(vocabulary.n_words, latent_space,
                           embeddings).to(device)
decoder = model.DecoderRNN(embedding_space, embeddings,
                           vocabulary.n_words).to(device)
linear = model.RGB_to_Hidden(latent_space, embedding_space).to(device)

trainIters(encoder,
           decoder,
           linear,
           epochs,
           plot_every=500,
           print_every=500,
           learning_rate=learning_r)

if SAVE:
    dirpath = os.getcwd()
    encoder_path = dirpath + '/enc'
Esempio n. 10
0
            loss = train(input_tensor, target_tensor, encoder,
                         decoders[train_file], encoder_optimizer,
                         decoder_optimizers[train_file], criterion)
            print_loss_total += loss
            plot_loss_total += loss

            if iter % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                print('%s (%d %d%%) %.4f' %
                      (timeSince(start, iter / n_iters), iter,
                       iter / n_iters * 100, print_loss_avg))

            if iter % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0

    showPlot(plot_losses)


hidden_size = 256
encoder1 = model.EncoderRNN(vocab.n_words, hidden_size, device).to(device)
decoders = {}
for train_file in TRAINING_DATA_FILES:
    decoders[train_file] = model.Decoder1RNN(hidden_size, vocab.n_words,
                                             device).to(device)

trainIters(encoder1, decoders, 500, print_every=5000)
Esempio n. 11
0

def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        input_sentence)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    input_lang, output_lang, pairList = p.prepareData('eng', 'fra', True)
    hidden_size = 256
    encoder1 = m.EncoderRNN(input_lang.n_words, hidden_size).to(device)
    attn_decoder1 = m.AttnDecoderRNN(hidden_size,
                                     output_lang.n_words,
                                     dropout_p=0.1).to(device)

    trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

    evaluateRandomly(encoder1, attn_decoder1)
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        "je suis trop froid .")
    plt.matshow(attentions.numpy())

    evaluateAndShowAttention("elle a cinq ans de moins que moi .")
    evaluateAndShowAttention("elle est trop petit .")
    evaluateAndShowAttention("je ne crains pas de mourir .")
    evaluateAndShowAttention("c est un jeune directeur plein de talent .")
Esempio n. 12
0
    relations = set(tmp)
    print(relations)
    relation_count = len(
        relations)  # args.relation_tag_size  # data['relation_tag_size']
    noisy_count = args.noisy_tag_size  # ata['noisy_tag_size']
    learning_rate = args.lr  # data['lr']
    l2 = args.l2  # data['l2']
    print("relation count: ", relation_count)
    print("Reading vector file......")
    vec_model = KeyedVectors.load_word2vec_format(args.datapath +
                                                  'vector2.txt',
                                                  binary=False)
    # vec_model = KeyedVectors.load_word2vec_format('/home/xiaoya/data/GoogleNews-vectors-negative300.bin.gz', binary=True)

    # load models
    encoder = model.EncoderRNN(args, wv).to(device)
    decoder = model.DecoderRNN(args, wv).to(device)
    RE_model = model.RE_RNN(args, wv, relation_count).to(device)

    criterion = nn.NLLLoss()  # CrossEntropyLoss()
    # criterion_RE = nn.BCELoss()
    # attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        RE_model = RE_model.cuda()
        criterion = criterion.cuda()
        # criterion_RE = criterion_RE.cuda()

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=learning_rate,
        os.path.join(args.save, 'decoder' + str(number) + '.pt'))
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        context = context.cuda()
        decoder = decoder.cuda()

    for dialog in validation_data:
        sample(my_lang, dialog, encoder, context, decoder)
        time.sleep(3)

    sys.exit(0)

learning_rate = args.lr
criterion = nn.NLLLoss()
if not args.restore:
    encoder = model.EncoderRNN(len(my_lang.word2index), args.encoder_hidden, \
            args.encoder_layer, args.dropout)
    context = model.ContextRNN(args.encoder_hidden * args.encoder_layer, args.context_hidden, \
            args.context_layer, args.dropout)
    decoder = model.DecoderRNN(args.context_hidden * args.context_layer, args.decoder_hidden, \
            len(my_lang.word2index), args.decoder_layer, args.dropout)
else:
    print("Load last model in %s" % (args.save))
    number = torch.load(os.path.join(args.save, 'checkpoint.pt'))
    encoder = torch.load(
        os.path.join(args.save, 'encoder' + str(number) + '.pt'))
    context = torch.load(
        os.path.join(args.save, 'context' + str(number) + '.pt'))
    decoder = torch.load(
        os.path.join(args.save, 'decoder' + str(number) + '.pt'))
    if torch.cuda.is_available():
        encoder = encoder.cuda()
Esempio n. 14
0
    training_pairs = [
        tensor_from_pairs(src_lang, tar_lang, random.choice(pairs))
        for _ in range(iters)
    ]
    criterion = nn.NLLLoss()

    for iter in range(iters):
        train_pair = training_pairs[iter]
        input_tensor = train_pair[0]
        target_tensor = train_pair[1]

        loss = bch_train(input_tensor, target_tensor, encoder, decoder,
                         e_optim, d_optim, criterion)
        step_loss.append(loss)

        if iter % 20 == 0:
            print(f'Loss:\t{loss}')


if __name__ == '__main__':
    lines = utils.readCorpus("./data/train.txt")
    ori_lang, tar_lang, pairs = utils.readLang(lines)
    hidden_size = 256
    encoder = model.EncoderRNN(ori_lang.n_words, hidden_size).to(model.device)
    # decoder = model.DecoderRNN(hidden_size, tar_lang.n_words)
    # decoder = model.BahdanauDecoderRNN(hidden_size, tar_lang.n_words).to(model.device)
    decoder = model.LuongDecoderRNN(hidden_size,
                                    tar_lang.n_words,
                                    attention_method="concat").to(model.device)
    train(ori_lang, tar_lang, pairs, encoder, decoder)
Esempio n. 15
0
        for i in range(n):
            output_words = evaluate(encoder1, encoder2, decoder, d['image'][i],
                                    d['post'][i], d['tags'][i])
            output_sentence = ' '.join(output_words)
            print('ground truth:', d['comment'][i])
            print('generated:', output_sentence)
        break


def evaluateScore(encoder, decoder, weights):
    val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)
    total_score = 0
    for d in val_loader:
        #for i in range(n):
        output_words = evaluate(encoder, decoder, d['image'][0], d['post'][0],
                                d['tags'][0])

        score = sentence_bleu([d['comment'][0].split(' ')],
                              output_words,
                              weights=weights)
        total_score += score
    return float(total_score) / val_data_size


encoder1 = model.EncoderRNN(300, post_hidden_size).to(device)
encoder2 = model.Encoder(input_size, final_hidden_size).to(device)
decoder = model.DecoderRNN(final_hidden_size, vocab.n_words).to(device)

trainIters(encoder1, encoder2, decoder, learning_rate=0.0001)
evaluateRandomly(encoder1, encoder2, decoder, 'val', 10)
Esempio n. 16
0
def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        input_sentence)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions)


if __name__ == '__main__':
    input_lang, output_lang, pairs = data.prepareData('eng', 'fra', True)
    print(random.choice(pairs))

    teacher_forcing_ratio = 0.5

    hidden_size = 256
    encoder1 = model.EncoderRNN(input_lang.n_words, hidden_size,
                                args.device).to(args.device)
    attn_decoder1 = model.AttnDecoderRNN(hidden_size, output_lang.n_words,
                                         args.device, 0.1,
                                         args.MAX_LENGTH).to(args.device)

    trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

    ######################################################################

    evaluateRandomly(encoder1, attn_decoder1)

    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        "je suis trop froid .")
    plt.matshow(attentions.numpy())

    evaluateAndShowAttention("elle a cinq ans de moins que moi .")
def main(args):

    # random set
    manualSeed = random.randint(1, 100)
    # print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    audio_len, comment_len, mfcc_dim = caculate_max_len(
        args.audio_dir, args.text_path, vocab)
    # mfcc_features = audio_preprocess(args.audio_dir, N, AUDIO_LEN, MFCC_DIM).astype(np.float32)

    # Build data loader
    data_loader = data_get(args.audio_dir, audio_len, args.text_path,
                           comment_len, vocab)

    # Build the models
    encoder = model.EncoderRNN(mfcc_dim, args.embed_size,
                               args.hidden_size).to(device)
    decoder = model.DecoderRNN(args.embed_size + Z_DIM, args.hidden_size,
                               len(vocab), args.num_layers).to(device)
    # decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

    # Loss and optimizer
    criterion_BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
    criterion_CrossEntropyLoss = nn.CrossEntropyLoss()

    # Loss and optimizer
    # criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # GAN                               #296'''in_dim=len(vocab)'''
    netD = model.LSTMDiscriminator(in_dim=1, hidden_dim=256).to(device)
    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, ((audio, audio_len), (comment,
                                     comment_len)) in enumerate(data_loader):
            audio = audio.to(device)
            audio = audio.unsqueeze(0)
            comment = comment.to(device)
            comment = comment.unsqueeze(0)
            targets = pack_padded_sequence(comment, [comment_len],
                                           batch_first=True)[0]

            batch_size = comment.shape[0]
            seq_len = targets.shape[0]
            # discriminator:1 -- real comment
            label0 = torch.full((batch_size, seq_len, 1), 0, device=device)
            label1 = torch.full((batch_size, seq_len, 1), 1, device=device)
            # real sample
            logits_real = netD(comment, [comment_len])  # batch*seq
            errD_real = criterion_BCEWithLogitsLoss(logits_real, label1)

            # discriminator:2 -- real comment
            audio_features = encoder(audio, [audio_len])
            if (Z_DIM > 0):
                z = Variable(torch.randn(audio_features.shape[0],
                                         Z_DIM)).cuda()
                audio_features = torch.cat([z, audio_features], 1)
            outputs = decoder(audio_features, comment, [comment_len])
            # generate comment discrimination
            max_v, max_index = outputs.detach().max(1)
            logits_fake = netD(max_index.unsqueeze(0),
                               [comment_len])  # batch*seq*1
            errD_fake = criterion_BCEWithLogitsLoss(logits_fake, label0)
            errD = errD_fake + errD_real
            optimizerD.zero_grad()
            errD.backward()
            optimizerD.step()

            # 2.generator
            audio_features = encoder(audio, [audio_len])
            if (Z_DIM > 0):
                z = Variable(torch.randn(audio_features.shape[0],
                                         Z_DIM)).cuda()
                audio_features = torch.cat([z, audio_features], 1)
            outputs = decoder(audio_features, comment, [comment_len])
            max_v, max_index = outputs.max(1)
            logits_fake = netD(max_index.unsqueeze(0),
                               [comment_len])  # batch*seq*vobsize
            errG = criterion_BCEWithLogitsLoss(logits_fake, label1)
            loss = criterion_CrossEntropyLoss(outputs, targets) + errG
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}],  Loss_D:  {:.4f}, Loss_G: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, errD.item(),
                            loss.item(), np.exp(loss.item())))

            # Save the model checkpoints
        if (epoch + 1) % args.save_step == 0:
            torch.save(
                decoder.state_dict(),
                os.path.join(args.model_path,
                             'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(args.model_path,
                             'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))