示例#1
0
def get_squad(tmpdir_factory):
    # Generate a synthetic features cache file to avoid using the real dataset
    tmpdir = tmpdir_factory.mktemp("ndr_test_tmp_data")
    input_file = str(tmpdir) + "/inputfile"
    dataset_size = 3333
    features = generate_random_features(128, 30, dataset_size)
    cache_file = input_file + f".{128}.cache"
    with open(cache_file, "wb") as f:
        pickle.dump(features, f)
    print(cache_file)
    return tmpdir
示例#2
0
def test_generated_data_squad():
    sequence_length = 128
    batch_size = 2
    vocab_length = 4864

    features = generate_random_features(sequence_length, vocab_length,
                                        batch_size)

    dl = SquadDataLoader(features, batch_size=batch_size)

    assert (len(dl) == 1)

    sizes = [sequence_length, sequence_length, sequence_length, 1, 1, 1]
    ranges = [
        vocab_length, sequence_length + 1, 2, sequence_length + 1,
        sequence_length, sequence_length
    ]

    dl_itr = iter(dl)

    for data, size, max_value in zip(next(dl_itr), sizes, ranges):
        assert (np.all(data < max_value))
        assert (data.shape == (batch_size, size))