예제 #1
0
def train_attention_captioner():
    print("Training The Attention Capitoner ... ")
    # Create model directory
    if not os.path.exists(path_trained_model):
        os.makedirs(path_trained_model)

    # Image preprocessing, first resize the input image then do normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.Resize((input_resnet_size, input_resnet_size),
                          interpolation=Image.ANTIALIAS),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Loading pickle dictionary
    with open(dict_path, 'rb') as file:
        dictionary = pickle.load(file)

    # Build data loader
    data_loader = get_loader(imgs_path,
                             data_caps,
                             dictionary,
                             transform,
                             BATCH_SIZE,
                             shuffle=True,
                             num_workers=2)

    # Building the Models
    encoder = EncoderCNN(word_embedding_size).to(device)
    attn_decoder = AttnDecoderRNN(word_embedding_size, len(dictionary[0]))

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(attn_decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=LEARN_RATE)

    word2idx = dictionary[0]
    # Initiazling the decoder hidden and output
    decoder_input = torch.tensor([[word2idx['START']]]).to(device)
    decoder_hidden = torch.zeros(word_embedding_size).to(device)

    total_steps = len(data_loader)
    for epcoh in range(NUM_EPOCHS):
        for i, (images, captions, lengths) in enumerate(data_loader):

            print(images.Size, captions.Size, lengths.Size)

            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            features = encoder(images)
            decoder_output, decoder_hidden, attn_weights = attn_decoder(
                decoder_input, decoder_hidden, features)
예제 #2
0
파일: eval2.py 프로젝트: mharwani/AMR-text
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--dev_files', default='../amr_anno_1.0/data/split/dev/*',
                    help='dev files.')
    ap.add_argument('--log_dir', default='./log',
                    help='log directory')
    ap.add_argument('--exp_name', default='experiment',
                    help='experiment name')
    args = ap.parse_args()
    
    #read dev files
    dev_files = glob.glob(args.dev_files)
    dev_pairs = AMR.read_AMR_files(dev_files, True)
    
    logdir = args.log_dir
    exp_dir = logdir + '/' + args.exp_name
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    
    max_iter = 0
    dev_bleu = 0.0
    while True:
        load_state_file = None
        state_files = glob.glob(exp_dir + '/*')
        for sf in state_files:
            iter_num = int(sf.split('_')[1].split('.')[0])
            if iter_num > max_iter:
                max_iter = iter_num
                load_state_file = sf
        if load_state_file is not None:
            state = torch.load(load_state_file)
            amr_vocab = state['amr_vocab']
            en_vocab = state['en_vocab']
            hidden_size = state['hidden_size']
            edge_size = state['edge_size']
            drop = state['dropout']
            mlength = state['max_length']
            logging.info('loaded checkpoint %s', load_state_file)
            
            encoder = EncoderRNN(amr_vocab.n_nodes, hidden_size).to(device)
            child_sum = ChildSum(amr_vocab.n_edges, edge_size, hidden_size).to(device)
            decoder = AttnDecoderRNN(hidden_size, en_vocab.n_words, dropout_p=drop, max_length=mlength).to(device)
            encoder.load_state_dict(state['enc_state'])
            child_sum.load_state_dict(state['sum_state'])
            decoder.load_state_dict(state['dec_state'])
            # translate from the dev set
            translate_random_amr(encoder, child_sum, decoder, dev_pairs, amr_vocab, en_vocab, mlength, n=10)
            translated_amrs = translate_amrs(encoder, child_sum, decoder, dev_pairs, amr_vocab, en_vocab, mlength)
            references = [[pair[0]] for pair in dev_pairs[:len(translated_amrs)]]
            candidates = [sent.split() for sent in translated_amrs]
            dev_bleu = corpus_bleu(references, candidates)
            logging.info('Dev BLEU score: %.2f', dev_bleu)
        else:
            logging.info('No new checkpoint found. Last DEV BLEU score: %.2f', dev_bleu)
        
        time.sleep(20)
예제 #3
0
def main():
    parser = argparse.ArgumentParser("English - Lojban translation")
    parser.add_argument("--source", default='loj', help="source language data")
    parser.add_argument("--target", default='en', help="target language data")
    parser.add_argument("--iters",
                        type=int,
                        default=100000,
                        help="number of iterations to train")
    parser.add_argument("--no-train",
                        type=bool,
                        default=False,
                        help="Do not perform training. Only validation")
    parser.add_argument("--pretrain-encoder",
                        help="Path to pretrained encoder")
    parser.add_argument("--pretrain-decoder",
                        help="Path to pretrained decoder")
    parser.add_argument(
        "--pretrain-input-words",
        type=int,
        help="Number of source language words in pretrained model")
    parser.add_argument(
        "--pretrain-output-words",
        type=int,
        help="Number of target language words in pretrained model")
    parser.add_argument("--encoder-ckpt",
                        default="encoder.pth",
                        help="Name of encoder checkpoint filename")
    parser.add_argument("--decoder-ckpt",
                        default="decoder.pth",
                        help="Name of decoder checkpoint filename")
    parser.add_argument("--prefix",
                        default='',
                        help='Prefix, added to data files')
    args = parser.parse_args()

    input_lang, output_lang, pairs, pairs_val = prepare_data(
        args.source, args.target, prefix=args.prefix)
    langs = (input_lang, output_lang)
    print(random.choice(pairs))

    input_words = args.pretrain_input_words or input_lang.n_words
    output_words = args.pretrain_output_words or output_lang.n_words

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size, output_lang.n_words,
                             dropout_p=0.1).to(device)

    if args.pretrain_encoder and args.pretrain_decoder:
        load_pretrained_model(encoder, decoder, args.pretrain_encoder,
                              args.pretrain_decoder)

    if not args.no_train:
        train(encoder, decoder, args.iters, pairs, langs, print_every=5000)
        torch.save(encoder.state_dict(), args.encoder_ckpt)
        torch.save(decoder.state_dict(), args.decoder_ckpt)

    evaluate_all(encoder, decoder, pairs_val, langs)
예제 #4
0
def main(args):
    config_path = os.path.join(args.config_path, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    print('[-] Loading pickles')
    dataset_path = Path(config["dataset_path"])
    input_lang = CustomUnpickler(open(dataset_path / 'input_lang.pkl', 'rb')).load()
    output_lang = CustomUnpickler(open(dataset_path / 'output_lang.pkl', 'rb')).load()
    pairs = CustomUnpickler(open(dataset_path / 'pairs.pkl', 'rb')).load()

    # input_lang = load_pkl(dataset_path / 'input_lang.pkl')
    # output_lang = load_pkl(dataset_path / 'output_lang.pkl')
    # pairs = load_pkl(dataset_path / 'pairs.pkl')

    max_len = config["max_len"]
    lr = config["model_cfg"]["lr"]
    hidden_size = config["model_cfg"]["hidden_size"]
    train_iters = args.train_iters
    device = torch.device("cuda:%s" % args.ordinal if torch.cuda.is_available() else "cpu")

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, max_len, dropout_p=0.1).to(device)

    trainer = Trainer(device, encoder, attn_decoder, input_lang, output_lang, pairs, max_len, lr,
                      ckpt_path=config["models_path"])
    if args.load_models:
        trainer.load_models()
    trainer.run_epoch(train_iters)
예제 #5
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SOS_token = 0
    EOS_token = 1
    MASKED_token = 2
    MAX_LENGTH = 42

    hidden_size = 325
    train_iters = 20
    pretrain_train_iters = 2000
    dataset = 'imdb'
    lang_filename = './data/' + dataset + '_lang.pkl'

    if os.path.exists(lang_filename):
        with open(lang_filename, 'rb') as file:
            (lang, lines) = pkl.load(file)
    else:
        lang, lines = prepareData(dataset)
        with open(lang_filename, 'wb') as file:
            pkl.dump((lang, lines), file)

    pretrained_filename = './pretrained/pretrained_lstm_' + dataset + '_' + str(
        hidden_size) + '_' + str(pretrain_train_iters) + '.pkl'

    model_filename = './pretrained/maskmle_' + dataset + '_' + str(
        hidden_size) + '_' + str(train_iters) + '.pkl'

    if os.path.exists(pretrained_filename):
        with open(pretrained_filename, 'rb') as file:
            pretainedlstm = pkl.load(file)
    else:
        raise NotImplementedError('pretrained lstm is not available')

    encoder1 = EncoderRNN(lang.n_words, hidden_size).to(device)
    attn_decoder1 = AttnDecoderRNN(hidden_size, lang.n_words,
                                   dropout_p=0.1).to(device)
    print("Total number of trainable parameters:",
          count_parameters(encoder1) + count_parameters(attn_decoder1))

    def copy_lstm_weights(from_, *args):
        for to_ in args:
            to_.weight_ih_l0 = from_.weight_ih_l0
            to_.weight_hh_l0 = from_.weight_hh_l0
            to_.bias_ih_l0 = from_.bias_ih_l0
            to_.bias_hh_l0 = from_.bias_hh_l0

    copy_lstm_weights(pretainedlstm.lstm, encoder1.lstm, attn_decoder1.lstm)
    #copy_lstm_weights(pretainedlstm.lstm, attn_decoder1.lstm)

    encoder1.embedding.weight = pretainedlstm.embedding.weight
    attn_decoder1.embedding.weight = pretainedlstm.embedding.weight

    trainIters(encoder1,
               attn_decoder1,
               lang,
               lines,
               train_iters,
               print_every=train_iters // 20,
               plot_every=train_iters // 20)
예제 #6
0
def main():
    dataset = 'imdb'
    hidden_size = 325
    train_iters = 40
    pretrain_train_iters = 40
    lang, lines = cachePrepareData(dataset)

    PATH = './pretrained/'
    pretrained_filename = PATH + 'pretrained_lstm_' + dataset + '_' + str(hidden_size) + '_' + str(pretrain_train_iters) + '.pt'
    
    model_filename = 'maskmle_' + dataset + '_' + str(hidden_size) + '_' + str(train_iters) + '.pt'
    
    encoder1 = EncoderRNN(lang.n_words, hidden_size).to(device)
    encoder1.load_state_dict(torch.load(PATH + 'e_' + model_filename))
    
    attn_decoder1 = AttnDecoderRNN(hidden_size, lang.n_words, dropout_p=0.1).to(device)
    attn_decoder1.load_state_dict(torch.load(PATH + 'd_' + model_filename))
    print(evaluateRandomly(encoder1, attn_decoder1, lang, lines, 20, 0.5))
예제 #7
0
def main(args):
    config_path = os.path.join(args.config_path, 'config.json')
    with open(config_path) as f:
        config = json.load(f)

    print('[-] Loading pickles')
    dataset_path = Path(config["dataset_path"])
    input_lang = CustomUnpickler(open(dataset_path / 'input_lang.pkl',
                                      'rb')).load()
    output_lang = CustomUnpickler(open(dataset_path / 'output_lang.pkl',
                                       'rb')).load()
    pairs = CustomUnpickler(open(dataset_path / 'pairs.pkl', 'rb')).load()
    # input_lang = load_pkl(dataset_path / 'input_lang.pkl')
    # output_lang = load_pkl(dataset_path / 'output_lang.pkl')
    # pairs = load_pkl(dataset_path / 'pairs.pkl')

    hidden_size = config["model_cfg"]["hidden_size"]
    max_len = config["max_len"]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size,
                             output_lang.n_words,
                             max_len,
                             dropout_p=0.1).to(device)

    print('[-] Loading models')
    ckpt = torch.load(config["models_path"] + 'models.ckpt')
    encoder.load_state_dict(ckpt['encoder'])
    encoder.to(device)
    decoder.load_state_dict(ckpt['decoder'])
    decoder.to(device)

    evaluator = Evaluater(device, encoder, decoder, input_lang, output_lang,
                          max_len)

    # Evaluate random samples
    evaluator.evaluateRandomly(pairs)

    evaluator.evaluateAndShowAttention("elle a cinq ans de moins que moi .")
    # evaluator.evaluateAndShowAttention("elle est trop petit .")
    # evaluator.evaluateAndShowAttention("je ne crains pas de mourir .")
    # evaluator.evaluateAndShowAttention("c est un jeune directeur plein de talent .")
    plt.savefig('attention.png')
예제 #8
0
def main():
    # ArgumentParser {{{
    parser = argparse.ArgumentParser()
    # hyper parameters
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--n_epochs', type=int, default=10)
    parser.add_argument('--enc_embd_size', type=int, default=256)
    parser.add_argument('--dec_embd_size', type=int, default=256)
    parser.add_argument('--enc_h_size', type=int, default=512)
    parser.add_argument('--dec_h_size', type=int, default=512)
    # other parameters
    parser.add_argument('--beam_width', type=int, default=3)
    parser.add_argument('--n_best', type=int, default=3)
    parser.add_argument('--max_dec_steps', type=int, default=1000)
    parser.add_argument('--export_dir', type=str, default=modelpath)
    parser.add_argument('--model_name', type=str, default='s2s')
    parser.add_argument('--model_path',
                        type=str,
                        default=modelpath / 's2s-vanilla.pt')
    parser.add_argument('--skip_train', action='store_true')
    parser.add_argument('--attention', action='store_true')
    opts = parser.parse_args()
    # }}}
    # opts.skip_train = True
    opts.attention = True

    # SOS_token = '<SOS>'
    # EOS_token = '<EOS>'
    # SRC = Field(tokenize=tokenize_de,
    #             init_token=SOS_token,
    #             eos_token=EOS_token,
    #             lower=True)
    # TRG = Field(tokenize=tokenize_en,
    #             init_token=SOS_token,
    #             eos_token=EOS_token,
    #             lower=True)
    # train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC, TRG))
    # print(f'Number of training examples: {len(train_data.examples)}')
    # print(f'Number of validation examples: {len(valid_data.examples)}')
    # print(f'Number of testing examples: {len(test_data.examples)}')

    # SRC.build_vocab(train_data, min_freq=2)
    # TRG.build_vocab(train_data, min_freq=2)
    # print(f'Unique tokens in source (de) vocabulary: {len(SRC.vocab)}')
    # print(f'Unique tokens in target (en) vocabulary: {len(TRG.vocab)}')
    # train_itr, valid_itr, test_itr =\
    #         BucketIterator.splits(
    #             (train_data, valid_data, test_data),
    #             batch_size=opts.batch_size,
    #             device=DEVICE)

    # exit

    train_dataset, valid_dataset, test_dataset = Multi30k(root=dataroot)
    train_dataset1, train_dataset2 = tee(train_dataset)
    valid_dataset1, valid_dataset2 = tee(valid_dataset)
    test_dataset1, test_dataset2 = tee(test_dataset)

    spacy_de = spacy.load('de_core_news_sm')
    spacy_en = spacy.load('en_core_web_sm')
    de_counter = Counter()
    en_counter = Counter()
    de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
    en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

    def build_vocab(dataset):
        for (src_sentence, tgt_sentence) in tqdm(dataset):
            de_counter.update(de_tokenizer(src_sentence))
            en_counter.update(en_tokenizer(tgt_sentence))

    def data_process(dataset):
        data = []
        for (raw_de, raw_en) in tqdm(dataset):
            de_tensor_ = torch.tensor(
                [de_vocab[token] for token in de_tokenizer(raw_de)],
                dtype=torch.long)
            en_tensor_ = torch.tensor(
                [en_vocab[token] for token in en_tokenizer(raw_en)],
                dtype=torch.long)
            data.append((de_tensor_, en_tensor_))
        return data

    def generate_batch(data_batch):
        de_batch, en_batch = [], []
        for (de_item, en_item) in data_batch:
            de_batch.append(
                torch.cat([
                    torch.tensor([TRG_SOS_IDX]), de_item,
                    torch.tensor([TRG_EOS_IDX])
                ],
                          dim=0))
            en_batch.append(
                torch.cat([
                    torch.tensor([TRG_SOS_IDX]), en_item,
                    torch.tensor([TRG_EOS_IDX])
                ],
                          dim=0))
        de_batch = pad_sequence(de_batch, padding_value=TRG_PAD_IDX)
        en_batch = pad_sequence(en_batch, padding_value=TRG_PAD_IDX)
        return de_batch, en_batch

    build_vocab(train_dataset1)
    build_vocab(valid_dataset1)
    build_vocab(test_dataset1)
    de_vocab = Vocab(de_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    en_vocab = Vocab(en_counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

    dec_v_size = len(de_vocab)
    enc_v_size = len(en_vocab)

    TRG_PAD_IDX = en_vocab.stoi['<pad>']
    TRG_SOS_IDX = en_vocab.stoi['<bos>']
    TRG_EOS_IDX = en_vocab.stoi['<eos>']

    train_data = data_process(train_dataset2)
    valid_data = data_process(valid_dataset2)
    test_data = data_process(test_dataset2)

    train_itr = DataLoader(train_data,
                           batch_size=opts.batch_size,
                           shuffle=False,
                           collate_fn=generate_batch)
    valid_itr = DataLoader(valid_data,
                           batch_size=opts.batch_size,
                           shuffle=False,
                           collate_fn=generate_batch)
    test_itr = DataLoader(test_data,
                          batch_size=opts.batch_size,
                          shuffle=False,
                          collate_fn=generate_batch)

    encoder = EncoderRNN(opts.enc_embd_size, opts.enc_h_size, opts.dec_h_size,
                         dec_v_size, DEVICE)

    if opts.attention:
        attn = Attention(opts.enc_h_size, opts.dec_h_size)
        decoder = AttnDecoderRNN(opts.dec_embd_size, opts.enc_h_size,
                                 opts.dec_h_size, enc_v_size, attn, DEVICE)
    else:
        decoder = DecoderRNN(opts.dec_embd_size, opts.dec_h_size, enc_v_size,
                             DEVICE)
    model = Seq2Seq(encoder, decoder, DEVICE).to(DEVICE)

    # TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
    # TRG_PAD_IDX = tgt_vocab.stoi['<pad>']

    if opts.skip_train:
        model.load_state_dict(torch.load(opts.model_path))

    if not opts.skip_train:
        optimizer = optim.Adam(model.parameters())
        criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
        best_valid_loss = float('inf')
        for epoch in range(opts.n_epochs):
            start_time = time.time()

            train_loss = train(model, train_itr, optimizer, criterion)
            valid_loss = evaluate(model, valid_itr, criterion)

            end_time = time.time()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                attn_type = 'attn' if opts.attention else 'vanilla'
                model_path = os.path.join(opts.export_dir,
                                          f'{opts.model_name}-{attn_type}.pt')
                print(f'Update model! Saved {model_path}')
                torch.save(model.state_dict(), model_path)
            else:
                print('Model was not updated. Stop training')
                break

            print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
            print(
                f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}'
            )
            print(
                f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}'
            )

    # TRG_SOS_IDX = TRG.vocab.stoi[TRG.init_token]
    # TRG_EOS_IDX = TRG.vocab.stoi[TRG.eos_token]
    model.eval()
    with torch.no_grad():
        # for batch_id, batch in enumerate(test_itr):
        for batch in tqdm(test_itr):
            # src = batch.src # (T, bs)
            # trg = batch.trg # (T, bs)
            src, trg = batch
            src = src.to(DEVICE)
            trg = src.to(DEVICE)
            print(f'In: {" ".join(de_vocab.itos[idx] for idx in src[:, 0])}')

            enc_outs, h = model.encoder(src)  # (T, bs, H), (bs, H)
            # decoded_seqs: (bs, T)
            start_time = time.time()
            decoded_seqs = beam_search_decoding(
                decoder=model.decoder,
                enc_outs=enc_outs,
                enc_last_h=h,
                beam_width=opts.beam_width,
                n_best=opts.n_best,
                sos_token=TRG_SOS_IDX,
                eos_token=TRG_EOS_IDX,
                max_dec_steps=opts.max_dec_steps,
                device=DEVICE)
            end_time = time.time()
            print(f'for loop beam search time: {end_time-start_time:.3f}')
            print_n_best(decoded_seqs[0], en_vocab.itos)

            start_time = time.time()
            decoded_seqs = batch_beam_search_decoding(
                decoder=model.decoder,
                enc_outs=enc_outs,
                enc_last_h=h,
                beam_width=opts.beam_width,
                n_best=opts.n_best,
                sos_token=TRG_SOS_IDX,
                eos_token=TRG_EOS_IDX,
                max_dec_steps=opts.max_dec_steps,
                device=DEVICE)
            end_time = time.time()
            print(f'Batch beam search time: {end_time-start_time:.3f}')
            print_n_best(decoded_seqs[0], en_vocab.itos)
예제 #9
0
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# load data
from loaders import prepareReportData

input_lang, output_lang, ds = prepareReportData()

# instantiate models
from models import EncoderRNN, AttnDecoderRNN

hidden_size = 256
max_length = 30
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size,
                               output_lang.n_words,
                               max_length,
                               dropout_p=0.1).to(device)

# train
from train_utils import trainIters

print("\ntraining on {}..".format(device))
n_iter = 500000
print_every = 10000
plot_losses = trainIters(ds,
                         encoder1,
                         attn_decoder1,
                         n_iter,
                         max_length,
                         input_lang,
                         output_lang,
예제 #10
0
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument(
        '--hidden_size',
        default=256,
        type=int,
        help='hidden size of encoder/decoder, also word vector size')
    ap.add_argument('--edge_size',
                    default=20,
                    type=int,
                    help='embedding dimension of edges')
    ap.add_argument('--n_iters',
                    default=100000,
                    type=int,
                    help='total number of examples to train on')
    ap.add_argument('--print_every',
                    default=5000,
                    type=int,
                    help='print loss info every this many training examples')
    ap.add_argument(
        '--checkpoint_every',
        default=10000,
        type=int,
        help='write out checkpoint every this many training examples')
    ap.add_argument('--initial_learning_rate',
                    default=0.001,
                    type=int,
                    help='initial learning rate')
    ap.add_argument('--train_files',
                    default='../amr_anno_1.0/data/split/training/*',
                    help='training files.')
    ap.add_argument('--log_dir', default='./log', help='log directory')
    ap.add_argument('--exp_name', default='experiment', help='experiment name')
    ap.add_argument('--batch_size', default=5, type=int, help='batch size')
    ap.add_argument('--load_checkpoint',
                    action='store_true',
                    help='use existing checkpoint')

    args = ap.parse_args()

    logdir = args.log_dir
    exp_dir = logdir + '/' + args.exp_name
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

    load_state_file = None
    if args.load_checkpoint:
        max_iter = 0
        state_files = glob.glob(exp_dir + '/*')
        for sf in state_files:
            iter_num = int(sf.split('_')[1].split('.')[0])
            if iter_num > max_iter:
                max_iter = iter_num
                load_state_file = sf
    # Create vocab from training data
    iter_num = 0
    train_files = glob.glob(args.train_files)
    train_pairs = AMR.read_AMR_files(train_files, True)
    amr_vocab, en_vocab = None, None
    state = None
    batch_size = args.batch_size
    hidden_size = args.hidden_size
    edge_size = args.edge_size
    drop = DROPOUT_P
    mlength = MAX_LENGTH
    if load_state_file is not None:
        state = torch.load(load_state_file)
        iter_num = state['iter_num']
        amr_vocab = state['amr_vocab']
        en_vocab = state['en_vocab']
        hidden_size = state['hidden_size']
        edge_size = state['edge_size']
        drop = state['dropout']
        mlength = state['max_length']
        logging.info('loaded checkpoint %s', load_state_file)
    else:
        amr_vocab, en_vocab = make_vocabs(train_pairs)
    encoder = EncoderRNN(amr_vocab.n_nodes, hidden_size).to(device)
    child_sum = ChildSum(amr_vocab.n_edges, edge_size, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size,
                             en_vocab.n_words,
                             dropout_p=drop,
                             max_length=mlength).to(device)

    #load checkpoint
    if state is not None:
        encoder.load_state_dict(state['enc_state'])
        child_sum.load_state_dict(state['sum_state'])
        decoder.load_state_dict(state['dec_state'])

    # set up optimization/loss
    params = list(encoder.parameters()) + list(child_sum.parameters()) + list(
        decoder.parameters())  # .parameters() returns generator
    optimizer = optim.Adam(params, lr=args.initial_learning_rate)
    criterion = nn.NLLLoss()

    #load checkpoint
    if state is not None:
        optimizer.load_state_dict(state['opt_state'])

    start = time.time()
    print_loss_total = 0  # Reset every args.print_every

    while iter_num < args.n_iters:
        num_samples = batch_size
        remaining = args.checkpoint_every - (iter_num % args.checkpoint_every)
        remaining2 = args.print_every - (iter_num % args.print_every)
        if remaining < batch_size:
            num_samples = remaining
        elif remaining2 < batch_size:
            num_samples = remaining2
        iter_num += num_samples
        random_pairs = random.sample(train_pairs, num_samples)
        target_snt = tensors_from_batch(en_vocab, random_pairs)
        loss = train(random_pairs, target_snt, amr_vocab, encoder, child_sum,
                     decoder, optimizer, criterion)
        print_loss_total += loss

        if iter_num % args.checkpoint_every == 0:
            state = {
                'iter_num': iter_num,
                'enc_state': encoder.state_dict(),
                'sum_state': child_sum.state_dict(),
                'dec_state': decoder.state_dict(),
                'opt_state': optimizer.state_dict(),
                'amr_vocab': amr_vocab,
                'en_vocab': en_vocab,
                'hidden_size': hidden_size,
                'edge_size': edge_size,
                'dropout': drop,
                'max_length': mlength
            }
            filename = 'state_%010d.pt' % iter_num
            save_file = exp_dir + '/' + filename
            torch.save(state, save_file)
            logging.debug('wrote checkpoint to %s', save_file)

        if iter_num % args.print_every == 0:
            print_loss_avg = print_loss_total / args.print_every
            print_loss_total = 0
            logging.info(
                'time since start:%s (iter:%d iter/n_iters:%d%%) loss_avg:%.4f',
                time.time() - start, iter_num, iter_num / args.n_iters * 100,
                print_loss_avg)
예제 #11
0
    # load model($python3 train.py encoder decoder)
    if len(sys.argv) == 3:
        if use_cuda:
            imdb_encoder = torch.load(sys.argv[1])
            imdb_decoder = torch.load(sys.argv[2])
        else:
            imdb_encoder = torch.load(sys.argv[1],
                                      map_location={'cuda:0': 'cpu'})
            imdb_decoder = torch.load(sys.argv[2],
                                      map_location={'cuda:0': 'cpu'})
    else:
        imdb_encoder = EncoderRNN(input_lang.n_words, hidden_size,
                                  embedding_matrix)
        imdb_decoder = AttnDecoderRNN(hidden_size,
                                      output_lang.n_words,
                                      1,
                                      dropout_p=0.1)

    if use_cuda:
        imdb_encoder = imdb_encoder.cuda()
        imdb_decoder = imdb_decoder.cuda()

    trainIters(imdb_encoder,
               imdb_decoder,
               1500000,
               print_every=100,
               plot_every=100,
               learning_rate=0.01)

    # save model
    torch.save(
예제 #12
0
def datasetBleu(multi_pairs_test, encoder, decoder):
    cumulative_bleu = 0.
    for source, targets in multi_pairs_test.items():
        translation_list, _ = evaluate(encoder, decoder, source)
        # output is a list and last one is EOS: we convert to string
        translation = ' '.join(translation_list[:-1])
        cumulative_bleu += bleu_score(translation, targets)

    return cumulative_bleu / len(multi_pairs_test.keys())


hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size, device).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size,
                               output_lang.n_words,
                               dropout_p=0.1,
                               device=device,
                               uniform_attention=config.use_uniform).to(device)

# train or load if possible

# create folder if not there!
import os
if not os.path.exists('saved_models'):
    os.makedirs('saved_models')

att_type = 'att' if not config.use_uniform else 'uni'
try:
    encoder1 = torch.load(
        f'./saved_models/{config.source}_encoder_{att_type}.pt')
    attn_decoder1 = torch.load(
예제 #13
0
architecture = 'positional'

## Create Byte Pair Encodings
if create_bpe:
    make_bpe()

## Load corpus
input_lang, output_lang, pairs = load_data('train')

## Define encoder and decoder
if architecture == 'AIAYN':
    hidden_size = word_embed_size + pos_embed_size
    encoder = EncoderPositional_AIAYN(input_lang.n_words, word_embed_size,
                                      pos_embed_size, max_sent_len)
    decoder = AttnDecoderRNN(hidden_size,
                             output_lang.n_words,
                             max_sent_len,
                             dropout_p=0.1)

elif architecture == 'positional':
    hidden_size = word_embed_size + pos_embed_size
    encoder = EncoderPositional(input_lang.n_words, word_embed_size,
                                pos_embed_size, max_sent_len)
    decoder = AttnDecoderRNN(hidden_size,
                             output_lang.n_words,
                             max_sent_len,
                             dropout_p=0.1)

elif architecture == 'LSTM':
    hidden_size = word_embed_size
    encoder = EncoderLSTM(input_lang.n_words, hidden_size)
    decoder = AttnDecoderRNN(hidden_size,
예제 #14
0
def main():
    # ArgumentParser {{{
    parser = argparse.ArgumentParser()
    # hyper parameters
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--n_epochs', type=int, default=10)
    parser.add_argument('--enc_embd_size', type=int, default=256)
    parser.add_argument('--dec_embd_size', type=int, default=256)
    parser.add_argument('--enc_h_size', type=int, default=512)
    parser.add_argument('--dec_h_size', type=int, default=512)
    # other parameters
    parser.add_argument('--beam_width', type=int, default=10)
    parser.add_argument('--n_best', type=int, default=5)
    parser.add_argument('--max_dec_steps', type=int, default=1000)
    parser.add_argument('--export_dir', type=str, default='./ckpts/')
    parser.add_argument('--model_name', type=str, default='s2s')
    parser.add_argument('--model_path', type=str, default='')
    parser.add_argument('--skip_train', action='store_true')
    parser.add_argument('--attention', action='store_true')
    opts = parser.parse_args()
    # }}}

    SOS_token = '<SOS>'
    EOS_token = '<EOS>'
    SRC = Field(tokenize=tokenize_de,
                init_token=SOS_token,
                eos_token=EOS_token,
                lower=True)
    TRG = Field(tokenize=tokenize_en,
                init_token=SOS_token,
                eos_token=EOS_token,
                lower=True)
    train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'),
                                                        fields=(SRC, TRG))
    print(f'Number of training examples: {len(train_data.examples)}')
    print(f'Number of validation examples: {len(valid_data.examples)}')
    print(f'Number of testing examples: {len(test_data.examples)}')

    SRC.build_vocab(train_data, min_freq=2)
    TRG.build_vocab(train_data, min_freq=2)
    print(f'Unique tokens in source (de) vocabulary: {len(SRC.vocab)}')
    print(f'Unique tokens in target (en) vocabulary: {len(TRG.vocab)}')

    train_itr, valid_itr, test_itr =\
            BucketIterator.splits(
                (train_data, valid_data, test_data),
                batch_size=opts.batch_size,
                device=DEVICE)

    enc_v_size = len(SRC.vocab)
    dec_v_size = len(TRG.vocab)

    encoder = EncoderRNN(opts.enc_embd_size, opts.enc_h_size, opts.dec_h_size,
                         enc_v_size, DEVICE)
    if opts.attention:
        attn = Attention(opts.enc_h_size, opts.dec_h_size)
        decoder = AttnDecoderRNN(opts.dec_embd_size, opts.enc_h_size,
                                 opts.dec_h_size, dec_v_size, attn, DEVICE)
    else:
        decoder = DecoderRNN(opts.dec_embd_size, opts.dec_h_size, dec_v_size,
                             DEVICE)
    model = Seq2Seq(encoder, decoder, DEVICE).to(DEVICE)

    TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

    if opts.model_path != '':
        model.load_state_dict(torch.load(opts.model_path))

    if not opts.skip_train:
        optimizer = optim.Adam(model.parameters())
        criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
        best_valid_loss = float('inf')
        for epoch in range(opts.n_epochs):
            start_time = time.time()

            train_loss = train(model, train_itr, optimizer, criterion)
            valid_loss = evaluate(model, valid_itr, criterion)

            end_time = time.time()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                attn_type = 'attn' if opts.attention else 'vanilla'
                model_path = os.path.join(opts.export_dir,
                                          f'{opts.model_name}-{attn_type}.pt')
                print(f'Update model! Saved {model_path}')
                torch.save(model.state_dict(), model_path)
            else:
                print('Model was not updated. Stop training')
                break

            print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
            print(
                f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}'
            )
            print(
                f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}'
            )

    TRG_SOS_IDX = TRG.vocab.stoi[TRG.init_token]
    TRG_EOS_IDX = TRG.vocab.stoi[TRG.eos_token]
    model.eval()
    with torch.no_grad():
        for batch_id, batch in enumerate(test_itr):
            src = batch.src  # (T, bs)
            trg = batch.trg  # (T, bs)
            print(f'In: {" ".join(SRC.vocab.itos[idx] for idx in src[:, 0])}')

            enc_outs, h = model.encoder(src)  # (T, bs, H), (bs, H)
            # decoded_seqs: (bs, T)
            start_time = time.time()
            decoded_seqs = beam_search_decoding(
                decoder=model.decoder,
                enc_outs=enc_outs,
                enc_last_h=h,
                beam_width=opts.beam_width,
                n_best=opts.n_best,
                sos_token=TRG_SOS_IDX,
                eos_token=TRG_EOS_IDX,
                max_dec_steps=opts.max_dec_steps,
                device=DEVICE)
            end_time = time.time()
            print(f'for loop beam search time: {end_time-start_time:.3f}')
            print_n_best(decoded_seqs[0], TRG.vocab.itos)

            start_time = time.time()
            decoded_seqs = batch_beam_search_decoding(
                decoder=model.decoder,
                enc_outs=enc_outs,
                enc_last_h=h,
                beam_width=opts.beam_width,
                n_best=opts.n_best,
                sos_token=TRG_SOS_IDX,
                eos_token=TRG_EOS_IDX,
                max_dec_steps=opts.max_dec_steps,
                device=DEVICE)
            end_time = time.time()
            print(f'Batch beam search time: {end_time-start_time:.3f}')
            print_n_best(decoded_seqs[0], TRG.vocab.itos)
예제 #15
0
def main():

    # TODO CHECK THE EFFECT OF LOADING DATA HERE:
    input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
    lang_pack = input_lang, output_lang, pairs
    print(random.choice(pairs))

    hidden_size = 2
    encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    attn_decoder1 = AttnDecoderRNN(hidden_size,
                                   output_lang.n_words,
                                   dropout_p=0.1).to(device)

    trainIters(encoder1,
               attn_decoder1,
               100,
               print_every=5000,
               plot_every=1,
               lang_pack=lang_pack)

    ######################################################################
    #

    evaluateRandomly(encoder1, attn_decoder1, lang_pack=lang_pack)

    output_words, attentions = evaluate(encoder1,
                                        attn_decoder1,
                                        "je suis trop froid .",
                                        lang_pack=lang_pack)

    ######################################################################
    # For a better viewing experience we will do the extra work of adding axes
    # and labels:
    #

    def showAttention(input_sentence, output_words, attentions):
        # Set up figure with colorbar
        fig = plt.figure()
        ax = fig.add_subplot(111)
        cax = ax.matshow(attentions.numpy(), cmap='bone')
        fig.colorbar(cax)

        # Set up axes
        ax.set_xticklabels([''] + input_sentence.split(' ') + ['<EOS>'],
                           rotation=90)
        ax.set_yticklabels([''] + output_words)

        # Show label at every tick
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

        plt.show()

    def evaluateAndShowAttention(input_sentence):
        output_words, attentions = evaluate(encoder1,
                                            attn_decoder1,
                                            input_sentence,
                                            lang_pack=lang_pack)
        print('input =', input_sentence)
        print('output =', ' '.join(output_words))
        showAttention(input_sentence, output_words, attentions)

    evaluateAndShowAttention("elle a cinq ans de moins que moi .")
    evaluateAndShowAttention("elle est trop petit .")
    evaluateAndShowAttention("je ne crains pas de mourir .")
    evaluateAndShowAttention("c est un jeune directeur plein de talent .")