def test_bptt_batch_sampler_example(): sampler = BPTTBatchSampler(range(100), bptt_length=2, batch_size=3, drop_last=False) assert list(sampler)[0] == [slice(0, 2), slice(34, 36), slice(67, 69)] sampler = BPTTBatchSampler(range(100), bptt_length=2, batch_size=3, drop_last=False, type_='target') assert list(sampler)[0] == [slice(1, 3), slice(35, 37), slice(68, 70)]
def test_bptt_batch_sampler(alphabet): sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=False) list_ = list(sampler_to_iterator(alphabet, sampler)) assert list_[0] == [['a', 'b'], ['h', 'i'], ['o', 'p'], ['u', 'v']] assert len(sampler) == len(list_)
def test_bptt_batch_sampler_example(): sampler = BPTTBatchSampler(range(100), bptt_length=2, batch_size=3, drop_last=False) assert list(sampler)[0] == [(slice(0, 2), slice(1, 3)), (slice(34, 36), slice(35, 37)), (slice(67, 69), slice(68, 70))] assert list(sampler)[1] == [(slice(2, 4), slice(3, 5)), (slice(36, 38), slice(37, 39)), (slice(69, 71), slice(70, 72))]
def test_bptt_batch_sampler_drop_last(): # Test samplers iterate over chunks similar to: # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112 alphabet = list('abcdefghijklmnopqrstuvwxyz') sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=True) list_ = list(sampler_to_iterator(alphabet, sampler)) assert list_[0] == [['a', 'b'], ['g', 'h'], ['m', 'n'], ['s', 't']] assert len(sampler) == len(list_)
def test_bptt_batch_sampler(): alphabet = list('abcdefghijklmnopqrstuvwxyz') sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=False) decoded_batches = [] for batch in list(sampler): decoded_batch = [] for source, target in batch: decoded_batch.append([alphabet[source], alphabet[target]]) decoded_batches.append(decoded_batch) assert decoded_batches[0] == [[['a', 'b'], ['b', 'c']], [['h', 'i'], ['i', 'j']], [['o', 'p'], ['p', 'q']], [['u', 'v'], ['v', 'w']]] assert len(sampler) == len(decoded_batches)
def test_bptt_batch_sampler_drop_last(): # Test samplers iterate over chunks similar to: # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112 alphabet = list('abcdefghijklmnopqrstuvwxyz') sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=True) decoded_batches = [] for batch in list(sampler): decoded_batch = [] for source, target in batch: decoded_batch.append([alphabet[source], alphabet[target]]) decoded_batches.append(decoded_batch) assert decoded_batches[0] == [[['a', 'b'], ['b', 'c']], [['g', 'h'], ['h', 'i']], [['m', 'n'], ['n', 'o']], [['s', 't'], ['t', 'u']]] assert len(sampler) == len(decoded_batches)
print('Producing dataset...') train, val, test = getattr(datasets, args.data)(train=True, dev=True, test=True) encoder = IdentityEncoder(train + val + test) train_data = encoder.encode(train) val_data = encoder.encode(val) test_data = encoder.encode(test) eval_batch_size = 10 test_batch_size = 1 train_source_sampler, val_source_sampler, test_source_sampler = tuple([ BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source') for d in (train, val, test) ]) train_target_sampler, val_target_sampler, test_target_sampler = tuple([ BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target') for d in (train, val, test) ]) ############################################################################### # Build the model ############################################################################### from splitcross import SplitCrossEntropyLoss criterion = None
def sampler(alphabet): return BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=True)
dev=True, test=True, extracted_name="wikitext-2", url= "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip", # noqa: E501 unknown_token=SPECIAL_TOKENS.UNK.value, eos_token=SPECIAL_TOKENS.EOS.value, ) ldm = PLDataModuleFromCorpus( train, val=dev, test=test, drop_last=True, max_len=-1, batch_sampler_train=BPTTBatchSampler(train, bptt, 20, True), batch_sampler_val=BPTTBatchSampler(dev, bptt, 10, True), batch_sampler_test=BPTTBatchSampler(test, bptt, 10, True), pin_memory=True, num_workers=0, language_model=True, tokenizer="tokenized", collate_fn=collate_fn, ) model = TransformerLM( vocab_size=ldm.vocab_size, num_layers=2, hidden_size=200, num_heads=2, inner_size=256,
from torchnlp import datasets from torchnlp.encoders import LabelEncoder from torchnlp.samplers import BPTTBatchSampler print('Producing dataset...') train, val, test = getattr(datasets, args.data)(train=True, dev=True, test=True) encoder = LabelEncoder(train + val + test) train_data = encoder.batch_encode(train) val_data = encoder.batch_encode(val) test_data = encoder.batch_encode(test) train_source_sampler, val_source_sampler, test_source_sampler = tuple( [BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source') for d in (train, val, test)]) train_target_sampler, val_target_sampler, test_target_sampler = tuple( [BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target') for d in (train, val, test)]) ############################################################################### # Build the model ############################################################################### from splitcross import SplitCrossEntropyLoss criterion = None ntokens = encoder.vocab_size model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout) ### if args.resume: