def test_concat_arrays(self): array_count = 4 sample_size = 128 # 1D arrays array_list = [ generate_array_ints(n=sample_size) for _ in range(array_count) ] array_concat = concat_arrays(*array_list) self.assertEqual( shape_of_array(array_list[0])[1:], shape_of_array(array_concat)[1:]) self.assertEqual(len(array_concat), sample_size * array_count) # 2D arrays (2 columns) array_list = [ generate_array_ints(n=sample_size * 2).reshape(-1, 2) for _ in range(array_count) ] array_concat = concat_arrays(*array_list) self.assertEqual( shape_of_array(array_list[0])[1:], shape_of_array(array_concat)[1:]) self.assertEqual(len(array_concat), sample_size * array_count) # N-D arrays (array of vectors) array_list = generate_onehot_matrix(n=sample_size, ndim=array_count) array_concat = concat_arrays(*array_list) self.assertEqual( shape_of_array(array_list[0])[1:], shape_of_array(array_concat)[1:]) self.assertEqual(len(array_concat), sample_size * array_count) # N-D arrays (array of images) array_list = [ generate_images(n=sample_size) for _ in range(array_count) ] array_concat = concat_arrays(*array_list) self.assertEqual( shape_of_array(array_list[0])[1:], shape_of_array(array_concat)[1:]) self.assertEqual(len(array_concat), sample_size * array_count)
def test_override_batch_size(self): batch_size = 10 num_batches = 5 num_features = 5 ones = numpy.ones((batch_size, num_features)) dataset = concat_arrays(*[(ones * i).tolist() for i in range(num_batches)]) sampler = OrderedSampler(dataset, batch_size=batch_size, epochs=1) sample = sampler(batch_size=len(dataset)) self.assertEqual(len(dataset), len(sample)) self.assertListEqual(sample, dataset) self.assertRaises(IndexError, sampler)
def test_from_numpy(self): batch_size = 10 num_batches = 5 num_features = 5 ones = numpy.ones((batch_size, num_features)) dataset = concat_arrays(*[ones * i for i in range(num_batches)]) input_fn = OrderedSampler(dataset, batch_size=batch_size, epochs=1) for i in range(num_batches): sample = input_fn() self.assertEqual(batch_size, len(sample)) self.assertListEqual(sample, (ones * i).tolist()) self.assertRaises(IndexError, input_fn)
def test_cross_validation_2d(self): batch_size = 10 num_batches = 5 num_features = 5 ones = numpy.ones((batch_size, num_features)) dataset = concat_arrays(*[ones * i for i in range(num_batches)]) input_fn = OrderedCVSampler(dataset, batch_size=batch_size, epochs=1, test_split=0.2) for i in range(num_batches - 1): sample = input_fn() self.assertEqual(batch_size, len(sample)) self.assertListEqual(sample, (ones * i).tolist()) sample = input_fn(subset=DataSplit.TEST) self.assertListEqual(sample, (ones * (num_batches - 1)).tolist()) self.assertRaises(IndexError, input_fn)