Exemplo n.º 1
0
    def _test(self, is_random=False):
        stats_uri = os.path.join(self.tmp_dir.name, 'stats.json')
        scenes = []
        raster_sources = []
        imgs = []
        sample_prob = 0.5
        for i in range(3):
            rs = MockRasterSource([0, 1, 2], 3)
            img = np.zeros((600, 600, 3))
            img[:, :, 0] = 1 + i
            img[:, :, 1] = 2 + i
            img[:, :, 2] = 3 + i
            if not is_random:
                img[300:, 300:, :] = np.nan

            imgs.append(img)
            rs.set_raster(img)
            raster_sources.append(rs)
            scenes.append(Scene(str(i), rs))

        channel_vals = list(map(lambda x: np.expand_dims(x, axis=0), imgs))
        channel_vals = np.concatenate(channel_vals, axis=0)
        channel_vals = np.transpose(channel_vals, [3, 0, 1, 2])
        channel_vals = np.reshape(channel_vals, (3, -1))
        exp_means = np.nanmean(channel_vals, axis=1)
        exp_stds = np.nanstd(channel_vals, axis=1)

        analyzer_cfg = StatsAnalyzerConfig(output_uri=stats_uri,
                                           sample_prob=None)
        if is_random:
            analyzer_cfg = StatsAnalyzerConfig(output_uri=stats_uri,
                                               sample_prob=sample_prob)
        analyzer = analyzer_cfg.build()
        analyzer.process(scenes, self.tmp_dir.name)

        stats = RasterStats.load(stats_uri)
        np.testing.assert_array_almost_equal(stats.means, exp_means, decimal=3)
        np.testing.assert_array_almost_equal(stats.stds, exp_stds, decimal=3)
        if is_random:
            for rs in raster_sources:
                width = rs.get_extent().get_width()
                height = rs.get_extent().get_height()
                exp_num_chips = round(
                    ((width * height) / (chip_sz**2)) * sample_prob)
                self.assertEqual(rs.mock._get_chip.call_count, exp_num_chips)
Exemplo n.º 2
0
    def _insert_analyzers(self):
        has_stats_transformer = False
        for s in self.dataset.get_all_scenes():
            for t in s.raster_source.transformers:
                if isinstance(t, StatsTransformerConfig):
                    has_stats_transformer = True

        has_stats_analyzer = False
        for a in self.analyzers:
            if isinstance(a, StatsAnalyzerConfig):
                has_stats_analyzer = True
                break

        if has_stats_transformer and not has_stats_analyzer:
            self.analyzers.append(StatsAnalyzerConfig())
    def _insert_analyzers(self):
        # Inserts StatsAnalyzer if it's needed because a RasterSource has a
        # StatsTransformer, but there isn't a StatsAnalyzer in the list of Analyzers.
        has_stats_transformer = False
        for s in self.dataset.get_all_scenes():
            for t in s.raster_source.transformers:
                if isinstance(t, StatsTransformerConfig):
                    has_stats_transformer = True

        has_stats_analyzer = False
        for a in self.analyzers:
            if isinstance(a, StatsAnalyzerConfig):
                has_stats_analyzer = True
                break

        if has_stats_transformer and not has_stats_analyzer:
            self.analyzers.append(StatsAnalyzerConfig())
Exemplo n.º 4
0
    def __init__(self,
                 model_bundle_uri,
                 tmp_dir,
                 update_stats=False,
                 channel_order=None):
        """Creates a new Predictor.

        Args:
            model_bundle_uri: URI of the model bundle to use. Can be any
                type of URI that Raster Vision can read.
            tmp_dir: Temporary directory in which to store files that are used
                by the Predictor. This directory is not cleaned up by this
                class.
            channel_order: Option for a new channel order to use for the
                imagery being predicted against. If not present, the
                channel_order from the original configuration in the predict
                package will be used.
        """
        self.tmp_dir = tmp_dir
        self.update_stats = update_stats
        self.model_loaded = False

        bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        bundle_dir = join(tmp_dir, 'bundle')
        make_dir(bundle_dir)
        with zipfile.ZipFile(bundle_path, 'r') as bundle_zip:
            bundle_zip.extractall(path=bundle_dir)

        config_path = join(bundle_dir, 'pipeline-config.json')
        config_dict = file_to_json(config_path)
        rv_config.set_everett_config(
            config_overrides=config_dict.get('rv_config'))
        config_dict = upgrade_config(config_dict)

        self.pipeline = build_config(config_dict).build(tmp_dir)
        self.scene = None

        if not hasattr(self.pipeline, 'predict'):
            raise Exception(
                'pipeline in model bundle must have predict method')

        self.scene = self.pipeline.config.dataset.validation_scenes[0]

        if not hasattr(self.scene.raster_source, 'uris'):
            raise Exception(
                'raster_source in model bundle must have uris as field')

        if not hasattr(self.scene.label_store, 'uri'):
            raise Exception(
                'label_store in model bundle must have uri as field')

        for t in self.scene.raster_source.transformers:
            t.update_root(bundle_dir)

        if self.update_stats:
            stats_analyzer = StatsAnalyzerConfig(
                output_uri=join(bundle_dir, 'stats.json'))
            self.pipeline.config.analyzers = [stats_analyzer]

        self.scene.label_source = None
        self.scene.aoi_uris = None
        self.pipeline.config.dataset.train_scenes = [self.scene]
        self.pipeline.config.dataset.validation_scenes = [self.scene]
        self.pipeline.config.dataset.test_scenes = None
        if channel_order is not None:
            self.scene.raster_source.channel_order = channel_order