示例#1
0
def _run_pipeline(cfg, runner, tmp_dir, splits=1, commands=None):
    cfg.update()
    cfg.recursive_validate_config()
    # This is to run the validation again to check any fields that may have changed
    # after the Config was constructed, possibly by the update method.
    build_config(cfg.dict())
    cfg_json_uri = cfg.get_config_uri()
    save_pipeline_config(cfg, cfg_json_uri)
    pipeline = cfg.build(tmp_dir)
    if not commands:
        commands = pipeline.commands

    runner.run(cfg_json_uri, pipeline, commands, num_splits=splits)
示例#2
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)

        hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
        model_def_path = None
        loss_def_path = None

        # retrieve existing model definition, if available
        ext_cfg = cfg.learner.model.external_def
        if ext_cfg is not None:
            model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(
                f'Using model definition found in bundle: {model_def_path}')

        # retrieve existing loss function definition, if available
        ext_cfg = cfg.learner.solver.external_loss_def
        if ext_cfg is not None:
            loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(f'Using loss definition found in bundle: {loss_def_path}')

        return cfg.learner.build(tmp_dir=tmp_dir,
                                 model_path=model_path,
                                 model_def_path=model_def_path,
                                 loss_def_path=loss_def_path)
示例#3
0
 def _test_config_upgrader(self, OldCfgType, NewCfgType, upgrader,
                           curr_version):
     old_cfg = OldCfgType()
     old_cfg_dict = old_cfg.dict()
     for i in range(curr_version):
         old_cfg_dict = upgrader(old_cfg_dict, version=i)
     new_cfg = build_config(old_cfg_dict)
     self.assertTrue(isinstance(new_cfg, NewCfgType))
示例#4
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)
        return cfg.learner.build(tmp_dir, model_path=model_path)
    def test_to_from(self):
        cfg = CConfig(al=[AConfig(), ASub1Config(),
                          ASub2Config()],
                      bl=[BConfig()],
                      a=ASub1Config(),
                      b=BConfig(),
                      plugin_versions=self.plugin_versions,
                      root_uri=None,
                      rv_config=None)

        exp_dict = {
            'plugin_versions':
            self.plugin_versions,
            'root_uri':
            None,
            'rv_config':
            None,
            'type_hint':
            'c',
            'a': {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            },
            'al': [{
                'type_hint': 'a',
                'x': 'x'
            }, {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            }, {
                'type_hint': 'asub2',
                'x': 'x',
                'y': 'y'
            }],
            'b': {
                'type_hint': 'b',
                'x': 'x'
            },
            'bl': [{
                'type_hint': 'b',
                'x': 'x'
            }],
            'x':
            'x'
        }

        self.assertDictEqual(cfg.dict(), exp_dict)
        self.assertEqual(build_config(exp_dict), cfg)
示例#6
0
def _run_command(cfg_json_uri: str,
                 command: str,
                 split_ind: Optional[int] = None,
                 num_splits: Optional[int] = None,
                 runner: Optional[str] = None):
    """Run a single command using a serialized PipelineConfig.

    Args:
        cfg_json_uri: URI of a JSON file with a serialized PipelineConfig
        command: name of command to run
        split_ind: the index that a split command should assume
        num_splits: the total number of splits to use
        runner: the name of the runner to use
    """
    pipeline_cfg_dict = file_to_json(cfg_json_uri)
    rv_config_dict = pipeline_cfg_dict.get('rv_config')
    rv_config.set_everett_config(profile=rv_config.profile,
                                 config_overrides=rv_config_dict)

    tmp_dir_obj = rv_config.get_tmp_dir()
    tmp_dir = tmp_dir_obj.name

    cfg = build_config(pipeline_cfg_dict)
    pipeline = cfg.build(tmp_dir)

    if num_splits is not None and split_ind is None and runner is not None:
        runner = registry.get_runner(runner)()
        split_ind = runner.get_split_ind()

    command_fn = getattr(pipeline, command)

    if num_splits is not None and num_splits > 1:
        msg = 'Running {} command split {}/{}...'.format(
            command, split_ind + 1, num_splits)
        click.secho(msg, fg='green', bold=True)
        command_fn(split_ind=split_ind, num_splits=num_splits)
    else:
        msg = 'Running {} command...'.format(command)
        click.secho(msg, fg='green', bold=True)
        command_fn()
示例#7
0
    def __init__(self,
                 model_bundle_uri,
                 tmp_dir,
                 update_stats=False,
                 channel_order=None):
        """Creates a new Predictor.

        Args:
            model_bundle_uri: URI of the model bundle to use. Can be any
                type of URI that Raster Vision can read.
            tmp_dir: Temporary directory in which to store files that are used
                by the Predictor. This directory is not cleaned up by this
                class.
            channel_order: Option for a new channel order to use for the
                imagery being predicted against. If not present, the
                channel_order from the original configuration in the predict
                package will be used.
        """
        self.tmp_dir = tmp_dir
        self.update_stats = update_stats
        self.model_loaded = False

        bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        bundle_dir = join(tmp_dir, 'bundle')
        make_dir(bundle_dir)
        with zipfile.ZipFile(bundle_path, 'r') as bundle_zip:
            bundle_zip.extractall(path=bundle_dir)

        config_path = join(bundle_dir, 'pipeline-config.json')
        config_dict = file_to_json(config_path)
        rv_config.set_everett_config(
            config_overrides=config_dict.get('rv_config'))
        config_dict = upgrade_config(config_dict)

        self.pipeline = build_config(config_dict).build(tmp_dir)
        self.scene = None

        if not hasattr(self.pipeline, 'predict'):
            raise Exception(
                'pipeline in model bundle must have predict method')

        self.scene = self.pipeline.config.dataset.validation_scenes[0]

        if not hasattr(self.scene.raster_source, 'uris'):
            raise Exception(
                'raster_source in model bundle must have uris as field')

        if not hasattr(self.scene.label_store, 'uri'):
            raise Exception(
                'label_store in model bundle must have uri as field')

        for t in self.scene.raster_source.transformers:
            t.update_root(bundle_dir)

        if self.update_stats:
            stats_analyzer = StatsAnalyzerConfig(
                output_uri=join(bundle_dir, 'stats.json'))
            self.pipeline.config.analyzers = [stats_analyzer]

        self.scene.label_source = None
        self.scene.aoi_uris = None
        self.pipeline.config.dataset.train_scenes = [self.scene]
        self.pipeline.config.dataset.validation_scenes = [self.scene]
        self.pipeline.config.dataset.test_scenes = None
        self.pipeline.config.train_uri = bundle_dir

        if channel_order is not None:
            self.scene.raster_source.channel_order = channel_order