def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = dataloader.DataLoader(ds, 4) train_data, test_data = data.split(0.5) self.assertEqual(train_data.size, 2) for i, elem in enumerate(train_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertEqual(test_data.size, 2) for i, elem in enumerate(test_data.dataset): self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def get_dataloader(data_size, input_shape, num_classes, max_input_value=1000): """Gets a simple `DataLoader` object for test.""" features = tf.random.uniform(shape=[data_size] + input_shape, minval=0, maxval=max_input_value, dtype=tf.float32) labels = tf.random.uniform(shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32) ds = tf.data.Dataset.from_tensor_slices((features, labels)) data = dataloader.DataLoader(ds, data_size) return data
def test_len(self): size = 4 ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = dataloader.DataLoader(ds, size) self.assertEqual(len(data), size)