示例#1
0
def main(opt):
    loader = BatchLoaderUnk(opt.tokens, opt.data_dir, opt.batch_size, opt.seq_length, opt.max_word_l, opt.n_words, opt.n_chars)
    opt.word_vocab_size = min(opt.n_words, len(loader.idx2word))
    opt.char_vocab_size = min(opt.n_chars, len(loader.idx2char))
    opt.max_word_l = loader.max_word_l
    print('Word vocab size: ', opt.word_vocab_size, \
        ', Char vocab size: ', opt.char_vocab_size, \
        ', Max word length (incl. padding): ', opt.max_word_l)

    # define the model
    if not opt.skip_train:
        print('creating an LSTM-CNN with ', opt.num_layers, ' layers')
        model = LSTMCNN(opt)
            # make sure output directory exists
        if not os.path.exists(opt.checkpoint_dir):
            os.makedirs(opt.checkpoint_dir)
        pickle.dump(opt, open('{}/{}.pkl'.format(opt.checkpoint_dir, opt.savefile), "wb"))
        model.save('{}/{}.json'.format(opt.checkpoint_dir, opt.savefile))
        model.fit_generator(loader.next_batch(Train), loader.split_sizes[Train], opt.max_epochs,
                            loader.next_batch(Validation), loader.split_sizes[Validation], opt)
        model.save_weights('{}/{}.h5'.format(opt.checkpoint_dir, opt.savefile), overwrite=True)
    else:
        model = load_model('{}/{}.json'.format(opt.checkpoint_dir, opt.savefile))
        model.load_weights('{}/{}.h5'.format(opt.checkpoint_dir, opt.savefile))
        print(model.summary())

    # evaluate on full test set.
    test_perp = model.evaluate_generator(loader.next_batch(Test), loader.split_sizes[Test])
    print('Perplexity on test set: ', exp(test_perp))
示例#2
0
文件: train.py 项目: jarfo/kchar
def main(opt):
    loader = BatchLoaderUnk(opt.tokens, opt.data_dir, opt.batch_size, opt.seq_length, opt.max_word_l, opt.n_words, opt.n_chars)
    opt.word_vocab_size = min(opt.n_words, len(loader.idx2word))
    opt.char_vocab_size = min(opt.n_chars, len(loader.idx2char))
    opt.max_word_l = loader.max_word_l
    print('Word vocab size:', opt.word_vocab_size,
        ', Char vocab size:', opt.char_vocab_size,
        ', Max word length (incl. padding):', opt.max_word_l)

    # define the model
    if not opt.skip_train:
        print('creating an LSTM-CNN with', opt.num_layers, 'layers')
        model = LSTMCNN(opt)
            # make sure output directory exists
        if not os.path.exists(opt.checkpoint_dir):
            os.makedirs(opt.checkpoint_dir)
        pickle.dump(opt, open('{}/{}.pkl'.format(opt.checkpoint_dir, opt.savefile), "wb"))
        model.save('{}/{}.json'.format(opt.checkpoint_dir, opt.savefile))
        model.fit_generator(loader.next_batch(Train), loader.split_sizes[Train], opt.max_epochs,
                            loader.next_batch(Validation), loader.split_sizes[Validation], opt)
        model.save_weights('{}/{}.h5'.format(opt.checkpoint_dir, opt.savefile), overwrite=True)
    else:
        model = load_model('{}/{}.json'.format(opt.checkpoint_dir, opt.savefile))
        model.load_weights('{}/{}.h5'.format(opt.checkpoint_dir, opt.savefile))
        print(model.summary())

    # evaluate on full test set.
    test_perp = model.evaluate_generator(loader.next_batch(Test), loader.split_sizes[Test])
    print('Perplexity on test set:', exp(test_perp))