コード例 #1
0
def load_model(saved_vae, stored_info, device, cache_path=str(Path('../tmp')), seed=None):
    stored_info = stored_info.split(os.sep)[-1]
    cache_file =  os.path.join(cache_path, stored_info)

    start_load = time.time()
    print(f"Fetching cached info at {cache_file}")
    with open(cache_file, "rb") as f:
        dataset, z_size, condition_size, condition_on, decoder_hidden_size, encoder_hidden_size, n_encoder_layers = pickle.load(f)
    end_load = time.time()
    print(f"Cache {cache_file} loaded (load time: {end_load - start_load:.2f}s)")

    if os.path.exists(saved_vae):
        print(f"Found saved model {saved_vae}")
        start_load_model = time.time()

        e = model.EncoderRNN(dataset.input_side.n_words, encoder_hidden_size, z_size, n_encoder_layers, bidirectional=True)
        d = model.DecoderRNN(z_size, dataset.trn_split.n_conditions, condition_size, decoder_hidden_size, dataset.input_side.n_words, 1, word_dropout=0)
        vae = model.VAE(e, d).to(device)
        vae.load_state_dict(torch.load(saved_vae, map_location=lambda storage, loc: storage))
        vae.eval()
        print(f"Trained for {vae.steps_seen} steps (load time: {time.time() - start_load_model:.2f}s)")

        print("Setting new random seed")
        if seed is None:
            # TODO: torch.manual_seed(1999) in model.py is affecting this
            new_seed = int(time.time())
            new_seed = abs(new_seed) % 4294967295 # must be between 0 and 4294967295
        else:
            new_seed = seed
        torch.manual_seed(new_seed)

        random_state = np.random.RandomState(new_seed)
        #random_state.shuffle(dataset.trn_pairs)

    return vae, dataset, z_size, random_state
コード例 #2
0
ファイル: utils.py プロジェクト: parth126/NeuralSumm
def build_model_from_scratch(args, corpus, embed_df):
    print("Building the initial models")
    ntokens = len(corpus.dictionary)
    ntopic = args.ntopic
    iembedding_tensor = init_embedding(args.embed, corpus.dictionary, embed_df)
    Encoder = model.EncoderRNN(args.model, ntokens, args.embed, args.nhid,
                               args.nlayers, args.dropout, iembedding_tensor)
    ''' Regular classifier without any attention
    Classifier = model.AttentionClassifier(ntopic, args.nhid, args.hhid, args.cembed,  args.max_len, args.dropout)
    '''
    Classifier = model.AttentionClassifier(ntopic, args.nhid, args.hhid,
                                           args.cembed, args.max_len,
                                           args.dropout)

    if args.cuda:
        Encoder = Encoder.cuda()
        Classifier = Classifier.cuda()
    return Encoder, Classifier
コード例 #3
0
ファイル: nmt.py プロジェクト: vaibhav4595/Seq2Seq_MT
    def __init__(self, embed_size, hidden_size, vocab, dropout_rate=0.2):
        super(NMT, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.vocab = vocab

        src_vocab_size = len(self.vocab.src.word2id)
        tgt_vocab_size = len(self.vocab.tgt.word2id)

        self.encoder = model.EncoderRNN(vocab_size=src_vocab_size,
                                        embed_size=self.embed_size,
                                        hidden_size=self.hidden_size)
        self.decoder = model.DecoderRNN(embed_size=self.embed_size,
                                        hidden_size=self.hidden_size,
                                        output_size=tgt_vocab_size)
        self.encoder = self.encoder.cuda()
        self.decoder = self.decoder.cuda()

        self.criterion = torch.nn.CrossEntropyLoss().cuda()
コード例 #4
0
            bleu_total += bleu
        print_loss_total /= test_len
        bleu_total /= test_len
        print(f'Test loss: {print_loss_total}, bleu: {bleu_total}')
        
        with open(f'{latent_hidden_size}/train_loss', 'a') as f:
            f.write(f'{str(train_loss_total/tot_cnt)}\n')
        with open(f'{latent_hidden_size}/train_KL_loss', 'a') as f:
            f.write(f'{str(train_KL_total/tot_cnt)}\n')
        with open(f'{latent_hidden_size}/test_bleu', 'a') as f:
            f.write(f'{str(bleu_total)}\n')

        test_bleu_list.append(bleu_total)
        train_loss_list.append(train_loss_total/tot_cnt)
        train_KL_list.append(train_KL_total/tot_cnt)
        train_loss_total = 0
        train_KL_total = 0
        tot_cnt = 0

        if bleu_total > highest_score:
            highest_score = bleu_total
            torch.save(encoder, f'/home/karljackab/DL/lab5/{latent_hidden_size}/encoder_{str(bleu_total)}.pkl')
            torch.save(decoder, f'/home/karljackab/DL/lab5/{latent_hidden_size}/decoder_{str(bleu_total)}.pkl')
            torch.save(enc_last, f'/home/karljackab/DL/lab5/{latent_hidden_size}/enc_last_{str(bleu_total)}.pkl')
            print('save model')

enc_last = model.EncodeLast(hidden_size+4, latent_hidden_size, device).to(device)
encoder = model.EncoderRNN(vocab_size, hidden_size+4, device).to(device)
decoder = model.DecoderRNN(hidden_size+4, vocab_size, device).to(device)

trainIters(encoder, decoder, enc_last, 300, print_every=2000)
コード例 #5
0
    def __init__(self,
                 embed_size,
                 hidden_size,
                 vocab,
                 dropout_rate,
                 num_layers,
                 bidirectional,
                 attention_type,
                 self_attention,
                 tau,
                 gamma1,
                 gamma2,
                 cost_fcn,
                 uniform_init,
                 embedding_file=None):

        super(NMT, self).__init__()

        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.vocab = vocab
        self.bidirectional = bidirectional
        self.tau = tau
        self.gamma1 = gamma1
        self.gamma2 = gamma2
        self.cost_fcn = cost_fcn
        src_vocab_size = len(self.vocab.src.word2id)
        tgt_vocab_size = len(self.vocab.tgt.word2id)

        if embedding_file is not None:

            Glove = {}
            f = open(embedding_file)
            print("Loading the vectors.")

            i = 0
            for line in f:
                if i != 0:
                    word, vec = line.split(' ', 1)
                    Glove[word] = np.fromstring(vec, sep=' ')
                i += 1
            f.close()

            print("Done.")
            X_train = np.zeros((len(self.vocab.src.id2word), self.embed_size))

            for i in range(len(self.vocab.src.id2word)):
                if self.vocab.src.id2word[i] in Glove:
                    X_train[i] = Glove[self.vocab.src.id2word[i]]

            embeddings = np.asarray(X_train)
        else:
            embeddings = None

        self.encoder = model.EncoderRNN(vocab_size=src_vocab_size,
                                        embed_size=self.embed_size,
                                        hidden_size=hidden_size,
                                        dropout_rate=dropout_rate,
                                        num_layers=num_layers,
                                        bidirectional=bidirectional,
                                        embeddings=embeddings)
        self.decoder = model.DecoderRNN(embed_size=self.embed_size,
                                        hidden_size=self.hidden_size,
                                        output_size=tgt_vocab_size,
                                        dropout_rate=dropout_rate,
                                        num_layers=num_layers,
                                        attention_type=attention_type,
                                        self_attention=self_attention,
                                        bidirectional=bidirectional)
        self.encoder = self.encoder.cuda()
        self.decoder = self.decoder.cuda()

        # Initialize all parameter weights uniformly
        for param in list(self.encoder.parameters()) + list(
                self.decoder.parameters()):
            torch.nn.init.uniform(param, a=-uniform_init, b=uniform_init)

        self.criterion = torch.nn.CrossEntropyLoss(reduce=0).cuda()
コード例 #6
0
ファイル: app.py プロジェクト: sakshi148/Chatbot
    # If loading a model trained on GPU to CPU
    checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']

print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = model.EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = model.LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                    voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')


def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [dp.indexesFromSentence(voc, sentence)]
コード例 #7
0
ファイル: main.py プロジェクト: sparse-mvs-2/PPO_dynamic
def main(config):

    if(config.dataset == 'real'):
        #initialize the dictionary
        lang_real = prepare.Lang_real('txt')
        lines = open('data/opensubtitles/vocab4000').read().strip().split('\n')
        for sen in lines:
            lang_real.addSentence(sen)
        lang_txt = lang_real

        train_data = prepare.get_dataset('data/opensubtitles/train.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        shuffle(train_data)
        dev_data = prepare.get_dataset('data/opensubtitles/dev.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        test_data = prepare.get_dataset('data/opensubtitles/test.txt', batch_size = 16, lang_txt = lang_real, task = 'real')
        
    elif(config.dataset == 'counting'):
        lang_counting = prepare.Lang_counting('txt')
        lang_txt = lang_counting

        train_data = prepare.get_dataset('data/counting/train_counting.txt', batch_size = 16, lang_txt = lang_counting, task = 'counting')
        shuffle(train_data)
        dev_data = prepare.get_dataset('data/counting/dev_counting.txt', batch_size = 16, lang_txt = lang_counting, task = 'counting')
        test_data = prepare.get_dataset_test_counting('data/counting/test_counting.txt', batch_size = 16)

    feature = config.feature
    encoder = model.EncoderRNN(feature, feature, lang_txt.n_words)                     
    decoder = model.DecoderRNN(feature, feature,  lang_txt.n_words)
    evaluater = model.EvaluateR(feature)
    decoder_prev =  model.DecoderRNN(feature, feature,  lang_txt.n_words)
    encoder_prev =  model.EncoderRNN(feature, feature,  lang_txt.n_words)
    dis_encoder = model.disEncoderRNN(feature, feature, lang_txt.n_words)
    dis_decoder = model.disDecoderRNN(feature, feature, lang_txt.n_words)
    eva_encoder = model.disEncoderRNN(feature, feature, lang_txt.n_words)
    eva_decoder = model.disDecoderRNN(feature, feature, lang_txt.n_words)
    if use_cuda:
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        evaluater= evaluater.cuda()
        decoder_prev = decoder_prev.cuda()
        encoder_prev = encoder_prev.cuda()
        dis_encoder = dis_encoder.cuda(0)
        dis_decoder = dis_decoder.cuda(0)
        eva_encoder = eva_encoder.cuda(0)
        eva_decoder = eva_decoder.cuda(0)


    print_every = config.print_every
    dev_every = config.dev_every
    use_ppo = config.use_ppo
    ppo_a1 = config.ppo_a1
    ppo_a2 = config.ppo_a2
    ppo_b1 = config.ppo_b1
    ppo_b2 = config.ppo_b2

    if(config.type == 'reinforce'): 

        lr = config.lr
        test1 = train.seq2seq(lang_txt, dev_data,test_data,  encoder, decoder, evaluater, 
                   encoder_prev,decoder_prev,
                    task = config.dataset,
                   god_rs_dev = [],
                    god_loss_dev = [],
                    god_loss = [],
                    god_rs_test = [])

        losses, rewards = test1.trainIters(train_data,1,1,
                                           use_ppo = use_ppo,actor_fixed = False, 
                                           min_rein_step = 0, max_rein_step = 5,
                                          ppo_b1 = ppo_b1, 
                                          ppo_b2 = ppo_b2, 
                                          ppo_a1 = ppo_a1,
                                          ppo_a2 = ppo_a2,
                                           ppo_a3 = 1e10,
                                           rate = 1,
                                          lr = lr,
                                          dev_every = dev_every,
                                          print_every = print_every,
                                          plot_every = 5000000000,
                                          name = '_z',
                                          file_name = 'MIXER')
    elif(config.type == 'gan'):

        test_gan = train.ganSeq2seq(lang_txt, dev_data,test_data,
                encoder,decoder,
                dis_encoder,dis_decoder,
                eva_encoder, eva_decoder,
                encoder_prev,
                decoder_prev,
                god_rs_dev = [],
               god_loss_dev = [],
               god_loss = [],
               god_rs_test = [],
                task = config.dataset)

        loss_g, loss_d = test_gan.trainIters(train_data, 0,0,1,
                                  use_ppo= config.use_ppo,g_lr = config.g_lr, d_lr = config.d_lr,
                                  search_n = 1, width = 1,
                                 ppo_b1 = ppo_b1, 
                                  ppo_b2 = ppo_b2, 
                                  ppo_a1 = ppo_a1,
                                  ppo_a2 = ppo_a2, 
                                  ppo_a3 = 10000000000,
                                 print_every = print_every,
                                 plot_every = 50000000000,
                                 dev_every = dev_every)
コード例 #8
0
        decoder_input = Variable(torch.LongTensor([[ni]]))
        decoder_input = decoder_input.cuda() if use_cuda else decoder_input

    return decoded_words


def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')


hidden_dim = 256
embedding_dim = 100

encoder_1 = model.EncoderRNN(input_lang.n_words, embedding_dim, hidden_dim)
decoder_1 = model.DecoderRNN(output_lang.n_words, embedding_dim, hidden_dim)
#attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, 1, dropout_p=0.1)

if use_cuda:
    encoder_1 = encoder_1.cuda()
    decoder_1 = decoder_1.cuda()
    #attn_decoder1 = attn_decoder1.cuda()

trainIters(encoder_1, decoder_1, 75000)
コード例 #9
0
ファイル: main.py プロジェクト: hamdans-eth/colors
                   iter / n_iters * 100, print_loss_avg))

            print_rec_total = print_rec_total / print_every
            print_kl_total = print_kl_total / print_every
            print('average kl =  %.4f' % print_kl_total)
            print('average reconstruction =  %.4f' % print_rec_total)

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
            print_kl_total = 0
            print_rec_total = 0


encoder = model.EncoderRNN(vocabulary.n_words, latent_space,
                           embeddings).to(device)
decoder = model.DecoderRNN(embedding_space, embeddings,
                           vocabulary.n_words).to(device)
linear = model.RGB_to_Hidden(latent_space, embedding_space).to(device)

trainIters(encoder,
           decoder,
           linear,
           epochs,
           plot_every=500,
           print_every=500,
           learning_rate=learning_r)

if SAVE:
    dirpath = os.getcwd()
    encoder_path = dirpath + '/enc'
コード例 #10
0
ファイル: training.py プロジェクト: omrimas/StyleTransfer
            loss = train(input_tensor, target_tensor, encoder,
                         decoders[train_file], encoder_optimizer,
                         decoder_optimizers[train_file], criterion)
            print_loss_total += loss
            plot_loss_total += loss

            if iter % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                print('%s (%d %d%%) %.4f' %
                      (timeSince(start, iter / n_iters), iter,
                       iter / n_iters * 100, print_loss_avg))

            if iter % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0

    showPlot(plot_losses)


hidden_size = 256
encoder1 = model.EncoderRNN(vocab.n_words, hidden_size, device).to(device)
decoders = {}
for train_file in TRAINING_DATA_FILES:
    decoders[train_file] = model.Decoder1RNN(hidden_size, vocab.n_words,
                                             device).to(device)

trainIters(encoder1, decoders, 500, print_every=5000)
コード例 #11
0

def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        input_sentence)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    input_lang, output_lang, pairList = p.prepareData('eng', 'fra', True)
    hidden_size = 256
    encoder1 = m.EncoderRNN(input_lang.n_words, hidden_size).to(device)
    attn_decoder1 = m.AttnDecoderRNN(hidden_size,
                                     output_lang.n_words,
                                     dropout_p=0.1).to(device)

    trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

    evaluateRandomly(encoder1, attn_decoder1)
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        "je suis trop froid .")
    plt.matshow(attentions.numpy())

    evaluateAndShowAttention("elle a cinq ans de moins que moi .")
    evaluateAndShowAttention("elle est trop petit .")
    evaluateAndShowAttention("je ne crains pas de mourir .")
    evaluateAndShowAttention("c est un jeune directeur plein de talent .")
コード例 #12
0
ファイル: main_LSTM.py プロジェクト: livelifeyiyi/DLworks
    relations = set(tmp)
    print(relations)
    relation_count = len(
        relations)  # args.relation_tag_size  # data['relation_tag_size']
    noisy_count = args.noisy_tag_size  # ata['noisy_tag_size']
    learning_rate = args.lr  # data['lr']
    l2 = args.l2  # data['l2']
    print("relation count: ", relation_count)
    print("Reading vector file......")
    vec_model = KeyedVectors.load_word2vec_format(args.datapath +
                                                  'vector2.txt',
                                                  binary=False)
    # vec_model = KeyedVectors.load_word2vec_format('/home/xiaoya/data/GoogleNews-vectors-negative300.bin.gz', binary=True)

    # load models
    encoder = model.EncoderRNN(args, wv).to(device)
    decoder = model.DecoderRNN(args, wv).to(device)
    RE_model = model.RE_RNN(args, wv, relation_count).to(device)

    criterion = nn.NLLLoss()  # CrossEntropyLoss()
    # criterion_RE = nn.BCELoss()
    # attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()
        RE_model = RE_model.cuda()
        criterion = criterion.cuda()
        # criterion_RE = criterion_RE.cuda()

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=learning_rate,
        os.path.join(args.save, 'decoder' + str(number) + '.pt'))
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        context = context.cuda()
        decoder = decoder.cuda()

    for dialog in validation_data:
        sample(my_lang, dialog, encoder, context, decoder)
        time.sleep(3)

    sys.exit(0)

learning_rate = args.lr
criterion = nn.NLLLoss()
if not args.restore:
    encoder = model.EncoderRNN(len(my_lang.word2index), args.encoder_hidden, \
            args.encoder_layer, args.dropout)
    context = model.ContextRNN(args.encoder_hidden * args.encoder_layer, args.context_hidden, \
            args.context_layer, args.dropout)
    decoder = model.DecoderRNN(args.context_hidden * args.context_layer, args.decoder_hidden, \
            len(my_lang.word2index), args.decoder_layer, args.dropout)
else:
    print("Load last model in %s" % (args.save))
    number = torch.load(os.path.join(args.save, 'checkpoint.pt'))
    encoder = torch.load(
        os.path.join(args.save, 'encoder' + str(number) + '.pt'))
    context = torch.load(
        os.path.join(args.save, 'context' + str(number) + '.pt'))
    decoder = torch.load(
        os.path.join(args.save, 'decoder' + str(number) + '.pt'))
    if torch.cuda.is_available():
        encoder = encoder.cuda()
コード例 #14
0
    training_pairs = [
        tensor_from_pairs(src_lang, tar_lang, random.choice(pairs))
        for _ in range(iters)
    ]
    criterion = nn.NLLLoss()

    for iter in range(iters):
        train_pair = training_pairs[iter]
        input_tensor = train_pair[0]
        target_tensor = train_pair[1]

        loss = bch_train(input_tensor, target_tensor, encoder, decoder,
                         e_optim, d_optim, criterion)
        step_loss.append(loss)

        if iter % 20 == 0:
            print(f'Loss:\t{loss}')


if __name__ == '__main__':
    lines = utils.readCorpus("./data/train.txt")
    ori_lang, tar_lang, pairs = utils.readLang(lines)
    hidden_size = 256
    encoder = model.EncoderRNN(ori_lang.n_words, hidden_size).to(model.device)
    # decoder = model.DecoderRNN(hidden_size, tar_lang.n_words)
    # decoder = model.BahdanauDecoderRNN(hidden_size, tar_lang.n_words).to(model.device)
    decoder = model.LuongDecoderRNN(hidden_size,
                                    tar_lang.n_words,
                                    attention_method="concat").to(model.device)
    train(ori_lang, tar_lang, pairs, encoder, decoder)
コード例 #15
0
        for i in range(n):
            output_words = evaluate(encoder1, encoder2, decoder, d['image'][i],
                                    d['post'][i], d['tags'][i])
            output_sentence = ' '.join(output_words)
            print('ground truth:', d['comment'][i])
            print('generated:', output_sentence)
        break


def evaluateScore(encoder, decoder, weights):
    val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)
    total_score = 0
    for d in val_loader:
        #for i in range(n):
        output_words = evaluate(encoder, decoder, d['image'][0], d['post'][0],
                                d['tags'][0])

        score = sentence_bleu([d['comment'][0].split(' ')],
                              output_words,
                              weights=weights)
        total_score += score
    return float(total_score) / val_data_size


encoder1 = model.EncoderRNN(300, post_hidden_size).to(device)
encoder2 = model.Encoder(input_size, final_hidden_size).to(device)
decoder = model.DecoderRNN(final_hidden_size, vocab.n_words).to(device)

trainIters(encoder1, encoder2, decoder, learning_rate=0.0001)
evaluateRandomly(encoder1, encoder2, decoder, 'val', 10)
コード例 #16
0
ファイル: main.py プロジェクト: thzll2001/EasyNLP
def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        input_sentence)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions)


if __name__ == '__main__':
    input_lang, output_lang, pairs = data.prepareData('eng', 'fra', True)
    print(random.choice(pairs))

    teacher_forcing_ratio = 0.5

    hidden_size = 256
    encoder1 = model.EncoderRNN(input_lang.n_words, hidden_size,
                                args.device).to(args.device)
    attn_decoder1 = model.AttnDecoderRNN(hidden_size, output_lang.n_words,
                                         args.device, 0.1,
                                         args.MAX_LENGTH).to(args.device)

    trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

    ######################################################################

    evaluateRandomly(encoder1, attn_decoder1)

    output_words, attentions = evaluate(encoder1, attn_decoder1,
                                        "je suis trop froid .")
    plt.matshow(attentions.numpy())

    evaluateAndShowAttention("elle a cinq ans de moins que moi .")
def main(args):

    # random set
    manualSeed = random.randint(1, 100)
    # print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    audio_len, comment_len, mfcc_dim = caculate_max_len(
        args.audio_dir, args.text_path, vocab)
    # mfcc_features = audio_preprocess(args.audio_dir, N, AUDIO_LEN, MFCC_DIM).astype(np.float32)

    # Build data loader
    data_loader = data_get(args.audio_dir, audio_len, args.text_path,
                           comment_len, vocab)

    # Build the models
    encoder = model.EncoderRNN(mfcc_dim, args.embed_size,
                               args.hidden_size).to(device)
    decoder = model.DecoderRNN(args.embed_size + Z_DIM, args.hidden_size,
                               len(vocab), args.num_layers).to(device)
    # decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

    # Loss and optimizer
    criterion_BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
    criterion_CrossEntropyLoss = nn.CrossEntropyLoss()

    # Loss and optimizer
    # criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # GAN                               #296'''in_dim=len(vocab)'''
    netD = model.LSTMDiscriminator(in_dim=1, hidden_dim=256).to(device)
    # setup optimizer
    optimizerD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, ((audio, audio_len), (comment,
                                     comment_len)) in enumerate(data_loader):
            audio = audio.to(device)
            audio = audio.unsqueeze(0)
            comment = comment.to(device)
            comment = comment.unsqueeze(0)
            targets = pack_padded_sequence(comment, [comment_len],
                                           batch_first=True)[0]

            batch_size = comment.shape[0]
            seq_len = targets.shape[0]
            # discriminator:1 -- real comment
            label0 = torch.full((batch_size, seq_len, 1), 0, device=device)
            label1 = torch.full((batch_size, seq_len, 1), 1, device=device)
            # real sample
            logits_real = netD(comment, [comment_len])  # batch*seq
            errD_real = criterion_BCEWithLogitsLoss(logits_real, label1)

            # discriminator:2 -- real comment
            audio_features = encoder(audio, [audio_len])
            if (Z_DIM > 0):
                z = Variable(torch.randn(audio_features.shape[0],
                                         Z_DIM)).cuda()
                audio_features = torch.cat([z, audio_features], 1)
            outputs = decoder(audio_features, comment, [comment_len])
            # generate comment discrimination
            max_v, max_index = outputs.detach().max(1)
            logits_fake = netD(max_index.unsqueeze(0),
                               [comment_len])  # batch*seq*1
            errD_fake = criterion_BCEWithLogitsLoss(logits_fake, label0)
            errD = errD_fake + errD_real
            optimizerD.zero_grad()
            errD.backward()
            optimizerD.step()

            # 2.generator
            audio_features = encoder(audio, [audio_len])
            if (Z_DIM > 0):
                z = Variable(torch.randn(audio_features.shape[0],
                                         Z_DIM)).cuda()
                audio_features = torch.cat([z, audio_features], 1)
            outputs = decoder(audio_features, comment, [comment_len])
            max_v, max_index = outputs.max(1)
            logits_fake = netD(max_index.unsqueeze(0),
                               [comment_len])  # batch*seq*vobsize
            errG = criterion_BCEWithLogitsLoss(logits_fake, label1)
            loss = criterion_CrossEntropyLoss(outputs, targets) + errG
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}],  Loss_D:  {:.4f}, Loss_G: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, errD.item(),
                            loss.item(), np.exp(loss.item())))

            # Save the model checkpoints
        if (epoch + 1) % args.save_step == 0:
            torch.save(
                decoder.state_dict(),
                os.path.join(args.model_path,
                             'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(args.model_path,
                             'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))