Example #1
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--test_query_file', '-i', type=str, required=True)
    argparser.add_argument('--load_path', '-p', type=str, required=True)
    # TODO: load epoch -> load best model
    argparser.add_argument('--load_epoch', '-e', type=int, required=True)

    argparser.add_argument('--output_file', '-o', type=str)
    argparser.add_argument('--dec_algorithm',
                           '-algo',
                           type=str,
                           default='greedy')

    new_args = argparser.parse_args()

    arg_file = os.path.join(new_args.load_path, 'args.pkl')
    if not os.path.exists(arg_file):
        raise RuntimeError('No default arguments file to load')
    f = open(arg_file, 'rb')
    args = pickle.load(f)
    f.close()

    if args.use_cuda:
        USE_CUDA = True

    vocab, rev_vocab = load_vocab(args.vocab_file,
                                  max_vocab=args.max_vocab_size)
    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size,
                                   args.emb_dim,
                                   padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size,
                   args.emb_dim,
                   args.hidden_dim,
                   args.n_layers,
                   bidirectional=True,
                   variable_lengths=True)
    G = Generator(vocab_size, args.response_max_len, args.emb_dim,
                  2 * args.hidden_dim, args.n_layers)

    if USE_CUDA:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()

    reload_model(new_args.load_path, new_args.load_epoch, word_embeddings, E,
                 G)

    predict(new_args.test_query_file, args.response_max_len, vocab, rev_vocab,
            word_embeddings, E, G, new_args.output_file)
Example #2
0
def trainDemo(lang, dataSet, nlVocab, codeVocab, train_variables):
    print("Training...")
    encoder1 = EncoderRNN(codeVocab.n_words, setting.HIDDDEN_SIAZE)
    attn_decoder1 = AttnDecoderRNN(setting.HIDDDEN_SIAZE,
                                   nlVocab.n_words,
                                   1,
                                   dropout_p=0.1)

    if setting.USE_CUDA:
        encoder1 = encoder1.cuda()
        attn_decoder1 = attn_decoder1.cuda()

    trainIters(lang,
               dataSet,
               train_variables,
               encoder1,
               attn_decoder1,
               2000000,
               print_every=5000)
#

hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1)

TRAIN = False
if "-t" in sys.argv:
    TRAIN = True

TRAIN_ITER = 7500
if len(sys.argv) == 3:
    TRAIN_ITER = int(sys.argv[2])

if use_cuda:
    encoder1 = encoder1.cuda()
    attn_decoder1 = attn_decoder1.cuda()

if os.path.exists("encoder.pt") and os.path.exists("decoder.pt") and not TRAIN:
    print("Found saved models")
    encoder_state = torch.load('encoder.pt')
    decoder_state = torch.load('decoder.pt')
    encoder1.load_state_dict(encoder_state)
    attn_decoder1.load_state_dict(decoder_state)
else:
    trainIters(encoder1, attn_decoder1, TRAIN_ITER, print_every=50)

torch.save(encoder1.state_dict(), "encoder.pt")
torch.save(attn_decoder1.state_dict(), "decoder.pt")

######################################################################
Example #4
0
def adversarial():
    # user the root logger
    logger = logging.getLogger("lan2720")
    
    argparser = argparse.ArgumentParser(add_help=False)
    argparser.add_argument('--load_path', '-p', type=str, required=True)
    # TODO: load best
    argparser.add_argument('--load_epoch', '-e', type=int, required=True)
    
    argparser.add_argument('--filter_num', type=int, required=True)
    argparser.add_argument('--filter_sizes', type=str, required=True)

    argparser.add_argument('--training_ratio', type=int, default=2)
    argparser.add_argument('--g_learning_rate', '-glr', type=float, default=0.001)
    argparser.add_argument('--d_learning_rate', '-dlr', type=float, default=0.001)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    
    # new arguments used in adversarial
    new_args = argparser.parse_args()
    
    # load default arguments
    default_arg_file = os.path.join(new_args.load_path, 'args.pkl')
    if not os.path.exists(default_arg_file):
        raise RuntimeError('No default argument file in %s' % new_args.load_path)
    else:
        with open(default_arg_file, 'rb') as f:
            args = pickle.load(f)
    
    args.mode = 'adversarial'
    #args.d_learning_rate  = 0.0001
    args.print_every = 1
    args.g_learning_rate = new_args.g_learning_rate
    args.d_learning_rate = new_args.d_learning_rate
    args.batch_size = new_args.batch_size

    # add new arguments
    args.load_path = new_args.load_path
    args.load_epoch = new_args.load_epoch
    args.filter_num = new_args.filter_num
    args.filter_sizes = new_args.filter_sizes
    args.training_ratio = new_args.training_ratio
    


    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode, time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger, os.path.join(exp_dirname, 'adversarial.log'), 
                        mode='w', silent=False, debug=True)

    # load vocabulary
    vocab, rev_vocab = load_vocab(args.vocab_file, max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size, args.emb_dim, padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size, args.emb_dim, args.hidden_dim, args.n_layers, args.dropout_rate, bidirectional=True, variable_lengths=True)
    G = Generator(vocab_size, args.response_max_len, args.emb_dim, 2*args.hidden_dim, args.n_layers, dropout_p=args.dropout_rate)
    D = Discriminator(args.emb_dim, args.filter_num, eval(args.filter_sizes))
    
    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()
        D.cuda()

    # define optimizer
    opt_G = torch.optim.Adam(G.rnn.parameters(), lr=args.g_learning_rate)
    opt_D = torch.optim.Adam(D.parameters(), lr=args.d_learning_rate)
    
    logger.info('----------------------------------')
    logger.info('Adversarial a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))
    
    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' + args.train_response_file)
   
    
    reload_model(args.load_path, args.load_epoch, word_embeddings, E, G)
    #    start_epoch = args.resume_epoch + 1
    #else:
    #    start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)


    # TODO: num_epoch is old one
    for e in range(args.num_epoch):
        train_data_generator = batcher(args.batch_size, args.train_query_file, args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        cur_time = time.time() 
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next()
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G, D) 
                ## evaluation
                #eval(args.valid_query_file, args.valid_response_file, args.batch_size, 
                #        word_embeddings, E, G, loss_func, args.use_cuda, vocab, args.response_max_len)
                break
            
            # prepare data
            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [sentence2id(sent, vocab) for sent in response_sentences]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()

            embedded_post = word_embeddings(posts_var)
            real_responses = word_embeddings(responses_var)

            # forward
            _, dec_init_state = E(embedded_post, input_lengths=posts_length.numpy())
            fake_responses = G(dec_init_state, word_embeddings) # [B, T, emb_size]

            prob_real = D(embedded_post, real_responses)
            prob_fake = D(embedded_post, fake_responses)
        
            # loss
            D_loss = - torch.mean(torch.log(prob_real) + torch.log(1. - prob_fake)) 
            G_loss = torch.mean(torch.log(1. - prob_fake))
            
            if step % args.training_ratio == 0:
                opt_D.zero_grad()
                D_loss.backward(retain_graph=True)
                opt_D.step()
            
            opt_G.zero_grad()
            G_loss.backward()
            opt_G.step()
            
            if step % args.print_every == 0:
                logger.info('Step %5d: D accuracy=%.2f (0.5 for D to converge) D score=%.2f (-1.38 for G to converge) (%.1f iters/sec)' % (
                    step, 
                    prob_real.cpu().data.numpy().mean(), 
                    -D_loss.cpu().data.numpy()[0], 
                    args.print_every/(time.time()-cur_time)))
                cur_time = time.time()
            step = step + 1
Example #5
0
    else:
        train_embedding = Embedding(filename=args.glove_filename, embedding_size=embedding_size).load_embedding(train_dataset.src_vocab)
        target_embedding = Embedding(filename=args.glove_filename, embedding_size=embedding_size).load_embedding(train_dataset.tgt_vocab)
    encoder.embedding.weight.data.copy_(train_embedding)
    decoder.embedding.weight.data.copy_(target_embedding)
    if opts.fixed_embeddings:
        encoder.embedding.weight.requires_grad = False
        decoder.embedding.weight.requires_grad = False
    else:
        decoder.embedding.weight.requires_grad = True
print("emb end")
if LOAD_CHECKPOINT:
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
if USE_CUDA:
    encoder.cuda()
    decoder.cuda()
FINE_TUNE = True
if FINE_TUNE:
    encoder.embedding.weight.requires_grad = True

print('='*100)
print('Model log:\n')
print(encoder)
print(decoder)
print('- Encoder input embedding requires_grad={}'.format(encoder.embedding.weight.requires_grad))
print('- Decoder input embedding requires_grad={}'.format(decoder.embedding.weight.requires_grad))
print('- Decoder output embedding requires_grad={}'.format(decoder.W_s.weight.requires_grad))
print('='*100 + '\n')

# Initialize optimizers (we can experiment different learning rates)
Example #6
0
class BiLSTMModel(nn.Module):
    def __init__(self):
        super(BiLSTMModel, self).__init__()
        self.base_rnn = None
        self.wd = None

    def init_model(self,
                   wd,
                   hidden_size,
                   e_layers,
                   d_layers,
                   base_rnn,
                   pretrained_embeddings=None,
                   dropout_p=0.1):

        self.base_rnn = base_rnn
        self.wd = wd
        self.dropout_p = dropout_p
        if pretrained_embeddings is True:
            print("Loading GloVe Embeddings ...")
            pretrained_embeddings = load_glove_embeddings(
                wd.word2index, hidden_size)

        self.encoder = EncoderRNN(wd.n_words,
                                  hidden_size,
                                  n_layers=e_layers,
                                  base_rnn=base_rnn,
                                  pretrained_embeddings=pretrained_embeddings)

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(int(hidden_size * 8), int(hidden_size)),
            torch.nn.ReLU(), torch.nn.Dropout(dropout_p),
            torch.nn.Linear(int(hidden_size), 3), torch.nn.Softmax(dim=1))
        self.parameter_list = [
            self.encoder.parameters(),
            self.mlp.parameters()
        ]

        if USE_CUDA:
            self.encoder = self.encoder.cuda()
            self.mlp = self.mlp.cuda()

        return self

    def forward(self, batch, inference=False):
        # Convert batch from numpy to torch
        if inference is True:
            text_batch, text_lengths, hyp_batch, hyp_lengths = batch
        else:
            text_batch, text_lengths, hyp_batch, hyp_lengths, labels = batch
        batch_size = text_batch.size(1)

        # Pass the input batch through the encoder
        text_enc_fwd_outputs, text_enc_bkwd_outputs, text_encoder_hidden = self.encoder(
            text_batch, text_lengths)
        hyp_enc_fwd_outputs, hyp_enc_bkwd_outputs, hyp_encoder_hidden = self.encoder(
            hyp_batch, hyp_lengths)

        last_text_enc_fwd = text_enc_fwd_outputs[-1, :, :]
        last_text_enc_bkwd = text_enc_bkwd_outputs[0, :, :]
        last_text_enc = torch.cat((last_text_enc_fwd, last_text_enc_bkwd),
                                  dim=1)
        last_hyp_enc_fwd = hyp_enc_fwd_outputs[-1, :, :]
        last_hyp_enc_bkwd = hyp_enc_bkwd_outputs[0, :, :]
        last_hyp_enc = torch.cat((last_hyp_enc_fwd, last_hyp_enc_bkwd), dim=1)

        mult_feature, diff_feature = last_text_enc * last_hyp_enc, torch.abs(
            last_text_enc - last_hyp_enc)

        features = torch.cat(
            [last_text_enc, last_hyp_enc, mult_feature, diff_feature], dim=1)
        outputs = self.mlp(features)  # B x 3
        return outputs

    def get_loss_for_batch(self, batch):
        labels = batch[-1]
        outputs = self(batch)

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(outputs, labels)

        return loss

    def torch_batch_from_numpy_batch(self, batch):
        batch = list(batch)

        variable_indices = [
            0, 2, 4
        ]  # tuple indices of variables need to be converted
        for i in variable_indices:
            var = Variable(torch.from_numpy(batch[i]))
            if USE_CUDA:
                var = var.cuda()
            batch[i] = var

        return batch

    # Trains on a single batch
    def train_batch(self, batch, tl_mode=False):
        self.train()

        batch = self.torch_batch_from_numpy_batch(batch)
        loss = self.get_loss_for_batch(batch)
        loss.backward()

        return loss.item()  #loss.data[0]

    def validate(self, batch):
        self.eval()

        batch = self.torch_batch_from_numpy_batch(batch)
        return self.get_loss_for_batch(batch).item()  #.data[0]

    def score(self, data):
        batch_size = 1
        batches = nli_batches(batch_size, data)

        total_correct = 0
        for batch in tqdm(batches):
            batch = self.torch_batch_from_numpy_batch(batch)
            num_correct = self._acc_for_batch(batch)
            total_correct += num_correct

        acc = total_correct / (len(batches) * batch_size)

        return acc

    def _acc_for_batch(self, batch):
        '''
        :param batch:
        :return: The number of correct predictions in a batch
        '''
        self.eval()

        outputs = self(batch)
        predictions = outputs.max(1)[1]

        labels = batch[-1]

        num_error = torch.nonzero(labels - predictions)
        num_correct = labels.size(0) - num_error.size(0)

        return num_correct

    def export_state(self, dir, label, epoch=-1):
        print("Saving models.")

        cwd = os.getcwd() + '/'

        enc_out = dir + ENC_1_FILE
        mlp_out = dir + MLP_FILE
        i2w_out = dir + I2W_FILE
        w2i_out = dir + W2I_FILE
        w2c_out = dir + W2C_FILE
        inf_out = dir + INF_FILE

        torch.save(self.encoder.state_dict(), enc_out)
        torch.save(self.mlp.state_dict(), mlp_out)

        i2w = open(i2w_out, 'wb')
        pickle.dump(self.wd.index2word, i2w)
        i2w.close()
        w2i = open(w2i_out, 'wb')
        pickle.dump(self.wd.word2index, w2i)
        w2i.close()
        w2c = open(w2c_out, 'wb')
        pickle.dump(self.wd.word2count, w2c)
        w2c.close()

        info = open(inf_out, 'w')
        using_lstm = 1 if self.base_rnn == nn.LSTM else 0
        info.write(
            str(self.encoder.hidden_size) + "\n" + str(self.encoder.n_layers) +
            "\n" + str(self.wd.n_words) + "\n" + str(using_lstm))
        if epoch > 0:
            info.write("\n" + str(epoch))
        info.close()

        files = [enc_out, mlp_out, i2w_out, w2i_out, w2c_out, inf_out]

        print("Bundling models")

        tf = tarfile.open(cwd + dir + label, mode='w')
        for file in files:
            tf.add(file)
        tf.close()

        for file in files:
            os.remove(file)

        print("Finished saving models.")

    def import_state(self, model_file, active_dir=TEMP_DIR, load_epoch=False):
        print("Loading models.")
        cwd = os.getcwd() + '/'
        tf = tarfile.open(model_file)

        # extract directly to current model directory
        for member in tf.getmembers():
            if member.isreg():
                member.name = os.path.basename(member.name)
                tf.extract(member, path=active_dir)

        info = open(active_dir + INF_FILE, 'r')
        lns = info.readlines()
        hidden_size, e_layers, n_words, using_lstm = [int(i) for i in lns[:4]]

        if load_epoch:
            epoch = int(lns[-1])

        i2w = open(cwd + TEMP_DIR + I2W_FILE, 'rb')
        w2i = open(cwd + TEMP_DIR + W2I_FILE, 'rb')
        w2c = open(cwd + TEMP_DIR + W2C_FILE, 'rb')
        i2w_dict = pickle.load(i2w)
        w2i_dict = pickle.load(w2i)
        w2c_dict = pickle.load(w2c)
        wd = WordDict(dicts=[w2i_dict, i2w_dict, w2c_dict, n_words])
        w2i.close()
        i2w.close()
        w2c.close()

        self.base_rnn = nn.LSTM if using_lstm == 1 else nn.GRU
        self.wd = wd
        self.encoder = EncoderRNN(wd.n_words,
                                  hidden_size,
                                  n_layers=e_layers,
                                  base_rnn=self.base_rnn)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(int(hidden_size * 8), int(hidden_size)),
            torch.nn.ReLU(), torch.nn.Dropout(0.1),
            torch.nn.Linear(int(hidden_size), 3), torch.nn.Softmax(dim=1))
        if not USE_CUDA:
            self.encoder.load_state_dict(
                torch.load(cwd + TEMP_DIR + ENC_1_FILE,
                           map_location=lambda storage, loc: storage))
            self.mlp.load_state_dict(
                torch.load(cwd + TEMP_DIR + MLP_FILE,
                           map_location=lambda storage, loc: storage))
        else:
            self.encoder.load_state_dict(
                torch.load(cwd + TEMP_DIR + ENC_1_FILE))
            self.mlp.load_state_dict(torch.load(cwd + TEMP_DIR + MLP_FILE))
            self.encoder = self.encoder.cuda()
            self.mlp = self.mlp.cuda()

        self.encoder.eval()
        self.mlp.eval()

        self.parameter_list = [
            self.encoder.parameters(),
            self.mlp.parameters()
        ]
        tf.close()

        print("Loaded models.")

        if load_epoch:
            return self, epoch
        else:
            return self

    def torch_batch_from_numpy_batch_without_label(self, batch):
        batch = list(batch)

        variable_indices = [0, 2]
        for i in variable_indices:
            var = Variable(torch.from_numpy(batch[i]))
            if USE_CUDA:
                var = var.cuda()
            batch[i] = var

        return batch

    def predict(self, data):
        batch_size = 1
        batches = nli_batches_without_label(batch_size, data)

        predictions = []
        for batch in tqdm(batches):
            batch = self.torch_batch_from_numpy_batch_without_label(batch)
            outputs = self(batch, inference=True)
            pred = outputs.max(1)[1]
            predictions.append(pred)

        return torch.cat(predictions)

    def add_new_vocabulary(self, genre):
        old_vocab_size = self.wd.n_words
        print("Previous vocabulary size: " + str(old_vocab_size))

        train_set = nli_preprocessor.get_multinli_text_hyp_labels(
            genre=genre
        )  #nli_preprocessor.get_multinli_training_set(max_lines=args.max_lines)
        matched_val_set = nli_preprocessor.get_multinli_matched_val_set(
        )  #genre_val_set(genre)

        unmerged_sentences = []
        for data in [train_set, matched_val_set]:
            unmerged_sentences.extend([data["text"], data["hyp"]])
        all_sentences = list(itertools.chain.from_iterable(unmerged_sentences))

        for line in all_sentences:
            self.wd.add_sentence(line)

        print("New vocabulary size: " + str(self.wd.n_words))

        print("Extending the Embedding layer with new vocabulary...")
        num_new_words = self.wd.n_words - old_vocab_size
        self.encoder.extend_embedding_layer(self.wd.word2index, num_new_words)

        self.new_vocab_size = num_new_words

    def freeze_source_params(self):
        for name, param in self.named_parameters():
            if "rnn" in name:
                param.requires_grad = False
            if ("M_k" in name or "M_v" in name) and "target_4" not in name:
                param.requires_grad = False
        for name, param in self.named_parameters():
            if param.requires_grad is True:
                print(name)
Example #7
0
def pretrain():
    # Parse command line arguments
    argparser = argparse.ArgumentParser()

    # train
    argparser.add_argument('--mode',
                           '-m',
                           choices=('pretrain', 'adversarial', 'inference'),
                           type=str,
                           required=True)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    argparser.add_argument('--num_epoch', '-e', type=int, default=10)
    argparser.add_argument('--print_every', type=int, default=100)
    argparser.add_argument('--use_cuda', default=True)
    argparser.add_argument('--g_learning_rate',
                           '-glr',
                           type=float,
                           default=0.001)
    argparser.add_argument('--d_learning_rate',
                           '-dlr',
                           type=float,
                           default=0.001)

    # resume
    argparser.add_argument('--resume', action='store_true', dest='resume')
    argparser.add_argument('--resume_dir', type=str)
    argparser.add_argument('--resume_epoch', type=int)

    # save
    argparser.add_argument('--exp_dir', type=str, required=True)

    # model
    argparser.add_argument('--emb_dim', type=int, default=128)
    argparser.add_argument('--hidden_dim', type=int, default=256)
    argparser.add_argument('--dropout_rate', '-drop', type=float, default=0.5)
    argparser.add_argument('--n_layers', type=int, default=1)
    argparser.add_argument('--response_max_len', type=int, default=15)

    # data
    argparser.add_argument('--train_query_file',
                           '-tqf',
                           type=str,
                           required=True)
    argparser.add_argument('--train_response_file',
                           '-trf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_query_file',
                           '-vqf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_response_file',
                           '-vrf',
                           type=str,
                           required=True)
    argparser.add_argument('--vocab_file', '-vf', type=str, default='')
    argparser.add_argument('--max_vocab_size', '-mv', type=int, default=100000)

    args = argparser.parse_args()

    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode,
                               time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger,
                        os.path.join(exp_dirname, 'train.log'),
                        mode='w',
                        silent=False,
                        debug=True)

    if not args.vocab_file:
        logger.info("no vocabulary file")
        build_vocab(args.train_query_file,
                    args.train_response_file,
                    seperated=True)
        sys.exit()
    else:
        vocab, rev_vocab = load_vocab(args.vocab_file,
                                      max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size,
                                   args.emb_dim,
                                   padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size,
                   args.emb_dim,
                   args.hidden_dim,
                   args.n_layers,
                   args.dropout_rate,
                   bidirectional=True,
                   variable_lengths=True)
    G = Generator(vocab_size,
                  args.response_max_len,
                  args.emb_dim,
                  2 * args.hidden_dim,
                  args.n_layers,
                  dropout_p=args.dropout_rate)

    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()

    loss_func = nn.NLLLoss(size_average=False)
    params = list(word_embeddings.parameters()) + list(E.parameters()) + list(
        G.parameters())
    opt = torch.optim.Adam(params, lr=args.g_learning_rate)

    logger.info('----------------------------------')
    logger.info('Pre-train a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))

    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' +
                args.train_response_file)

    # resume training from other experiment
    if args.resume:
        assert args.resume_epoch >= 0, 'If resume training, please assign resume_epoch'
        reload_model(args.resume_dir, args.resume_epoch, word_embeddings, E, G)
        start_epoch = args.resume_epoch + 1
    else:
        start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    for e in range(start_epoch, args.num_epoch):
        logger.info('---------------------training--------------------------')
        train_data_generator = batcher(args.batch_size, args.train_query_file,
                                       args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        total_loss = 0.0
        total_valid_char = []
        cur_time = time.time()
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next(
                )
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G)
                # evaluation
                eval(args.valid_query_file, args.valid_response_file,
                     args.batch_size, word_embeddings, E, G, loss_func,
                     args.use_cuda, vocab, args.response_max_len)
                break

            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [
                sentence2id(sent, vocab) for sent in response_sentences
            ]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(
                response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            # 在sentence后面加eos
            references_var = torch.cat([
                responses_var,
                Variable(torch.zeros(responses_var.size(0), 1).long(),
                         requires_grad=False)
            ],
                                       dim=1)
            for idx, length in enumerate(responses_length):
                references_var[idx, length] = SYM_EOS

            # show case
            #for p, r, ref in zip(posts_var.data.numpy()[:10], responses_var.data.numpy()[:10], references_var.data.numpy()[:10]):
            #    print ''.join(id2sentence(p, rev_vocab))
            #    print ''.join(id2sentence(r, rev_vocab))
            #    print ''.join(id2sentence(ref, rev_vocab))
            #    print

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()
                references_var = references_var.cuda()

            embedded_post = word_embeddings(posts_var)
            embedded_response = word_embeddings(responses_var)

            _, dec_init_state = E(embedded_post,
                                  input_lengths=posts_length.numpy())
            log_softmax_outputs = G.supervise(
                embedded_response, dec_init_state,
                word_embeddings)  # [B, T, vocab_size]

            outputs = log_softmax_outputs.view(-1, vocab_size)
            mask_pos = mask(references_var).view(-1).unsqueeze(-1)
            masked_output = outputs * (mask_pos.expand_as(outputs))
            loss = loss_func(masked_output,
                             references_var.view(-1)) / (posts_var.size(0))

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss * (posts_var.size(0))
            total_valid_char.append(mask_pos)

            if step % args.print_every == 0:
                total_loss_val = total_loss.cpu().data.numpy()[0]
                total_valid_char_val = torch.sum(
                    torch.cat(total_valid_char, dim=1)).cpu().data.numpy()[0]
                logger.info(
                    'Step %5d: (per word) training perplexity %.2f (%.1f iters/sec)'
                    % (step, math.exp(total_loss_val / total_valid_char_val),
                       args.print_every / (time.time() - cur_time)))
                total_loss = 0.0
                total_valid_char = []
                total_case_num = 0
                cur_time = time.time()
            step = step + 1
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start, iter / n_iters), iter,
                   iter / n_iters * 100, print_loss_avg))
            evaluateRandomly(encoder, decoder, 1)

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)


lang, pairs = prepareData("filtered_dataset.jsonl")
filtered_embeddings = filter_embedding(lang, "glove.6B.100d.txt")
print(random.choice(pairs))

hidden_size = 256
encoder1 = EncoderRNN(lang.n_words, hidden_size, filtered_embeddings)
attn_decoder1 = AttnDecoderRNN(hidden_size,
                               lang.n_words,
                               dropout_p=0.1,
                               embeddings=filtered_embeddings)
print("parameters ", get_n_params(encoder1) + get_n_params(attn_decoder1))

encoder1.cuda()
attn_decoder1.cuda()

trainIters(encoder1, attn_decoder1, 7, print_every=100)