def _test(self, is_random=False): stats_uri = os.path.join(self.tmp_dir.name, 'stats.json') scenes = [] raster_sources = [] imgs = [] sample_prob = 0.5 for i in range(3): rs = MockRasterSource([0, 1, 2], 3) img = np.zeros((600, 600, 3)) img[:, :, 0] = 1 + i img[:, :, 1] = 2 + i img[:, :, 2] = 3 + i if not is_random: img[300:, 300:, :] = np.nan imgs.append(img) rs.set_raster(img) raster_sources.append(rs) scenes.append(Scene(str(i), rs)) channel_vals = list(map(lambda x: np.expand_dims(x, axis=0), imgs)) channel_vals = np.concatenate(channel_vals, axis=0) channel_vals = np.transpose(channel_vals, [3, 0, 1, 2]) channel_vals = np.reshape(channel_vals, (3, -1)) exp_means = np.nanmean(channel_vals, axis=1) exp_stds = np.nanstd(channel_vals, axis=1) analyzer_cfg = StatsAnalyzerConfig(output_uri=stats_uri, sample_prob=None) if is_random: analyzer_cfg = StatsAnalyzerConfig(output_uri=stats_uri, sample_prob=sample_prob) analyzer = analyzer_cfg.build() analyzer.process(scenes, self.tmp_dir.name) stats = RasterStats.load(stats_uri) np.testing.assert_array_almost_equal(stats.means, exp_means, decimal=3) np.testing.assert_array_almost_equal(stats.stds, exp_stds, decimal=3) if is_random: for rs in raster_sources: width = rs.get_extent().get_width() height = rs.get_extent().get_height() exp_num_chips = round( ((width * height) / (chip_sz**2)) * sample_prob) self.assertEqual(rs.mock._get_chip.call_count, exp_num_chips)
def _insert_analyzers(self): has_stats_transformer = False for s in self.dataset.get_all_scenes(): for t in s.raster_source.transformers: if isinstance(t, StatsTransformerConfig): has_stats_transformer = True has_stats_analyzer = False for a in self.analyzers: if isinstance(a, StatsAnalyzerConfig): has_stats_analyzer = True break if has_stats_transformer and not has_stats_analyzer: self.analyzers.append(StatsAnalyzerConfig())
def _insert_analyzers(self): # Inserts StatsAnalyzer if it's needed because a RasterSource has a # StatsTransformer, but there isn't a StatsAnalyzer in the list of Analyzers. has_stats_transformer = False for s in self.dataset.get_all_scenes(): for t in s.raster_source.transformers: if isinstance(t, StatsTransformerConfig): has_stats_transformer = True has_stats_analyzer = False for a in self.analyzers: if isinstance(a, StatsAnalyzerConfig): has_stats_analyzer = True break if has_stats_transformer and not has_stats_analyzer: self.analyzers.append(StatsAnalyzerConfig())
def __init__(self, model_bundle_uri, tmp_dir, update_stats=False, channel_order=None): """Creates a new Predictor. Args: model_bundle_uri: URI of the model bundle to use. Can be any type of URI that Raster Vision can read. tmp_dir: Temporary directory in which to store files that are used by the Predictor. This directory is not cleaned up by this class. channel_order: Option for a new channel order to use for the imagery being predicted against. If not present, the channel_order from the original configuration in the predict package will be used. """ self.tmp_dir = tmp_dir self.update_stats = update_stats self.model_loaded = False bundle_path = download_if_needed(model_bundle_uri, tmp_dir) bundle_dir = join(tmp_dir, 'bundle') make_dir(bundle_dir) with zipfile.ZipFile(bundle_path, 'r') as bundle_zip: bundle_zip.extractall(path=bundle_dir) config_path = join(bundle_dir, 'pipeline-config.json') config_dict = file_to_json(config_path) rv_config.set_everett_config( config_overrides=config_dict.get('rv_config')) config_dict = upgrade_config(config_dict) self.pipeline = build_config(config_dict).build(tmp_dir) self.scene = None if not hasattr(self.pipeline, 'predict'): raise Exception( 'pipeline in model bundle must have predict method') self.scene = self.pipeline.config.dataset.validation_scenes[0] if not hasattr(self.scene.raster_source, 'uris'): raise Exception( 'raster_source in model bundle must have uris as field') if not hasattr(self.scene.label_store, 'uri'): raise Exception( 'label_store in model bundle must have uri as field') for t in self.scene.raster_source.transformers: t.update_root(bundle_dir) if self.update_stats: stats_analyzer = StatsAnalyzerConfig( output_uri=join(bundle_dir, 'stats.json')) self.pipeline.config.analyzers = [stats_analyzer] self.scene.label_source = None self.scene.aoi_uris = None self.pipeline.config.dataset.train_scenes = [self.scene] self.pipeline.config.dataset.validation_scenes = [self.scene] self.pipeline.config.dataset.test_scenes = None if channel_order is not None: self.scene.raster_source.channel_order = channel_order