示例#1
0
    def test_concats(self):
        img1 = np.ones((2, 2), dtype='uint8')
        img2 = np.zeros((2, 2), dtype='uint8')
        images = np.array([img1, img2])
        labels = np.array([[0, 1], [1, 0]])
        augm = Augmentation(mixing_coeff=0.5, vertical_concat_prob=1)
        res_image, res_label = augm.apply_random_transform(images, labels)
        true_image = np.array([[1, 1], [0, 0]], dtype='uint8')
        true_label = np.array([0.5, 0.5])
        self.assertTrue(np.array_equal(true_image, res_image))
        self.assertTrue(np.array_equal(true_label, res_label))

        res_image, res_label = augm.horizontal_concat(images,
                                                      labels,
                                                      mixing_coeff=0.5)
        true_image = np.array([[1, 0], [1, 0]], dtype='uint8')
        self.assertTrue(np.array_equal(true_image, res_image))
        self.assertTrue(np.array_equal(true_label, res_label))

        res_image, res_label = augm.mixed_concat(images,
                                                 labels,
                                                 mixing_coeff=0.5)
        true_image = np.array([[1, 0], [0, 1]], dtype='uint8')
        self.assertTrue(np.array_equal(true_image, res_image))
        self.assertTrue(np.array_equal(true_label, res_label))
示例#2
0
 def test_mixup(self):
     img1 = np.array([[250, 250], [250, 250]], dtype='uint8')
     img2 = np.zeros((2, 2), dtype='uint8')
     images = np.array([img1, img2])
     labels = np.array([[0, 1], [1, 0]])
     augm = Augmentation(mixing_coeff=0.5, mixup_prob=1)
     res_image, res_label = augm.apply_random_transform(images, labels)
     true_image = np.array([[125, 125], [125, 125]], dtype='uint8')
     true_label = np.array([0.5, 0.5])
     self.assertTrue(np.array_equal(true_image, res_image))
     self.assertTrue(np.array_equal(true_label, res_label))
示例#3
0
 def get_batch_generator(self, batch_size, generator_args={}, **kwargs):
     return DataFrameIterator(self.dataframe,
                              self.dataset_dir,
                              Augmentation(**generator_args),
                              x_col=self.x_col,
                              y_col=self.y_col,
                              target_size=self.image_size,
                              batch_size=batch_size,
                              seed=42,
                              interpolation='bilinear',
                              cache=self.cache,
                              **kwargs)
示例#4
0
 def get_fewshot_generator(self,
                           n_way,
                           k_shot,
                           query_size=None,
                           support_generator_args={},
                           query_generator_args={},
                           **kwargs):
     return FewShotDataFrameIterator(self.dataframe,
                                     self.dataset_dir,
                                     Augmentation(**support_generator_args),
                                     Augmentation(**query_generator_args),
                                     class_index=self.class_index,
                                     n_way=n_way,
                                     k_shot=k_shot,
                                     query_size=query_size,
                                     x_col=self.x_col,
                                     y_col=self.y_col,
                                     target_size=self.image_size,
                                     seed=42,
                                     interpolation='bilinear',
                                     cache=self.cache,
                                     **kwargs)
示例#5
0
 def test_zero_probs(self):
     img = np.random.uniform(0, 256, (50, 50, 3))
     img = np.array([img], dtype='uint8')
     augm = Augmentation()
     res_image, res_label = augm.apply_random_transform(img)
     self.assertTrue(np.array_equal(img, res_image))
示例#6
0
 def test_crop(self):
     img = np.random.uniform(0, 256, (50, 50, 3))
     img = np.array([img], dtype='uint8')
     augm = Augmentation(crop_prob=1, crop_size=10)
     res_image, res_label = augm.apply_random_transform(img)
     self.assertEqual(res_image.shape, (1, 10, 10, 3))