コード例 #1
0
def test_dataset_tensor(source, config):
    dataset = data.TripleDataset(source.train_set, batch_size=2)
    assert len(dataset) == 3
    np.testing.assert_equal(dataset._data[0],
                            np.array([[0, 0, 1], [0, 1, 2]], dtype=np.int64))
    np.testing.assert_equal(dataset._data[-1],
                            np.array([[2, 0, 3]], dtype=np.int64))
コード例 #2
0
def test_SequentialBatchSampler(config, source):
    dataset = data.TripleDataset(source.test_set)
    sampler = data.SequentialBatchSampler(dataset)
    it1 = iter(sampler)
    next(it1)
    with pytest.raises(StopIteration):
        next(it1)
    it2 = iter(sampler)
    next(it2)
    with pytest.raises(StopIteration):
        next(it2)
コード例 #3
0
def small_triple_list(source):
    dataset = data.TripleDataset(source.train_set, batch_size=2)
    triple_list = next(iter(dataset))
    return triple_list