Beispiel #1
0
def load_vocab(args):
    reader = util.get_reader(args.reader_mode)(args.train,
                                               mode=args.reader_mode,
                                               begin=BEGIN_TOKEN,
                                               end=END_TOKEN)
    vocab = util.Vocab.load_from_corpus(reader, remake=args.rebuild_vocab)
    vocab.START_TOK = vocab[BEGIN_TOKEN]
    vocab.END_TOK = vocab[END_TOKEN]
    vocab.add_unk(args.unk_thresh)

    return vocab
Beispiel #2
0
# define model

model = dynet.Model()
sgd = dynet.SimpleSGDTrainer(model)

S2SModel = seq2seq.get_s2s(args.model)
if args.load:
    print "Loading model..."
    s2s = S2SModel.load(model, args.load)
    src_vocab = s2s.src_vocab
    tgt_vocab = s2s.tgt_vocab
else:
    print "fresh model. getting vocab...",
    src_reader = util.get_reader(args.reader_mode)(args.train,
                                                   mode=args.reader_mode,
                                                   begin=BEGIN_TOKEN,
                                                   end=END_TOKEN)
    src_vocab = util.Vocab.load_from_corpus(src_reader,
                                            remake=args.rebuild_vocab,
                                            src_or_tgt="src")
    src_vocab.START_TOK = src_vocab[BEGIN_TOKEN]
    src_vocab.END_TOK = src_vocab[END_TOKEN]
    src_vocab.add_unk(args.unk_thresh)
    tgt_reader = util.get_reader(args.reader_mode)(args.train,
                                                   mode=args.reader_mode,
                                                   end=END_TOKEN)
    tgt_vocab = util.Vocab.load_from_corpus(tgt_reader,
                                            remake=args.rebuild_vocab,
                                            src_or_tgt="tgt")
    tgt_vocab.END_TOK = tgt_vocab[END_TOKEN]
    tgt_vocab.add_unk(args.unk_thresh)
Beispiel #3
0
elif args.trainer == "adadelta":
    trainer = dynet.AdadeltaTrainer(model)
elif args.trainer == "adagrad":
    trainer = dynet.AdagradTrainer(model)
elif args.trainer == "adam":
    trainer = dynet.AdamTrainer(model)
else:
    raise Exception("Trainer not recognized! Please use one of {simple_sgd, momentum_sgd, adadelta, adagrad, adam}")

# Set sparse updates for efficiency
trainer.set_sparse_updates(True)

# Load train/valid corpus

print("Loading corpus...")
train_data_src = list(util.get_reader(args.reader_mode)(args.train_src, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
train_data_tgt = list(util.get_reader(args.reader_mode)(args.train_tgt, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
if args.valid_src and args.valid_tgt:
    valid_data_src = list(util.get_reader(args.reader_mode)(args.valid_src, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
    valid_data_tgt = list(util.get_reader(args.reader_mode)(args.valid_tgt, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
else:
    if args.percent_valid > 1: cutoff = args.percent_valid
    else: cutoff = int(len(train_data_src)*(args.percent_valid))

    valid_data_src = train_data_src[-cutoff:]
    valid_data_tgt = train_data_tgt[-cutoff:]

    train_data_src = train_data_src[:-cutoff]
    train_data_tgt = train_data_tgt[:-cutoff]

print("Train set of size", len(train_data_src), "/ Validation set of size", len(valid_data_src))
Beispiel #4
0
elif args.trainer == "adadelta":
    trainer = dynet.AdadeltaTrainer(model)
elif args.trainer == "adagrad":
    trainer = dynet.AdagradTrainer(model)
elif args.trainer == "adam":
    trainer = dynet.AdamTrainer(model)
else:
    raise Exception("Trainer not recognized! Please use one of {simple_sgd, momentum_sgd, adadelta, adagrad, adam}")

trainer.set_clip_threshold(-1.0)
trainer.set_sparse_updates(True)

# load corpus

print "Loading corpus..."
train_data = list(util.get_reader(args.reader_mode)(args.train, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
if args.valid:
    valid_data = list(util.get_reader(args.reader_mode)(args.valid, mode=args.reader_mode, begin=BEGIN_TOKEN, end=END_TOKEN))
else:
    if args.percent_valid > 1: cutoff = args.percent_valid
    else: cutoff = int(len(train_data)*(args.percent_valid))
    valid_data = train_data[-cutoff:]
    train_data = train_data[:-cutoff]
    print "Train set of size", len(train_data), "/ Validation set of size", len(valid_data)
print "done."

# Initialize model
S2SModel = seq2seq.get_s2s(args.model)
if args.load:
    print "Loading existing model..."
    s2s = S2SModel.load(model, args.load)
Beispiel #5
0
elif args.trainer == "adam":
    trainer = dynet.AdamTrainer(model)
    learning_rate = .001
elif args.trainer == "adagrad":
    trainer = dynet.AdagradTrainer(model)
    learning_rate = .01

if args.learning_rate is not None: learning_rate = args.learning_rate
################################### LOAD THE MODELS

if args.load:
    lm = rnnlm.get_model(args.arch).load(model, args.load)
    # OVERRIDES
else:
    reader = util.get_reader(CORPUS_READ_STYLE)(args.train,
                                                mode=CORPUS_READ_STYLE,
                                                begin=BEGIN_TOKEN,
                                                end=END_TOKEN)
    vocab = util.Vocab.load_from_corpus(reader, remake=args.rebuild_vocab)
    vocab.START_TOK = vocab[BEGIN_TOKEN]
    vocab.END_TOK = vocab[END_TOKEN]
    if args.unk_thresh > 0: vocab.add_unk(args.unk_thresh, "<UNK>")
    lm = rnnlm.get_model(args.arch)(model, vocab, args)

################################### LOAD THE DATA
train_data = list(
    util.get_reader(CORPUS_READ_STYLE)(args.train,
                                       mode=CORPUS_READ_STYLE,
                                       begin=BEGIN_TOKEN,
                                       end=END_TOKEN))
if not args.split_train:
    valid_data = list(