Beispiel #1
0
def test_bptt_batch_sampler_example():
    sampler = BPTTBatchSampler(range(100),
                               bptt_length=2,
                               batch_size=3,
                               drop_last=False)
    assert list(sampler)[0] == [slice(0, 2), slice(34, 36), slice(67, 69)]

    sampler = BPTTBatchSampler(range(100),
                               bptt_length=2,
                               batch_size=3,
                               drop_last=False,
                               type_='target')
    assert list(sampler)[0] == [slice(1, 3), slice(35, 37), slice(68, 70)]
Beispiel #2
0
def test_bptt_batch_sampler(alphabet):
    sampler = BPTTBatchSampler(alphabet,
                               bptt_length=2,
                               batch_size=4,
                               drop_last=False)
    list_ = list(sampler_to_iterator(alphabet, sampler))
    assert list_[0] == [['a', 'b'], ['h', 'i'], ['o', 'p'], ['u', 'v']]
    assert len(sampler) == len(list_)
def test_bptt_batch_sampler_example():
    sampler = BPTTBatchSampler(range(100),
                               bptt_length=2,
                               batch_size=3,
                               drop_last=False)
    assert list(sampler)[0] == [(slice(0, 2), slice(1, 3)),
                                (slice(34, 36), slice(35, 37)),
                                (slice(67, 69), slice(68, 70))]
    assert list(sampler)[1] == [(slice(2, 4), slice(3, 5)),
                                (slice(36, 38), slice(37, 39)),
                                (slice(69, 71), slice(70, 72))]
def test_bptt_batch_sampler_drop_last():
    # Test samplers iterate over chunks similar to:
    # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112
    alphabet = list('abcdefghijklmnopqrstuvwxyz')
    sampler = BPTTBatchSampler(alphabet,
                               bptt_length=2,
                               batch_size=4,
                               drop_last=True)
    list_ = list(sampler_to_iterator(alphabet, sampler))
    assert list_[0] == [['a', 'b'], ['g', 'h'], ['m', 'n'], ['s', 't']]
    assert len(sampler) == len(list_)
def test_bptt_batch_sampler():
    alphabet = list('abcdefghijklmnopqrstuvwxyz')
    sampler = BPTTBatchSampler(alphabet,
                               bptt_length=2,
                               batch_size=4,
                               drop_last=False)
    decoded_batches = []
    for batch in list(sampler):
        decoded_batch = []
        for source, target in batch:
            decoded_batch.append([alphabet[source], alphabet[target]])
        decoded_batches.append(decoded_batch)
    assert decoded_batches[0] == [[['a', 'b'], ['b', 'c']],
                                  [['h', 'i'], ['i', 'j']],
                                  [['o', 'p'], ['p', 'q']],
                                  [['u', 'v'], ['v', 'w']]]
    assert len(sampler) == len(decoded_batches)
def test_bptt_batch_sampler_drop_last():
    # Test samplers iterate over chunks similar to:
    # https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112
    alphabet = list('abcdefghijklmnopqrstuvwxyz')
    sampler = BPTTBatchSampler(alphabet,
                               bptt_length=2,
                               batch_size=4,
                               drop_last=True)
    decoded_batches = []
    for batch in list(sampler):
        decoded_batch = []
        for source, target in batch:
            decoded_batch.append([alphabet[source], alphabet[target]])
        decoded_batches.append(decoded_batch)
    assert decoded_batches[0] == [[['a', 'b'], ['b', 'c']],
                                  [['g', 'h'], ['h', 'i']],
                                  [['m', 'n'], ['n', 'o']],
                                  [['s', 't'], ['t', 'u']]]

    assert len(sampler) == len(decoded_batches)
Beispiel #7
0
print('Producing dataset...')
train, val, test = getattr(datasets, args.data)(train=True,
                                                dev=True,
                                                test=True)

encoder = IdentityEncoder(train + val + test)

train_data = encoder.encode(train)
val_data = encoder.encode(val)
test_data = encoder.encode(test)

eval_batch_size = 10
test_batch_size = 1

train_source_sampler, val_source_sampler, test_source_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source')
    for d in (train, val, test)
])

train_target_sampler, val_target_sampler, test_target_sampler = tuple([
    BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target')
    for d in (train, val, test)
])

###############################################################################
# Build the model
###############################################################################

from splitcross import SplitCrossEntropyLoss
criterion = None
Beispiel #8
0
def sampler(alphabet):
    return BPTTBatchSampler(alphabet,
                            bptt_length=2,
                            batch_size=4,
                            drop_last=True)
        dev=True,
        test=True,
        extracted_name="wikitext-2",
        url=
        "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",  # noqa: E501
        unknown_token=SPECIAL_TOKENS.UNK.value,
        eos_token=SPECIAL_TOKENS.EOS.value,
    )

    ldm = PLDataModuleFromCorpus(
        train,
        val=dev,
        test=test,
        drop_last=True,
        max_len=-1,
        batch_sampler_train=BPTTBatchSampler(train, bptt, 20, True),
        batch_sampler_val=BPTTBatchSampler(dev, bptt, 10, True),
        batch_sampler_test=BPTTBatchSampler(test, bptt, 10, True),
        pin_memory=True,
        num_workers=0,
        language_model=True,
        tokenizer="tokenized",
        collate_fn=collate_fn,
    )

    model = TransformerLM(
        vocab_size=ldm.vocab_size,
        num_layers=2,
        hidden_size=200,
        num_heads=2,
        inner_size=256,
Beispiel #10
0
from torchnlp import datasets
from torchnlp.encoders import LabelEncoder
from torchnlp.samplers import BPTTBatchSampler

print('Producing dataset...')
train, val, test = getattr(datasets, args.data)(train=True, dev=True, test=True)

encoder = LabelEncoder(train + val + test)

train_data = encoder.batch_encode(train)
val_data = encoder.batch_encode(val)
test_data = encoder.batch_encode(test)

train_source_sampler, val_source_sampler, test_source_sampler = tuple(
    [BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'source') for d in (train, val, test)])

train_target_sampler, val_target_sampler, test_target_sampler = tuple(
    [BPTTBatchSampler(d, args.bptt, args.batch_size, True, 'target') for d in (train, val, test)])

###############################################################################
# Build the model
###############################################################################

from splitcross import SplitCrossEntropyLoss
criterion = None

ntokens = encoder.vocab_size
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout)
###
if args.resume: