コード例 #1
0
ファイル: learner.py プロジェクト: uiuc-arc/raster-vision
    def save_model_bundle(self):
        """Save a model bundle.

        This is a zip file with the model weights in .pth format and a serialized
        copy of the LearningConfig, which allows for making predictions in the future.
        """
        from rastervision.pytorch_learner.learner_pipeline_config import (
            LearnerPipelineConfig)

        log.info('Creating bundle.')
        model_bundle_dir = join(self.tmp_dir, 'model-bundle')
        make_dir(model_bundle_dir)

        shutil.copyfile(self.last_model_path,
                        join(model_bundle_dir, 'model.pth'))

        # copy modules into bundle
        if isdir(self.modules_dir):
            log.info('Copying modules into bundle.')
            bundle_modules_dir = join(model_bundle_dir, MODULES_DIRNAME)
            if isdir(bundle_modules_dir):
                shutil.rmtree(bundle_modules_dir)
            shutil.copytree(self.modules_dir, bundle_modules_dir)

        pipeline_cfg = LearnerPipelineConfig(learner=self.cfg)
        save_pipeline_config(pipeline_cfg,
                             join(model_bundle_dir, 'pipeline-config.json'))
        zipdir(model_bundle_dir, self.model_bundle_path)
コード例 #2
0
ファイル: cli.py プロジェクト: sumesh1/raster-vision
def _run_pipeline(cfg, runner, tmp_dir, splits=1, commands=None):
    cfg.update()
    cfg.recursive_validate_config()
    # This is to run the validation again to check any fields that may have changed
    # after the Config was constructed, possibly by the update method.
    build_config(cfg.dict())
    cfg_json_uri = cfg.get_config_uri()
    save_pipeline_config(cfg, cfg_json_uri)
    pipeline = cfg.build(tmp_dir)
    if not commands:
        commands = pipeline.commands

    runner.run(cfg_json_uri, pipeline, commands, num_splits=splits)
コード例 #3
0
ファイル: learner.py プロジェクト: yangxhcaf/raster-vision
    def save_model_bundle(self):
        """Save a model bundle.

        This is a zip file with the model weights in .pth format and a serialized
        copy of the LearningConfig, which allows for making predictions in the future.
        """
        from rastervision.pytorch_learner.learner_pipeline_config import (
            LearnerPipelineConfig)
        model_bundle_dir = join(self.tmp_dir, 'model-bundle')
        make_dir(model_bundle_dir)
        shutil.copyfile(self.last_model_path,
                        join(model_bundle_dir, 'model.pth'))
        pipeline_cfg = LearnerPipelineConfig(learner=self.cfg)
        save_pipeline_config(pipeline_cfg,
                             join(model_bundle_dir, 'pipeline-config.json'))
        zipdir(model_bundle_dir, self.model_bundle_path)