def test_concats(self): img1 = np.ones((2, 2), dtype='uint8') img2 = np.zeros((2, 2), dtype='uint8') images = np.array([img1, img2]) labels = np.array([[0, 1], [1, 0]]) augm = Augmentation(mixing_coeff=0.5, vertical_concat_prob=1) res_image, res_label = augm.apply_random_transform(images, labels) true_image = np.array([[1, 1], [0, 0]], dtype='uint8') true_label = np.array([0.5, 0.5]) self.assertTrue(np.array_equal(true_image, res_image)) self.assertTrue(np.array_equal(true_label, res_label)) res_image, res_label = augm.horizontal_concat(images, labels, mixing_coeff=0.5) true_image = np.array([[1, 0], [1, 0]], dtype='uint8') self.assertTrue(np.array_equal(true_image, res_image)) self.assertTrue(np.array_equal(true_label, res_label)) res_image, res_label = augm.mixed_concat(images, labels, mixing_coeff=0.5) true_image = np.array([[1, 0], [0, 1]], dtype='uint8') self.assertTrue(np.array_equal(true_image, res_image)) self.assertTrue(np.array_equal(true_label, res_label))
def test_mixup(self): img1 = np.array([[250, 250], [250, 250]], dtype='uint8') img2 = np.zeros((2, 2), dtype='uint8') images = np.array([img1, img2]) labels = np.array([[0, 1], [1, 0]]) augm = Augmentation(mixing_coeff=0.5, mixup_prob=1) res_image, res_label = augm.apply_random_transform(images, labels) true_image = np.array([[125, 125], [125, 125]], dtype='uint8') true_label = np.array([0.5, 0.5]) self.assertTrue(np.array_equal(true_image, res_image)) self.assertTrue(np.array_equal(true_label, res_label))
def get_batch_generator(self, batch_size, generator_args={}, **kwargs): return DataFrameIterator(self.dataframe, self.dataset_dir, Augmentation(**generator_args), x_col=self.x_col, y_col=self.y_col, target_size=self.image_size, batch_size=batch_size, seed=42, interpolation='bilinear', cache=self.cache, **kwargs)
def get_fewshot_generator(self, n_way, k_shot, query_size=None, support_generator_args={}, query_generator_args={}, **kwargs): return FewShotDataFrameIterator(self.dataframe, self.dataset_dir, Augmentation(**support_generator_args), Augmentation(**query_generator_args), class_index=self.class_index, n_way=n_way, k_shot=k_shot, query_size=query_size, x_col=self.x_col, y_col=self.y_col, target_size=self.image_size, seed=42, interpolation='bilinear', cache=self.cache, **kwargs)
def test_zero_probs(self): img = np.random.uniform(0, 256, (50, 50, 3)) img = np.array([img], dtype='uint8') augm = Augmentation() res_image, res_label = augm.apply_random_transform(img) self.assertTrue(np.array_equal(img, res_image))
def test_crop(self): img = np.random.uniform(0, 256, (50, 50, 3)) img = np.array([img], dtype='uint8') augm = Augmentation(crop_prob=1, crop_size=10) res_image, res_label = augm.apply_random_transform(img) self.assertEqual(res_image.shape, (1, 10, 10, 3))