Ejemplo n.º 1
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)

        hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
        model_def_path = None
        loss_def_path = None

        # retrieve existing model definition, if available
        ext_cfg = cfg.learner.model.external_def
        if ext_cfg is not None:
            model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(
                f'Using model definition found in bundle: {model_def_path}')

        # retrieve existing loss function definition, if available
        ext_cfg = cfg.learner.solver.external_loss_def
        if ext_cfg is not None:
            loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(f'Using loss definition found in bundle: {loss_def_path}')

        return cfg.learner.build(tmp_dir=tmp_dir,
                                 model_path=model_path,
                                 model_def_path=model_def_path,
                                 loss_def_path=loss_def_path)
Ejemplo n.º 2
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)
        return cfg.learner.build(tmp_dir, model_path=model_path)
Ejemplo n.º 3
0
    def __init__(self,
                 model_bundle_uri,
                 tmp_dir,
                 update_stats=False,
                 channel_order=None):
        """Creates a new Predictor.

        Args:
            model_bundle_uri: URI of the model bundle to use. Can be any
                type of URI that Raster Vision can read.
            tmp_dir: Temporary directory in which to store files that are used
                by the Predictor. This directory is not cleaned up by this
                class.
            channel_order: Option for a new channel order to use for the
                imagery being predicted against. If not present, the
                channel_order from the original configuration in the predict
                package will be used.
        """
        self.tmp_dir = tmp_dir
        self.update_stats = update_stats
        self.model_loaded = False

        bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        bundle_dir = join(tmp_dir, 'bundle')
        make_dir(bundle_dir)
        with zipfile.ZipFile(bundle_path, 'r') as bundle_zip:
            bundle_zip.extractall(path=bundle_dir)

        config_path = join(bundle_dir, 'pipeline-config.json')
        config_dict = file_to_json(config_path)
        rv_config.set_everett_config(
            config_overrides=config_dict.get('rv_config'))
        config_dict = upgrade_config(config_dict)

        self.pipeline = build_config(config_dict).build(tmp_dir)
        self.scene = None

        if not hasattr(self.pipeline, 'predict'):
            raise Exception(
                'pipeline in model bundle must have predict method')

        self.scene = self.pipeline.config.dataset.validation_scenes[0]

        if not hasattr(self.scene.raster_source, 'uris'):
            raise Exception(
                'raster_source in model bundle must have uris as field')

        if not hasattr(self.scene.label_store, 'uri'):
            raise Exception(
                'label_store in model bundle must have uri as field')

        for t in self.scene.raster_source.transformers:
            t.update_root(bundle_dir)

        if self.update_stats:
            stats_analyzer = StatsAnalyzerConfig(
                output_uri=join(bundle_dir, 'stats.json'))
            self.pipeline.config.analyzers = [stats_analyzer]

        self.scene.label_source = None
        self.scene.aoi_uris = None
        self.pipeline.config.dataset.train_scenes = [self.scene]
        self.pipeline.config.dataset.validation_scenes = [self.scene]
        self.pipeline.config.dataset.test_scenes = None
        self.pipeline.config.train_uri = bundle_dir

        if channel_order is not None:
            self.scene.raster_source.channel_order = channel_order
    def test_upgrade(self):
        plugin_versions_v0 = dict(self.plugin_versions)
        plugin_versions_v0['rastervision.ab'] = 0
        plugin_versions_v0['rastervision2.c'] = 0

        # after upgrading: the y field in the root should get converted to x, and
        # the z field in the instances of a should get convert to x.
        c_dict_v0 = {
            'plugin_versions':
            plugin_versions_v0,
            'root_uri':
            None,
            'rv_config':
            None,
            'type_hint':
            'c',
            'a': {
                'type_hint': 'asub1',
                'z': 'x',
                'y': 'y'
            },
            'al': [{
                'type_hint': 'a',
                'z': 'x'
            }, {
                'type_hint': 'asub1',
                'z': 'x',
                'y': 'y'
            }, {
                'type_hint': 'asub2',
                'z': 'x',
                'y': 'y'
            }],
            'b': {
                'type_hint': 'b',
                'x': 'x'
            },
            'bl': [{
                'type_hint': 'b',
                'x': 'x'
            }],
            'y':
            'x'
        }

        c_dict_v1 = {
            'plugin_versions':
            plugin_versions_v0,
            'root_uri':
            None,
            'rv_config':
            None,
            'type_hint':
            'c',
            'a': {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            },
            'al': [{
                'type_hint': 'a',
                'x': 'x'
            }, {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            }, {
                'type_hint': 'asub2',
                'x': 'x',
                'y': 'y'
            }],
            'b': {
                'type_hint': 'b',
                'x': 'x'
            },
            'bl': [{
                'type_hint': 'b',
                'x': 'x'
            }],
            'x':
            'x'
        }

        upgraded_c_dict = upgrade_config(c_dict_v0)
        self.assertDictEqual(upgraded_c_dict, c_dict_v1)