Ejemplo n.º 1
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    prefix = args.train_prefix

    kwargs = {
        'test': False,
        'shuffle_on_load': True,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(prefix, vocab, **kwargs)
    else:
        data = LMDataset(prefix, vocab, **kwargs)

    tf_save_dir = args.save_dir
    tf_log_dir = args.save_dir

    # set optional inputs
    if args.n_train_tokens > 0:
        options['n_train_tokens'] = args.n_train_tokens
    if args.n_epochs > 0:
        options['n_epochs'] = args.n_epochs
    if args.batch_size > 0:
        options['batch_size'] = args.batch_size

    train(options,
          data,
          args.n_gpus,
          tf_save_dir,
          tf_log_dir,
          restart_ckpt_file=ckpt_file)
Ejemplo n.º 2
0
def main(args):
    options, ckpt_file = load_options_latest_checkpoint(args.save_dir)

    # load the vocab
    if 'char_cnn' in options:
        max_word_length = options['char_cnn']['max_characters_per_token']
    else:
        max_word_length = None
    vocab = load_vocab(args.vocab_file, max_word_length)

    test_prefix = args.test_prefix

    kwargs = {
        'test': True,
        'shuffle_on_load': False,
    }

    if options.get('bidirectional'):
        data = BidirectionalLMDataset(test_prefix, vocab, **kwargs)
    else:
        data = LMDataset(test_prefix, vocab, **kwargs)

    test(options, ckpt_file, data, batch_size=args.batch_size)
Ejemplo n.º 3
0
    words, vocab_unicodechars.word_to_char_ids(words)))
print('====> word \t{}\t encoded chars id result: {}'.format(
    words, vocab_unicodechars.encode_chars(words)))
ids = [1234, 3234, 22, 34, 341324, 21, 345]
print('====> decode \t{}\t to words: {}'.format(
    ids, vocab_unicodechars.decode(ids)))
'''
UE for LMDataset
'''
print('\n\n\tUE for LMDataset:')
vocab_file = '../data/vocab_seg_words_elmo.txt'
vocab_unicodechars = UnicodeCharsVocabulary(vocab_file,
                                            max_word_length=10,
                                            validate_file=True)
filepattern = '../data/example/*_seg_words.txt'
lmds = LMDataset(filepattern, vocab_unicodechars, test=True)
batch_size = 128
n_gpus = 1
unroll_steps = 10
data_gen = lmds.iter_batches(batch_size * n_gpus, unroll_steps)
jump_cnt = 0
for num, batch in enumerate(data_gen, start=1):
    jump_cnt += 1
    if jump_cnt > 10:
        break
    print('====> iter [{}]\ttoken ids shape: {}'.format(
        num, batch['token_ids'].shape))
    print('====> iter [{}]\ttokens characters shape: {}'.format(
        num, batch['tokens_characters'].shape))
    print('====> iter [{}]\tnext token ids shape: {}'.format(
        num, batch['next_token_id'].shape))
Ejemplo n.º 4
0
        if args.model == 'hmm':
            res0, res1, res2, OOV, IN = get_perplexities(sents, model, k=k)
            res_perplexities2.append(res2)
            count_in += OOV
            count_oov += IN
            res_perplexities0.append(res0)
            res_perplexities1.append(res1)

        if args.model == 'elmo':

            filepath = subdir + os.sep
            if options.get('bidirectional'):
                data = BidirectionalLMDataset(filepath, vocab, **kwargs)
                # print(data)
            else:
                data = LMDataset(filepath, vocab, **kwargs)

            res2 = test(options, ckpt_file, data, batch_size=args.batch_size)

            res_perplexities2.append(res2)

        outfile.write(file + '\t' + label + '\t' + str(res2) + '\n')

        if count % 5 == 0:
            print('I have calculated perplexities for %s files' % count,
                  file=sys.stderr)

print('=== Just a sanity check on the perplexity calculations: ')
print(labels[:5], fns[:5], res_perplexities2[:5])

print('Texts with the most extreme text-level perplexities:')