示例#1
0
  def test_split(self):
    ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
    data = dataloader.DataLoader(ds, 4)
    train_data, test_data = data.split(0.5)

    self.assertEqual(train_data.size, 2)
    for i, elem in enumerate(train_data.dataset):
      self.assertTrue((elem.numpy() == np.array([i, 1])).all())

    self.assertEqual(test_data.size, 2)
    for i, elem in enumerate(test_data.dataset):
      self.assertTrue((elem.numpy() == np.array([i, 0])).all())
示例#2
0
def get_dataloader(data_size, input_shape, num_classes, max_input_value=1000):
    """Gets a simple `DataLoader` object for test."""
    features = tf.random.uniform(shape=[data_size] + input_shape,
                                 minval=0,
                                 maxval=max_input_value,
                                 dtype=tf.float32)

    labels = tf.random.uniform(shape=[data_size],
                               minval=0,
                               maxval=num_classes,
                               dtype=tf.int32)

    ds = tf.data.Dataset.from_tensor_slices((features, labels))
    data = dataloader.DataLoader(ds, data_size)
    return data
示例#3
0
 def test_len(self):
     size = 4
     ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
                                              [1, 0]])
     data = dataloader.DataLoader(ds, size)
     self.assertEqual(len(data), size)