def test_dataset_tensor(source, config): dataset = data.TripleDataset(source.train_set, batch_size=2) assert len(dataset) == 3 np.testing.assert_equal(dataset._data[0], np.array([[0, 0, 1], [0, 1, 2]], dtype=np.int64)) np.testing.assert_equal(dataset._data[-1], np.array([[2, 0, 3]], dtype=np.int64))
def test_SequentialBatchSampler(config, source): dataset = data.TripleDataset(source.test_set) sampler = data.SequentialBatchSampler(dataset) it1 = iter(sampler) next(it1) with pytest.raises(StopIteration): next(it1) it2 = iter(sampler) next(it2) with pytest.raises(StopIteration): next(it2)
def small_triple_list(source): dataset = data.TripleDataset(source.train_set, batch_size=2) triple_list = next(iter(dataset)) return triple_list