def setUp(self):
        """ setup of the learner class """
        self.path = 'test/test_data'
        [data_loader, data_loader_test] = get_dataloader(self.path)

        self.learner = LearnFastRCNN(num_classes=20,
                                     data_loader=data_loader,
                                     data_loader_test=data_loader_test,
                                     device='cpu')
 def test_torch_to_numpy(self):
     """ check if conversion from dataload can be converted back """
     [data_loader, _] = get_dataloader(
         self.path, shuffle_train=False,
         batch_size_train=1,
         perm_images=False)
     self.learner.data_loader = data_loader
     it = iter(self.learner.data_loader)
     images, _ = next(it)
     image = images[0]
     image.to(torch.device('cpu'))
     # image.to(torch.device('cpu'))
     img_numpy = self.learner.torch_to_numpy_image(image)
     img_org = cv2.cvtColor(cv2.imread(self.path + '/JPEGImages/000005.jpg'), cv2.COLOR_BGR2RGB)
     self.assertEqual(img_numpy.shape, img_org.shape)
Esempio n. 3
0
    def test_get_dataloader_correct_image_format(self):
        """
        tests wether image is returned in the right format.
        The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
        image, and should be in 0-1 range. Different images can have different sizes.

        The behavior of the model changes depending if it is in training or evaluation mode.

        During training, the model expects both the input tensors, as well as a targets dictionary,
        containing:
            - boxes (Tensor[N, 4]): the ground-truth boxes in [x0, y0, x1, y1] format, with values
            between 0 and H and 0 and W
            - labels (Tensor[N]): the class label for each ground-truth box
        """
        [data_loader, _] = get_dataloader(self.path, train_test_split=0.8)
        it = iter(data_loader)
        images, targets = next(it)
        self.assertEqual(len(images), 5)
        self.assertEqual(len(targets), 5)
Esempio n. 4
0
 def test_get_dataloader_train_test_split(self):
     """ test of the dataloader function """
     [data_loader, data_loader_test] = get_dataloader(self.path, train_test_split=0.8)
     self.assertEqual(len(data_loader.dataset), 10)
     self.assertEqual(len(data_loader_test.dataset), 3)
Esempio n. 5
0
 def test_get_dataloader(self):
     """ test of the dataloader function """
     [data_loader, data_loader_test] = get_dataloader(self.path)
     self.assertIsInstance(data_loader, torch.utils.data.DataLoader)
     self.assertIsInstance(data_loader_test, torch.utils.data.DataLoader)