def test_givenStopErrorStrategy_whenMissingTarget_shouldRaise(self):
     self.dataloader = DataLoader(
         lambda: [self._metadata(),
                  self._metadata(target_ghi_3h=None)],
         self.image_reader,
         DataloaderConfig(error_strategy=ErrorStrategy.stop),
     )
     self.assertRaises(MissingTargetException,
                       lambda: list(self.dataloader.generator()))
    def test_givenMultipleImages_shouldHaveOneMoreDimension(self):
        self.image_reader.read = mock.Mock(return_value=FAKE_IMAGE)

        self.dataloader = DataLoader(
            lambda: [self._metadata(image_paths=[IMAGE_PATH, IMAGE_PATH])],
            self.image_reader,
            config=DataloaderConfig(features=[Feature.image], ),
        )

        for (image, ) in self.dataloader.generator():
            self.assertEqual([2] + list(FAKE_IMAGE.shape), list(image.shape))
    def test_givenIgnoreErrorStrategy_whenMissingTarget_shouldReturnDummyTarget(
        self, ):
        self.dataloader = DataLoader(
            lambda: [self._metadata(target_ghi=None)],
            self.image_reader,
            DataloaderConfig(error_strategy=ErrorStrategy.ignore, ),
        )

        for image, targets in self.dataloader.generator():
            for target in targets:
                self.assertIsNotNone(target)
    def test_givenStopErrorStrategy_whenImageLoaderException_shouldRaise(self):
        self.image_reader.read = mock.Mock(
            side_effect=[AN_EXCEPTION, FAKE_IMAGE])

        self.dataloader = DataLoader(
            lambda: [self._metadata(), self._metadata()],
            self.image_reader,
            DataloaderConfig(error_strategy=ErrorStrategy.stop),
        )
        self.assertRaises(AN_EXCEPTION_TYPE,
                          lambda: list(self.dataloader.generator()))
    def test_givenSkipErrorStrategy_whenMissingTarget_shouldSkipToNextItem(
        self, ):
        self.dataloader = DataLoader(
            lambda: [self._metadata(target_ghi_1h=None),
                     self._metadata()],
            self.image_reader,
            DataloaderConfig(error_strategy=ErrorStrategy.skip),
        )
        items = list(self.dataloader.generator())

        self.assertEqual(1, len(items))
    def test_givenNoLocalPath_shouldUseOriginalPath(self):
        self.dataloader = DataLoader(
            lambda: [self._metadata()],
            self.image_reader,
            DataloaderConfig(local_path=None),
        )

        dataset = self.dataloader.generator()
        list(dataset)  # Force evaluate the dataset

        self.image_reader.read.assert_called_with(IMAGE_PATH, mock.ANY,
                                                  mock.ANY, mock.ANY)
    def test_givenMultipleImagesWithMissingImage_shouldCreateZerosImage(self):
        self.image_reader.read = mock.Mock(
            side_effect=[FAKE_IMAGE, AN_EXCEPTION])

        self.dataloader = DataLoader(
            lambda: [self._metadata(image_paths=[IMAGE_PATH, IMAGE_PATH])],
            self.image_reader,
            config=DataloaderConfig(features=[Feature.image], ),
        )

        (image, ) = next(self.dataloader.generator())
        self.assertTrue(np.array_equal(image[0], np.zeros(FAKE_IMAGE.shape)))
        self.assertTrue(np.array_equal(image[1], FAKE_IMAGE))
    def test_givenSkipErrorStrategy_whenImageLoaderException_shouldSkipToNextItem(
        self, ):
        self.image_reader.read = mock.Mock(
            side_effect=[AN_EXCEPTION, FAKE_IMAGE])

        self.dataloader = DataLoader(
            lambda: [self._metadata(), self._metadata()],
            self.image_reader,
            DataloaderConfig(error_strategy=ErrorStrategy.skip),
        )
        items = list(self.dataloader.generator())

        self.assertEqual(1, len(items))
    def test_givenSkipErrorStrategy_whenImageCacheMiss_shouldSkipToNextItem(
        self, ):
        self.image_reader.read = mock.Mock(
            side_effect=[ImageNotCached, FAKE_IMAGE])

        self.dataloader = DataLoader(
            lambda: [self._metadata(), self._metadata()],
            self.image_reader,
            DataloaderConfig(error_strategy=ErrorStrategy.skip,
                             force_caching=True),
        )
        items = list(self.dataloader.generator())

        self.assertEqual(1, len(items))
    def test_givenLocalPath_shouldUseLocalPathAsRoot(self):
        local_path = "local/path/"
        self.dataloader = DataLoader(
            lambda: [self._metadata()],
            self.image_reader,
            DataloaderConfig(local_path=local_path),
        )

        dataset = self.dataloader.generator()
        list(dataset)

        expected_path = f"{local_path}{IMAGE_NAME}"
        self.image_reader.read.assert_called_with(expected_path, mock.ANY,
                                                  mock.ANY, mock.ANY)
    def test_givenFeatures_whenCreateDataset_shouldReturnSameNumberOfFeatures(
            self):
        features = [
            Feature.image,
            Feature.target_ghi,
            Feature.metadata,
        ]

        self.dataloader = DataLoader(
            lambda: [self._metadata()],
            self.image_reader,
            DataloaderConfig(features=features),
        )

        for data in self.dataloader.generator():
            self.assertEqual(len(data), len(features))
    def test_metadata_format(self):
        config = cf.read_configuration_file(config_test.DUMMY_TEST_CFG_PATH)
        metadata = Metadata(
            "",
            [],
            False,
            "",
            0,
            datetime=datetime(2010, 6, 19, 22, 15),
            coordinates=config.stations[cf.Station.BND],
        )

        self.dataloader = DataLoader(
            lambda: [metadata],
            self.image_reader,
            DataloaderConfig(features=[Feature.metadata]),
        )

        for (meta, ) in self.dataloader.generator():
            self.assertCloseTo(meta[MetadataFeatureIndex.GHI_T], 471.675670)
            self.assertCloseTo(meta[MetadataFeatureIndex.GHI_T_1h], 280.165857)
            self.assertCloseTo(meta[MetadataFeatureIndex.GHI_T_3h], 0.397029)
            self.assertCloseTo(meta[MetadataFeatureIndex.GHI_T_6h], 0.0)
    def test_givenIgnoreErrorStrategy_whenImageLoaderException_shouldReturnDummyImage(
        self, ):
        self.image_reader.read = mock.Mock(
            side_effect=[AN_HANDLED_EXCEPTION, FAKE_IMAGE])
        channels = ["ch1", "ch2"]
        crop_size = [40, 40]
        num_channels = len(channels)
        expected_image_shape = crop_size + [num_channels]

        self.dataloader = DataLoader(
            lambda: [self._metadata(), self._metadata()],
            self.image_reader,
            DataloaderConfig(
                error_strategy=ErrorStrategy.ignore,
                crop_size=crop_size,
                channels=channels,
            ),
        )

        items = list(self.dataloader.generator())

        first_image = items[0][0]
        self.assertEqual(2, len(items))
        self.assertEqual(first_image.shape, expected_image_shape)