Esempio n. 1
0
def main():
    logging.basicConfig(level=logging.INFO, format='%(message)s')

    parser = argparse.ArgumentParser(description='Train alignment model')
    parser.add_argument('--train', help='training corpus', required=True)
    parser.add_argument('--iter',
                        help='number of iterations',
                        type=int,
                        required=True)
    parser.add_argument('--charlm', help='character language model')
    parser.add_argument('--pyp',
                        help='G_w^0 is PYP(CharLM)',
                        action='store_true')
    parser.add_argument('--output', help='model output path')
    parser.add_argument(
        '--reverse',
        help='train model in reverse direction (but output f-e)',
        action='store_true')

    args = parser.parse_args()

    source_vocabulary = Vocabulary()
    source_vocabulary[NULL]
    target_vocabulary = Vocabulary()

    logging.info('Reading parallel training data')
    with open(args.train) as train:
        training_corpus = read_parallel_corpus(train, source_vocabulary,
                                               target_vocabulary, args.reverse)

    if args.charlm:
        logging.info('Preloading character language model')
        if args.charlm == 'pu':
            char_lm = PoissonUniformCharLM(target_vocabulary)
        else:
            char_lm = CharLM(args.charlm, target_vocabulary)
        if args.pyp:
            t_base = PYP(char_lm, PYPPrior(1.0, 1.0, 1.0, 1.0, 0.1, 1.0))
        else:
            t_base = char_lm
    else:
        t_base = Uniform(len(target_vocabulary))
    model = AlignmentModel(len(source_vocabulary), t_base)

    logging.info('Training alignment model')
    alignments = run_sampler(model, training_corpus, args.iter)

    if args.output:
        with open(args.output, 'w') as f:
            model.source_vocabulary = source_vocabulary
            model.target_vocabulary = target_vocabulary
            cPickle.dump(model, f, protocol=-1)

    fmt = ('{e}-{f}' if args.reverse else '{f}-{e}')
    for a, (f, e) in izip(alignments, training_corpus):
        #f_sentence = ' '.join(source_vocabulary[w] for w in f[1:])
        #e_sentence = ' '.join(target_vocabulary[w] for w in e)
        print(' '.join(
            fmt.format(f=j - 1, e=i) for i, j in enumerate(a) if j > 0))
Esempio n. 2
0
def main():
    logging.basicConfig(level=logging.INFO, format='%(message)s')

    parser = argparse.ArgumentParser(description='Train alignment model')
    parser.add_argument('--train', help='training corpus', required=True)
    parser.add_argument('--iter', help='number of iterations', type=int, required=True)
    parser.add_argument('--charlm', help='character language model')
    parser.add_argument('--pyp', help='G_w^0 is PYP(CharLM)', action='store_true')
    parser.add_argument('--output', help='model output path')
    parser.add_argument('--reverse', help='train model in reverse direction (but output f-e)', 
            action='store_true')

    args = parser.parse_args()

    source_vocabulary = Vocabulary()
    source_vocabulary[NULL]
    target_vocabulary = Vocabulary()

    logging.info('Reading parallel training data')
    with open(args.train) as train:
        training_corpus = read_parallel_corpus(train, source_vocabulary, target_vocabulary,
                args.reverse)

    if args.charlm:
        logging.info('Preloading character language model')
        if args.charlm == 'pu':
            char_lm = PoissonUniformCharLM(target_vocabulary)
        else:
            char_lm = CharLM(args.charlm, target_vocabulary)
        if args.pyp:
            t_base = PYP(char_lm, PYPPrior(1.0, 1.0, 1.0, 1.0, 0.1, 1.0))
        else:
            t_base = char_lm
    else:
        t_base = Uniform(len(target_vocabulary))
    model = AlignmentModel(len(source_vocabulary), t_base)

    logging.info('Training alignment model')
    alignments = run_sampler(model, training_corpus, args.iter)

    if args.output:
        with open(args.output, 'w') as f:
            model.source_vocabulary = source_vocabulary
            model.target_vocabulary = target_vocabulary
            cPickle.dump(model, f, protocol=-1)

    fmt = ('{e}-{f}' if args.reverse else '{f}-{e}')
    for a, (f, e) in izip(alignments, training_corpus):
        #f_sentence = ' '.join(source_vocabulary[w] for w in f[1:])
        #e_sentence = ' '.join(target_vocabulary[w] for w in e)
        print(' '.join(fmt.format(f=j-1, e=i) for i, j in enumerate(a) if j > 0))
Esempio n. 3
0
def run_sampler(model, corpus, n_iter):
    n_words = sum(len(e) for f, e in corpus)
    alignments = [None] * len(corpus)
    samples = []
    for it in range(n_iter):
        logging.info('Iteration %d/%d', it + 1, n_iter)
        for i, (f, e) in enumerate(corpus):
            if it > 0: model.decrement(f, e, alignments[i])
            alignments[i] = list(model.increment(f, e))
        if it % 10 == 0:
            logging.info('Model: %s', model)
            ll = model.log_likelihood()
            ppl = math.exp(-ll / n_words)
            logging.info('LL=%.0f ppl=%.3f', ll, ppl)
        if it % 30 == 29:
            logging.info('Resampling hyperparameters...')
            acceptance, rejection = model.resample_hyperparemeters(mh_iter)
            arate = acceptance / float(acceptance + rejection)
            logging.info('Metropolis-Hastings acceptance rate: %.4f', arate)
            logging.info('Model: %s', model)
        if it > n_iter / 10 and it % 10 == 0:
            logging.info('Estimating sample')
            samples.append(model.map_estimate())

    logging.info('Combining %d samples', len(samples))
    align = AlignmentModel.combine(samples)
    for i, (f, e) in enumerate(corpus):
        alignments[i] = list(align(f, e))
    return alignments
Esempio n. 4
0
def run_sampler(model, corpus, n_iter):
    n_words = sum(len(e) for f, e in corpus)
    alignments = [None] * len(corpus)
    samples = []
    for it in range(n_iter):
        logging.info('Iteration %d/%d', it+1, n_iter)
        for i, (f, e) in enumerate(corpus):
            if it > 0: model.decrement(f, e, alignments[i])
            alignments[i] = list(model.increment(f, e))
        if it % 10 == 0:
            logging.info('Model: %s', model)
            ll = model.log_likelihood()
            ppl = math.exp(-ll / n_words)
            logging.info('LL=%.0f ppl=%.3f', ll, ppl)
        if it % 30 == 29:
            logging.info('Resampling hyperparameters...')
            acceptance, rejection = model.resample_hyperparemeters(mh_iter)
            arate = acceptance / float(acceptance + rejection)
            logging.info('Metropolis-Hastings acceptance rate: %.4f', arate)
            logging.info('Model: %s', model)
        if it > n_iter/10 and it % 10 == 0:
            logging.info('Estimating sample')
            samples.append(model.map_estimate())

    logging.info('Combining %d samples', len(samples))
    align = AlignmentModel.combine(samples)
    for i, (f, e) in enumerate(corpus):
        alignments[i] = list(align(f, e))
    return alignments
Esempio n. 5
0
import sys
sys.path.append('../')
import torch
import numpy as np
import argparse
from model import AlignmentModel
from torch.utils.data import DataLoader
from cc_dataset import CCDataset

device = "cuda:1"

model_path = "../checkpoints/2020_12_23_1/epch0_bidx6400.pt"
model = AlignmentModel().to(device)
model_data = torch.load(model_path)["model_state_dict"]
model_data.pop("text_encoder.model.embeddings.position_ids")
model.load_state_dict(model_data)
model.eval()


class Namespace:
    def __init__(self, opts):
        self.__dict__.update(opts)


def collate_fn(batch):
    image_feats = []
    text_feats = []
    for sample in batch:
        image_feats.append(sample["img_feat"])
        text_feats.append(sample["text_feat"])
    image_feats = torch.mean(torch.stack(image_feats), dim=1).to(device)