예제 #1
0
def main():
    embedding = nn.Embedding(voc.num_words, hidden_size)
    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                                  voc.num_words, decoder_n_layers, dropout)

    model = torch.load(model_save_pth, 'cpu')

    encoder.load_state_dict(torch.load(model_save_pth, device)['en'])
    decoder.load_state_dict(torch.load(model_save_pth, device)['de'])

    #encoder = model['en']
    #decoder = model.LuongAttnDecoderRNN['de']

    #encoder = encoder.to(device)
    #decoder = decoder.to(device)
    encoder.eval()
    decoder.eval()

    searcher = GreedySearchDecoder(encoder, decoder)

    for sentence in pick_n_valid_sentences(10):
        decoded_words = evaluate(searcher, sentence)
        print('Human: {}'.format(sentence))
        print('Bot: {}'.format(''.join(decoded_words)))
예제 #2
0
파일: utils.py 프로젝트: kristogj/chatbot
def load_encoder_decoder(voc, checkpoint, configs):
    """
    Initialize encoder and decoder, and load from file if prev states exists
    :param voc: Vocabulary
    :param checkpoint: dict
    :param configs: dict
    :return: Encoder, LuongAttentionDecoderRNN
    """
    logging.info('Building encoder and decoder ...')

    # Initialize word embeddings
    embedding = nn.Embedding(voc.num_words, configs["hidden_size"])

    # Initialize encoder & decoder models
    encoder = EncoderRNN(configs["hidden_size"], embedding,
                         configs["encoder_n_layers"], configs["dropout"])
    decoder = LuongAttentionDecoderRNN(embedding, voc.num_words, configs)

    if checkpoint:
        voc.__dict__ = checkpoint['voc_dict']
        embedding.load_state_dict(checkpoint['embedding'])
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])

    logging.info('Models built and ready to go!')
    return encoder.to(get_device()), decoder.to(get_device())
예제 #3
0
파일: eval2.py 프로젝트: mharwani/AMR-text
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--dev_files', default='../amr_anno_1.0/data/split/dev/*',
                    help='dev files.')
    ap.add_argument('--log_dir', default='./log',
                    help='log directory')
    ap.add_argument('--exp_name', default='experiment',
                    help='experiment name')
    args = ap.parse_args()
    
    #read dev files
    dev_files = glob.glob(args.dev_files)
    dev_pairs = AMR.read_AMR_files(dev_files, True)
    
    logdir = args.log_dir
    exp_dir = logdir + '/' + args.exp_name
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    
    max_iter = 0
    dev_bleu = 0.0
    while True:
        load_state_file = None
        state_files = glob.glob(exp_dir + '/*')
        for sf in state_files:
            iter_num = int(sf.split('_')[1].split('.')[0])
            if iter_num > max_iter:
                max_iter = iter_num
                load_state_file = sf
        if load_state_file is not None:
            state = torch.load(load_state_file)
            amr_vocab = state['amr_vocab']
            en_vocab = state['en_vocab']
            hidden_size = state['hidden_size']
            edge_size = state['edge_size']
            drop = state['dropout']
            mlength = state['max_length']
            logging.info('loaded checkpoint %s', load_state_file)
            
            encoder = EncoderRNN(amr_vocab.n_nodes, hidden_size).to(device)
            child_sum = ChildSum(amr_vocab.n_edges, edge_size, hidden_size).to(device)
            decoder = AttnDecoderRNN(hidden_size, en_vocab.n_words, dropout_p=drop, max_length=mlength).to(device)
            encoder.load_state_dict(state['enc_state'])
            child_sum.load_state_dict(state['sum_state'])
            decoder.load_state_dict(state['dec_state'])
            # translate from the dev set
            translate_random_amr(encoder, child_sum, decoder, dev_pairs, amr_vocab, en_vocab, mlength, n=10)
            translated_amrs = translate_amrs(encoder, child_sum, decoder, dev_pairs, amr_vocab, en_vocab, mlength)
            references = [[pair[0]] for pair in dev_pairs[:len(translated_amrs)]]
            candidates = [sent.split() for sent in translated_amrs]
            dev_bleu = corpus_bleu(references, candidates)
            logging.info('Dev BLEU score: %.2f', dev_bleu)
        else:
            logging.info('No new checkpoint found. Last DEV BLEU score: %.2f', dev_bleu)
        
        time.sleep(20)
예제 #4
0
파일: train.py 프로젝트: Ierezell/ExamAi
def init():
    print("\tInitialising sentences")

    print("\t\tLoading and cleaning json files")
    json_of_convs = load_all_json_conv('./Dataset/messages')

    print("\t\tLoading two person convs")
    duo_conversations = get_chat_friend_and_me(json_of_convs)

    print("\t\tMaking two person convs discussions")
    discussions = get_discussions(duo_conversations)

    print("\t\tCreating pairs for training")
    pairs_of_sentences = make_pairs(discussions)
    print(f"\t\t{len(pairs_of_sentences)} different pairs")

    print("\t\tCreating Vocabulary")
    voc = Voc()

    print("\t\tPopulating Vocabulary")
    voc.createVocFromPairs(pairs_of_sentences)
    print(f"\t\tVocabulary of : {voc.num_words} differents words")

    print('\tBuilding encoder and decoder ...')
    embedding = nn.Embedding(voc.num_words, HIDDEN_SIZE)
    encoder = EncoderRNN(HIDDEN_SIZE, embedding, ENCODER_N_LAYERS, DROPOUT)
    decoder = LuongAttnDecoderRNN(ATTN_MODEL, embedding, HIDDEN_SIZE,
                                  voc.num_words, DECODER_N_LAYERS, DROPOUT)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=LEARNING_RATE)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=LEARNING_RATE * DECODER_LEARNING_RATIO)
    checkpoint = None
    if LOADFILENAME:
        print("\t\tLoading last training")
        checkpoint = torch.load(LOADFILENAME)
        # If loading a model trained on GPU to CPU
        # checkpoint=torch.load(loadFilename,map_location=torch.device('cpu'))
        encoder_sd = checkpoint['en']
        decoder_sd = checkpoint['de']
        encoder_optimizer_sd = checkpoint['en_opt']
        decoder_optimizer_sd = checkpoint['de_opt']
        embedding_sd = checkpoint['embedding']
        voc.__dict__ = checkpoint['voc_dict']
        print("\t\tPopulating from last training")
        embedding.load_state_dict(embedding_sd)
        encoder.load_state_dict(encoder_sd)
        decoder.load_state_dict(decoder_sd)
        encoder_optimizer.load_state_dict(encoder_optimizer_sd)
        decoder_optimizer.load_state_dict(decoder_optimizer_sd)

    encoder = encoder.to(DEVICE)
    decoder = decoder.to(DEVICE)
    return (encoder, decoder, encoder_optimizer, decoder_optimizer, embedding,
            voc, pairs_of_sentences, checkpoint)
예제 #5
0
def main():
    dataset = 'imdb'
    hidden_size = 325
    train_iters = 40
    pretrain_train_iters = 40
    lang, lines = cachePrepareData(dataset)

    PATH = './pretrained/'
    pretrained_filename = PATH + 'pretrained_lstm_' + dataset + '_' + str(hidden_size) + '_' + str(pretrain_train_iters) + '.pt'
    
    model_filename = 'maskmle_' + dataset + '_' + str(hidden_size) + '_' + str(train_iters) + '.pt'
    
    encoder1 = EncoderRNN(lang.n_words, hidden_size).to(device)
    encoder1.load_state_dict(torch.load(PATH + 'e_' + model_filename))
    
    attn_decoder1 = AttnDecoderRNN(hidden_size, lang.n_words, dropout_p=0.1).to(device)
    attn_decoder1.load_state_dict(torch.load(PATH + 'd_' + model_filename))
    print(evaluateRandomly(encoder1, attn_decoder1, lang, lines, 20, 0.5))
예제 #6
0
def main(args):
    config_path = os.path.join(args.config_path, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    print('[-] Loading pickles')
    dataset_path = Path(config["dataset_path"])
    input_lang = CustomUnpickler(open(dataset_path / 'input_lang.pkl',
                                      'rb')).load()
    output_lang = CustomUnpickler(open(dataset_path / 'output_lang.pkl',
                                       'rb')).load()
    pairs = CustomUnpickler(open(dataset_path / 'pairs.pkl', 'rb')).load()
    # input_lang = load_pkl(dataset_path / 'input_lang.pkl')
    # output_lang = load_pkl(dataset_path / 'output_lang.pkl')
    # pairs = load_pkl(dataset_path / 'pairs.pkl')

    hidden_size = config["model_cfg"]["hidden_size"]
    max_len = config["max_len"]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size,
                             output_lang.n_words,
                             max_len,
                             dropout_p=0.1).to(device)

    print('[-] Loading models')
    ckpt = torch.load(config["models_path"] + 'models.ckpt')
    encoder.load_state_dict(ckpt['encoder'])
    encoder.to(device)
    decoder.load_state_dict(ckpt['decoder'])
    decoder.to(device)

    evaluator = Evaluater(device, encoder, decoder, input_lang, output_lang,
                          max_len)

    # Evaluate random samples
    evaluator.evaluateRandomly(pairs)

    evaluator.evaluateAndShowAttention("elle a cinq ans de moins que moi .")
    # evaluator.evaluateAndShowAttention("elle est trop petit .")
    # evaluator.evaluateAndShowAttention("je ne crains pas de mourir .")
    # evaluator.evaluateAndShowAttention("c est un jeune directeur plein de talent .")
    plt.savefig('attention.png')
예제 #7
0
파일: sample.py 프로젝트: qfzhu/fix_style
def main():
    # load vocablary
    with open('data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)

    # build model
    encoder = EncoderRNN(voc_size=60736, emb_size=300, hidden_size=300)
    decoder = FactoredLSTM(300, 512, 512, len(vocab))

    encoder.load_state_dict(torch.load('pretrained_models/encoder-4.pkl'))
    decoder.load_state_dict(torch.load('pretrained_models/decoder-4.pkl'))

    # prepare images
    # transform = transforms.Compose([
    #     Rescale((224, 224)),
    #     transforms.ToTensor()
    #     ])
    # img_names, img_list = load_sample_images('sample_images/', transform)
    # image = to_var(img_list[30], volatile=True)

    data_loader = get_data_loader('', 'data/factual_train.txt', vocab, 1)

    # if torch.cuda.is_available():
    #     encoder = encoder.cuda()
    #     decoder = decoder.cuda()

    for i, (messages, m_lengths, targets, t_lengths) in enumerate(data_loader):
        print(''.join([vocab.i2w[x] for x in messages[0]]))
        messages = to_var(messages.long())
        targets = to_var(targets.long())

        # forward, backward and optimize
        output, features = encoder(messages, list(m_lengths))
        outputs = decoder.sample(features, mode="humorous")
        caption = [vocab.i2w[x] for x in outputs]
        print(''.join(caption))
        print('-------')
예제 #8
0

def tensorsFromPair(pair):
    input_tensor = listTotensor(input_lang, pair[0])
    output_tensor = listTotensor(output_lang, pair[1])
    return (input_tensor, output_tensor)


hidden_size = 256
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = AttenDecoderRNN(hidden_size,
                          output_lang.n_words,
                          max_len=MAX_LENGTH,
                          dropout_p=0.1).to(device)

encoder.load_state_dict(torch.load("models/encoder_1000000.pth"))
decoder.load_state_dict(torch.load("models/decoder_1000000.pth"))
n_iters = 10

train_sen_pairs = [random.choice(pairs) for i in range(n_iters)]
training_pairs = [tensorsFromPair(train_sen_pairs[i]) for i in range(n_iters)]

for i in range(n_iters):
    input_tensor, output_tensor = training_pairs[i]
    encoder_hidden = encoder.initHidden()
    input_len = input_tensor.size(0)
    encoder_outputs = torch.zeros(MAX_LENGTH,
                                  encoder.hidden_size,
                                  device=device)
    for ei in range(input_len):
        encoder_output, encoder_hidden = encoder(input_tensor[ei],
예제 #9
0
    print("input_lang.n_words: " + str(input_lang.n_words))
    print("output_lang.n_words: " + str(output_lang.n_words))

    checkpoint = '{}/BEST_checkpoint.tar'.format(save_dir)  # model checkpoint
    print('checkpoint: ' + str(checkpoint))
    # Load model
    checkpoint = torch.load(checkpoint)
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']

    print('Building encoder and decoder ...')
    # Initialize encoder & decoder models
    encoder = EncoderRNN(input_lang.n_words, hidden_size, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, hidden_size, output_lang.n_words, decoder_n_layers, dropout)

    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)

    # Use appropriate device
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    print('Models built and ready to go!')

    # Set dropout layers to eval mode
    encoder.eval()
    decoder.eval()

    # Initialize search module
    searcher = GreedySearchDecoder(encoder, decoder)
    for input_sentence, target_sentence in pick_n_valid_sentences(input_lang, output_lang, 10):
        decoded_words = evaluate(searcher, input_sentence, input_lang, output_lang)
예제 #10
0
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument(
        '--hidden_size',
        default=256,
        type=int,
        help='hidden size of encoder/decoder, also word vector size')
    ap.add_argument('--edge_size',
                    default=20,
                    type=int,
                    help='embedding dimension of edges')
    ap.add_argument('--n_iters',
                    default=100000,
                    type=int,
                    help='total number of examples to train on')
    ap.add_argument('--print_every',
                    default=5000,
                    type=int,
                    help='print loss info every this many training examples')
    ap.add_argument(
        '--checkpoint_every',
        default=10000,
        type=int,
        help='write out checkpoint every this many training examples')
    ap.add_argument('--initial_learning_rate',
                    default=0.001,
                    type=int,
                    help='initial learning rate')
    ap.add_argument('--train_files',
                    default='../amr_anno_1.0/data/split/training/*',
                    help='training files.')
    ap.add_argument('--log_dir', default='./log', help='log directory')
    ap.add_argument('--exp_name', default='experiment', help='experiment name')
    ap.add_argument('--batch_size', default=5, type=int, help='batch size')
    ap.add_argument('--load_checkpoint',
                    action='store_true',
                    help='use existing checkpoint')

    args = ap.parse_args()

    logdir = args.log_dir
    exp_dir = logdir + '/' + args.exp_name
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

    load_state_file = None
    if args.load_checkpoint:
        max_iter = 0
        state_files = glob.glob(exp_dir + '/*')
        for sf in state_files:
            iter_num = int(sf.split('_')[1].split('.')[0])
            if iter_num > max_iter:
                max_iter = iter_num
                load_state_file = sf
    # Create vocab from training data
    iter_num = 0
    train_files = glob.glob(args.train_files)
    train_pairs = AMR.read_AMR_files(train_files, True)
    amr_vocab, en_vocab = None, None
    state = None
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    edge_size = args.edge_size
    drop = DROPOUT_P
    mlength = MAX_LENGTH
    if load_state_file is not None:
        state = torch.load(load_state_file)
        iter_num = state['iter_num']
        amr_vocab = state['amr_vocab']
        en_vocab = state['en_vocab']
        hidden_size = state['hidden_size']
        edge_size = state['edge_size']
        drop = state['dropout']
        mlength = state['max_length']
        logging.info('loaded checkpoint %s', load_state_file)
    else:
        amr_vocab, en_vocab = make_vocabs(train_pairs)
    encoder = EncoderRNN(amr_vocab.n_nodes, hidden_size).to(device)
    child_sum = ChildSum(amr_vocab.n_edges, edge_size, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size,
                             en_vocab.n_words,
                             dropout_p=drop,
                             max_length=mlength).to(device)

    #load checkpoint
    if state is not None:
        encoder.load_state_dict(state['enc_state'])
        child_sum.load_state_dict(state['sum_state'])
        decoder.load_state_dict(state['dec_state'])

    # set up optimization/loss
    params = list(encoder.parameters()) + list(child_sum.parameters()) + list(
        decoder.parameters())  # .parameters() returns generator
    optimizer = optim.Adam(params, lr=args.initial_learning_rate)
    criterion = nn.NLLLoss()

    #load checkpoint
    if state is not None:
        optimizer.load_state_dict(state['opt_state'])

    start = time.time()
    print_loss_total = 0  # Reset every args.print_every

    while iter_num < args.n_iters:
        num_samples = batch_size
        remaining = args.checkpoint_every - (iter_num % args.checkpoint_every)
        remaining2 = args.print_every - (iter_num % args.print_every)
        if remaining < batch_size:
            num_samples = remaining
        elif remaining2 < batch_size:
            num_samples = remaining2
        iter_num += num_samples
        random_pairs = random.sample(train_pairs, num_samples)
        target_snt = tensors_from_batch(en_vocab, random_pairs)
        loss = train(random_pairs, target_snt, amr_vocab, encoder, child_sum,
                     decoder, optimizer, criterion)
        print_loss_total += loss

        if iter_num % args.checkpoint_every == 0:
            state = {
                'iter_num': iter_num,
                'enc_state': encoder.state_dict(),
                'sum_state': child_sum.state_dict(),
                'dec_state': decoder.state_dict(),
                'opt_state': optimizer.state_dict(),
                'amr_vocab': amr_vocab,
                'en_vocab': en_vocab,
                'hidden_size': hidden_size,
                'edge_size': edge_size,
                'dropout': drop,
                'max_length': mlength
            }
            filename = 'state_%010d.pt' % iter_num
            save_file = exp_dir + '/' + filename
            torch.save(state, save_file)
            logging.debug('wrote checkpoint to %s', save_file)

        if iter_num % args.print_every == 0:
            print_loss_avg = print_loss_total / args.print_every
            print_loss_total = 0
            logging.info(
                'time since start:%s (iter:%d iter/n_iters:%d%%) loss_avg:%.4f',
                time.time() - start, iter_num, iter_num / args.n_iters * 100,
                print_loss_avg)
    # find the last decoder state
    decoder_last_state = sorted(
        [
            x
            for x in os.listdir(
                "/Users/lena/Desktop/thesis-more/checkpoints/pov/" + lang
            )
            if x.startswith("dec")
        ]
    )[-1]
    print(decoder_last_state)

    encoder.load_state_dict(
        torch.load(
            "/Users/lena/Desktop/thesis-more/checkpoints/pov/"
            + lang
            + "/"
            + encoder_last_state
        )
    )
    decoder.load_state_dict(
        torch.load(
            "/Users/lena/Desktop/thesis-more/checkpoints/pov/"
            + lang
            + "/"
            + decoder_last_state
        )
    )

    # predict for test

    figs_path = "figs/" + lang + "/pov_test"
예제 #12
0
# find the last encoder state
encoder_last_state = sorted([
    x for x in os.listdir(args.path + "checkpoints/" + lang)
    if x.startswith("enc")
])[-1]
print(encoder_last_state)

# find the last decoder state
decoder_last_state = sorted([
    x for x in os.listdir(args.path + "checkpoints/" + lang)
    if x.startswith("dec")
])[-1]
print(decoder_last_state)

encoder.load_state_dict(
    torch.load(args.path + "checkpoints/" + lang + "/" + encoder_last_state))
decoder.load_state_dict(
    torch.load(args.path + "checkpoints/" + lang + "/" + decoder_last_state))

# predict for test
figs_path = args.path + "figs/" + lang + "/test"
if not os.path.exists(figs_path):
    os.makedirs(figs_path)

decoded_words_test = decode_dataset(
    args.path + "data/" + lang + "/" + lang + "_test.txt",
    encoder,
    decoder,
    dataset,
    figs_path,
)
def main():
    """
    Main function for the translation RNN
    """
    args = parse_args()

    eng_prefixes = (
        "i am ", "i m ",
        "he is", "he s ",
        "she is", "she s ",
        "you are", "you re ",
        "we are", "we re ",
        "they are", "they re "
    )

    input_lang, output_lang, pairs = \
        prepare_data('eng', 'fra', reverse=True, max_length=args.max_length, prefixes=eng_prefixes)
    # print(random.choice(pairs))

    encoder = EncoderRNN(input_lang.num_words, args.hidden_size).to(args.device)
    decoder = AttentionDecoderRNN(
            args.hidden_size,
            output_lang.num_words,
            args.max_length,
            args.dropout
        ).to(args.device)

    if args.train:
        train_iters(
            encoder,
            decoder,
            pairs,
            args.max_length,
            input_lang,
            output_lang,
            args.num_iters,
            device=args.device,
            print_every=args.print_every,
            teacher_forcing_ratio=args.teacher_forcing_ratio)

        torch.save(encoder.state_dict(), 'encoder.pth')
        torch.save(decoder.state_dict(), 'decoder.pth')

    encoder.load_state_dict(torch.load('encoder.pth'))
    decoder.load_state_dict(torch.load('decoder.pth'))

    encoder.eval()
    decoder.eval()

    evaluate_randomly(
        encoder,
        decoder,
        pairs,
        input_lang,
        output_lang,
        args.max_length,
        args.device,
        n=10
    )

    # visualizing attention
    _, attentions = \
        evaluate(
            encoder,
            decoder,
            'je suis trop froid .',
            input_lang,
            output_lang,
            args.max_length,
            args.device
        )

    plt.matshow(attentions.cpu().numpy())

    input_sentences = ['elle a cinq ans de moins que moi .',
                       'elle est trop petit .',
                       'je ne crains pas de mourir .',
                       'c est un jeune directeur plein de talent .']

    for input_sentence in input_sentences:
        evaluate_and_show_attention(
            encoder,
            decoder,
            input_sentence,
            input_lang,
            output_lang,
            args.max_length,
            args.device
        )
예제 #14
0
def main():
    corpus_name = "cornell movie-dialogs corpus"
    corpus = os.path.join("data", corpus_name)

    printLines(os.path.join(corpus, "movie_lines.txt"))

    # Define path to new file
    datafile = os.path.join(corpus, "formatted_movie_lines.txt")
    linefile = os.path.join(corpus, "movie_lines.txt")
    conversationfile = os.path.join(corpus, "movie_conversations.txt")

    # Initialize lines dict, conversations list, and field ids
    MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
    MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

    # Load lines and process conversations
    preprocess = Preprocess(datafile, linefile, conversationfile, MOVIE_LINES_FIELDS, MOVIE_CONVERSATIONS_FIELDS)
    preprocess.loadLines()
    preprocess.loadConversations()
    preprocess.writeCSV()

    # Load/Assemble voc and pairs
    save_dir = os.path.join("data", "save")
    dataset = Dataset(corpus, corpus_name, datafile)
    voc, pairs = dataset.loadPrepareData()
    
    # # Print some pairs to validate
    # print("\npairs:")
    # for pair in pairs[:10]:
    #   print(pair)

    # Trim voc and pairs
    pairs = dataset.trimRareWords(voc, pairs, MIN_COUNT)

    # Example for validation
    small_batch_size = 5
    batches = dataset.batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
    input_variable, lengths, target_variable, mask, max_target_len = batches

    print("input_variable:", input_variable)
    print("lengths:", lengths)
    print("target_variable:", target_variable)
    print("mask:", mask)
    print("max_target_len:", max_target_len)

  

    # Configure models
    model_name = 'cb_model'
    attn_model = 'dot'
    #attn_model = 'general'
    #attn_model = 'concat'
    hidden_size = 500
    encoder_n_layers = 2
    decoder_n_layers = 2
    dropout = 0.1
    batch_size = 64

    # Set checkpoint to load from; set to None if starting from scratch
    loadFilename = None
    checkpoint_iter = 4000
    #loadFilename = os.path.join(save_dir, model_name, corpus_name,
    #                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
    #                            '{}_checkpoint.tar'.format(checkpoint_iter))

    if loadFilename:
        # If loading on same machine the model was trained on
        checkpoint = torch.load(loadFilename)
        # If loading a model trained on GPU to CPU
        #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
        encoder_sd = checkpoint['en']
        decoder_sd = checkpoint['de']
        encoder_optimizer_sd = checkpoint['en_opt']
        decoder_optimizer_sd = checkpoint['de_opt']
        embedding_sd = checkpoint['embedding']
        voc.__dict__ = checkpoint['voc_dict']

    print('Building encoder and decoder ...')
    # Initialize word embeddings
    embedding = nn.Embedding(voc.num_words, hidden_size)
    if loadFilename:
        embedding.load_state_dict(embedding_sd)
    # Initialize encoder & decoder models
    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
    if loadFilename:
        encoder.load_state_dict(encoder_sd)
        decoder.load_state_dict(decoder_sd)
    # Use appropriate device
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    print('Models built and ready to go!')

    # Configure training/optimization
    clip = 50.0
    teacher_forcing_ratio = 1.0
    learning_rate = 0.0001
    decoder_learning_ratio = 5.0
    n_iteration = 4000
    print_every = 1
    save_every = 500

    # Ensure dropout layers are in train mode
    encoder.train()
    decoder.train()

    # Initialize optimizers
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(encoder_optimizer_sd)
        decoder_optimizer.load_state_dict(decoder_optimizer_sd)

    # Run training iterations
    print("Starting Training!")
    model = Model(dataset.batch2TrainData, teacher_forcing_ratio)
    model.trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
                     embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
                     print_every, save_every, clip, corpus_name, loadFilename)

    # Set dropout layers to eval mode
    encoder.eval()
    decoder.eval()

    # Initialize search module
    searcher = GreedySearchDecoder(encoder, decoder)