def save_model_bundle(self): """Save a model bundle. This is a zip file with the model weights in .pth format and a serialized copy of the LearningConfig, which allows for making predictions in the future. """ from rastervision.pytorch_learner.learner_pipeline_config import ( LearnerPipelineConfig) log.info('Creating bundle.') model_bundle_dir = join(self.tmp_dir, 'model-bundle') make_dir(model_bundle_dir) shutil.copyfile(self.last_model_path, join(model_bundle_dir, 'model.pth')) # copy modules into bundle if isdir(self.modules_dir): log.info('Copying modules into bundle.') bundle_modules_dir = join(model_bundle_dir, MODULES_DIRNAME) if isdir(bundle_modules_dir): shutil.rmtree(bundle_modules_dir) shutil.copytree(self.modules_dir, bundle_modules_dir) pipeline_cfg = LearnerPipelineConfig(learner=self.cfg) save_pipeline_config(pipeline_cfg, join(model_bundle_dir, 'pipeline-config.json')) zipdir(model_bundle_dir, self.model_bundle_path)
def _run_pipeline(cfg, runner, tmp_dir, splits=1, commands=None): cfg.update() cfg.recursive_validate_config() # This is to run the validation again to check any fields that may have changed # after the Config was constructed, possibly by the update method. build_config(cfg.dict()) cfg_json_uri = cfg.get_config_uri() save_pipeline_config(cfg, cfg_json_uri) pipeline = cfg.build(tmp_dir) if not commands: commands = pipeline.commands runner.run(cfg_json_uri, pipeline, commands, num_splits=splits)
def save_model_bundle(self): """Save a model bundle. This is a zip file with the model weights in .pth format and a serialized copy of the LearningConfig, which allows for making predictions in the future. """ from rastervision.pytorch_learner.learner_pipeline_config import ( LearnerPipelineConfig) model_bundle_dir = join(self.tmp_dir, 'model-bundle') make_dir(model_bundle_dir) shutil.copyfile(self.last_model_path, join(model_bundle_dir, 'model.pth')) pipeline_cfg = LearnerPipelineConfig(learner=self.cfg) save_pipeline_config(pipeline_cfg, join(model_bundle_dir, 'pipeline-config.json')) zipdir(model_bundle_dir, self.model_bundle_path)