shuffle_on_load=False,
                              curriculum=True,
                              num_steps=20)
data_gen = data.curr_iter_batches(2, 1, 0.0)
print("Data done loading starting exp")
for batch_no, batch in enumerate(data_gen, start=1):
    print(batch_no)
    print(batch['tokens_characters'].shape)
    print(batch['ignore'])
    print(batch['next_token_id'])
    print(batch['ignore'].shape)
    print(batch['token_ids'].shape)
    print(batch['next_token_id'].shape)
    print(batch['token_ids_reverse'].shape)
    print(batch['tokens_characters_reverse'].shape)
    break

data = BidirectionalLMDataset('dummy.txt',
                              vocab,
                              test=False,
                              shuffle_on_load=False)
data_gen = data.iter_batches(2, 20)
for batch_no, batch in enumerate(data_gen, start=1):
    print(batch_no)
    print(batch['tokens_characters'].shape)
    print(batch['token_ids'].shape)
    print(batch['next_token_id'].shape)
    print(batch['token_ids_reverse'].shape)
    print(batch['tokens_characters_reverse'].shape)
    break
Exemple #2
0
    print('====> iter [{}]\tnext token ids shape: {}'.format(
        num, batch['next_token_id'].shape))
'''
UE for BidirectionalLMDataset
'''
print('\n\n\tUE for BidirectionalLMDataset:')
vocab_file = '../data/vocab_seg_words_elmo.txt'
vocab_unicodechars = UnicodeCharsVocabulary(vocab_file,
                                            max_word_length=10,
                                            validate_file=True)
filepattern = '../data/example/*_seg_words.txt'
bilmds = BidirectionalLMDataset(filepattern, vocab_unicodechars, test=True)
batch_size = 128
n_gpus = 1
unroll_steps = 10
data_gen = bilmds.iter_batches(batch_size * n_gpus, unroll_steps)
jump_cnt = 0
for num, batch in enumerate(data_gen, start=1):
    jump_cnt += 1
    if jump_cnt > 10:
        break
    print('\n')
    print('====> iter [{}]\ttoken ids shape: {}'.format(
        num, batch['token_ids'].shape))
    print('====> iter [{}]\ttokens characters shape: {}'.format(
        num, batch['tokens_characters'].shape))
    print('====> iter [{}]\tnext token ids shape: {}'.format(
        num, batch['next_token_id'].shape))
    print('====> iter [{}]\ttoken ids reverse shape: {}'.format(
        num, batch['token_ids_reverse'].shape))
    print('====> iter [{}]\ttokens characters reverse shape: {}'.format(