Ejemplo n.º 1
0
def _run_pipeline(cfg, runner, tmp_dir, splits=1, commands=None):
    cfg.update()
    cfg.recursive_validate_config()
    # This is to run the validation again to check any fields that may have changed
    # after the Config was constructed, possibly by the update method.
    build_config(cfg.dict())
    cfg_json_uri = cfg.get_config_uri()
    save_pipeline_config(cfg, cfg_json_uri)
    pipeline = cfg.build(tmp_dir)
    if not commands:
        commands = pipeline.commands

    runner.run(cfg_json_uri, pipeline, commands, num_splits=splits)
Ejemplo n.º 2
0
def _run_command(cfg_json_uri,
                 command,
                 split_ind=None,
                 num_splits=None,
                 runner=None):
    pipeline_cfg_dict = file_to_json(cfg_json_uri)
    rv_config_dict = pipeline_cfg_dict.get('rv_config')
    rv_config.reset(
        config_overrides=rv_config_dict,
        verbosity=rv_config.verbosity,
        tmp_dir=rv_config.tmp_dir)

    tmp_dir_obj = rv_config.get_tmp_dir()
    tmp_dir = tmp_dir_obj.name

    cfg = build_config(pipeline_cfg_dict)
    pipeline = cfg.build(tmp_dir)

    if num_splits is not None and split_ind is None and runner is not None:
        runner = registry.get_runner(runner)()
        split_ind = runner.get_split_ind()

    command_fn = getattr(pipeline, command)

    if num_splits is not None and num_splits > 1:
        msg = 'Running {} command split {}/{}...'.format(
            command, split_ind + 1, num_splits)
        click.echo(click.style(msg, fg='green'))
        command_fn(split_ind=split_ind, num_splits=num_splits)
    else:
        msg = 'Running {} command...'.format(command)
        click.echo(click.style(msg, fg='green'))
        command_fn()
Ejemplo n.º 3
0
    def from_model_bundle(model_bundle_uri, tmp_dir):
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'config.json')
        model_path = join(model_bundle_dir, 'model.pth')
        cfg = build_config(file_to_json(config_path))
        return cfg.build(tmp_dir, model_path=model_path)
Ejemplo n.º 4
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'learner-config.json')
        model_path = join(model_bundle_dir, 'model.pth')
        cfg = build_config(file_to_json(config_path))
        return cfg.build(tmp_dir, model_path=model_path)
Ejemplo n.º 5
0
def run(runner: str, cfg_module: str, commands: List[str],
        arg: List[Tuple[str, str]], splits: int):
    """Run COMMANDS within pipelines in CFG_MODULE using RUNNER.

    RUNNER: name of the Runner to use

    CFG_MODULE: the module with `get_configs` function that returns PipelineConfigs.
    This can either be a Python module path or a local path to a .py file.

    COMMANDS: space separated sequence of commands to run within pipeline. The order in
    which to run them is based on the Pipeline.commands attribute. If this is omitted,
    all commands will be run.
    """
    tmp_dir_obj = rv_config.get_tmp_dir()
    tmp_dir = tmp_dir_obj.name

    args = dict(arg)
    args = convert_bool_args(args)
    cfgs = get_configs(cfg_module, runner, args)
    runner = registry.get_runner(runner)()

    for cfg in cfgs:
        cfg.update()
        cfg.rv_config = rv_config.get_config_dict(registry.rv_config_schema)
        cfg.recursive_validate_config()
        # This is to run the validation again to check any fields that may have changed
        # after the Config was constructed, possibly by the update method.
        build_config(cfg.dict())

        cfg_json = cfg.json()
        cfg_json_uri = cfg.get_config_uri()
        str_to_file(cfg_json, cfg_json_uri)

        pipeline = cfg.build(tmp_dir)
        if not commands:
            commands = pipeline.commands

        runner.run(cfg_json_uri, pipeline, commands, num_splits=splits)
Ejemplo n.º 6
0
    def test_to_from(self):
        cfg = CConfig(al=[AConfig(), ASub1Config(),
                          ASub2Config()],
                      bl=[BConfig()],
                      a=ASub1Config(),
                      b=BConfig(),
                      plugin_versions=self.plugin_versions,
                      root_uri=None,
                      rv_config=None)

        exp_dict = {
            'plugin_versions':
            self.plugin_versions,
            'root_uri':
            None,
            'rv_config':
            None,
            'type_hint':
            'c',
            'a': {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            },
            'al': [{
                'type_hint': 'a',
                'x': 'x'
            }, {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            }, {
                'type_hint': 'asub2',
                'x': 'x',
                'y': 'y'
            }],
            'b': {
                'type_hint': 'b',
                'x': 'x'
            },
            'bl': [{
                'type_hint': 'b',
                'x': 'x'
            }],
            'x':
            'x'
        }

        self.assertDictEqual(cfg.dict(), exp_dict)
        self.assertEqual(build_config(exp_dict), cfg)
Ejemplo n.º 7
0
    def test_to_from(self):
        cfg = CConfig(
            al=[AConfig(), ASub1Config(),
                ASub2Config()],
            bl=[BConfig()],
            a=ASub1Config(),
            b=BConfig())

        exp_dict = {
            'type_hint':
            'c',
            'version':
            1,
            'a': {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            },
            'al': [{
                'x': 'x'
            }, {
                'type_hint': 'asub1',
                'x': 'x',
                'y': 'y'
            }, {
                'type_hint': 'asub2',
                'x': 'x',
                'y': 'y'
            }],
            'b': {
                'x': 'x'
            },
            'bl': [{
                'x': 'x'
            }],
            'x':
            'x'
        }

        self.assertDictEqual(cfg.dict(), exp_dict)
        self.assertEqual(build_config(exp_dict), cfg)
Ejemplo n.º 8
0
def _run_command(cfg_json_uri: str,
                 command: str,
                 split_ind: Optional[int] = None,
                 num_splits: Optional[int] = None,
                 runner: Optional[str] = None):
    """Run a single command using a serialized PipelineConfig.

    Args:
        cfg_json_uri: URI of a JSON file with a serialized PipelineConfig
        command: name of command to run
        split_ind: the index that a split command should assume
        num_splits: the total number of splits to use
        runner: the name of the runner to use
    """
    pipeline_cfg_dict = file_to_json(cfg_json_uri)
    rv_config_dict = pipeline_cfg_dict.get('rv_config')
    rv_config.set_everett_config(profile=rv_config.profile,
                                 config_overrides=rv_config_dict)

    tmp_dir_obj = rv_config.get_tmp_dir()
    tmp_dir = tmp_dir_obj.name

    cfg = build_config(pipeline_cfg_dict)
    pipeline = cfg.build(tmp_dir)

    if num_splits is not None and split_ind is None and runner is not None:
        runner = registry.get_runner(runner)()
        split_ind = runner.get_split_ind()

    command_fn = getattr(pipeline, command)

    if num_splits is not None and num_splits > 1:
        msg = 'Running {} command split {}/{}...'.format(
            command, split_ind + 1, num_splits)
        click.echo(click.style(msg, fg='green'))
        command_fn(split_ind=split_ind, num_splits=num_splits)
    else:
        msg = 'Running {} command...'.format(command)
        click.echo(click.style(msg, fg='green'))
        command_fn()
Ejemplo n.º 9
0
    def __init__(self,
                 model_bundle_uri,
                 tmp_dir,
                 update_stats=False,
                 channel_order=None):
        """Creates a new Predictor.

        Args:
            model_bundle_uri: URI of the model bundle to use. Can be any
                type of URI that Raster Vision can read.
            tmp_dir: Temporary directory in which to store files that are used
                by the Predictor. This directory is not cleaned up by this
                class.
            channel_order: Option for a new channel order to use for the
                imagery being predicted against. If not present, the
                channel_order from the original configuration in the predict
                package will be used.
        """
        self.tmp_dir = tmp_dir
        self.update_stats = update_stats
        self.model_loaded = False

        bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        bundle_dir = join(tmp_dir, 'bundle')
        make_dir(bundle_dir)
        with zipfile.ZipFile(bundle_path, 'r') as bundle_zip:
            bundle_zip.extractall(path=bundle_dir)

        config_path = join(bundle_dir, 'pipeline-config.json')
        config_dict = file_to_json(config_path)
        rv_config.set_everett_config(
            config_overrides=config_dict.get('rv_config'))
        config_dict = upgrade_config(config_dict)

        self.pipeline = build_config(config_dict).build(tmp_dir)
        self.scene = None

        if not hasattr(self.pipeline, 'predict'):
            raise Exception(
                'pipeline in model bundle must have predict method')

        self.scene = self.pipeline.config.dataset.validation_scenes[0]

        if not hasattr(self.scene.raster_source, 'uris'):
            raise Exception(
                'raster_source in model bundle must have uris as field')

        if not hasattr(self.scene.label_store, 'uri'):
            raise Exception(
                'label_store in model bundle must have uri as field')

        for t in self.scene.raster_source.transformers:
            t.update_root(bundle_dir)

        if self.update_stats:
            stats_analyzer = StatsAnalyzerConfig(
                output_uri=join(bundle_dir, 'stats.json'))
            self.pipeline.config.analyzers = [stats_analyzer]

        self.scene.label_source = None
        self.scene.aoi_uris = None
        self.pipeline.config.dataset.train_scenes = [self.scene]
        self.pipeline.config.dataset.validation_scenes = [self.scene]
        self.pipeline.config.dataset.test_scenes = None
        if channel_order is not None:
            self.scene.raster_source.channel_order = channel_order