def bundle(self):
        """Save a model bundle with whatever is needed to make predictions.

        The model bundle is a zip file and it is used by the Predictor and
        predict CLI subcommand.
        """
        with tempfile.TemporaryDirectory(dir=self.tmp_dir) as tmp_dir:
            bundle_dir = join(tmp_dir, 'bundle')
            make_dir(bundle_dir)

            for fn in self.config.backend.get_bundle_filenames():
                path = download_if_needed(join(self.config.train_uri, fn),
                                          tmp_dir)
                shutil.copy(path, join(bundle_dir, fn))

            for a in self.config.analyzers:
                for fn in a.get_bundle_filenames():
                    path = download_if_needed(
                        join(self.config.analyze_uri, fn), tmp_dir)
                    shutil.copy(path, join(bundle_dir, fn))

            path = download_if_needed(self.config.get_config_uri(), tmp_dir)
            shutil.copy(path, join(bundle_dir, 'pipeline-config.json'))

            model_bundle_uri = self.config.get_model_bundle_uri()
            model_bundle_path = get_local_path(model_bundle_uri, self.tmp_dir)
            zipdir(bundle_dir, model_bundle_path)
            upload_or_copy(model_bundle_path, model_bundle_uri)
Exemple #2
0
def main(args, extra_args):
    make_dir(args.tmp_root)
    make_dir(args.cache_dir)

    with tempfile.TemporaryDirectory(dir=args.tmp_root) as tmp_dir:
        output_uri = get_local_path(args.output_uri, tmp_dir)
        pretrained_uri = (
            get_file(args.pretrained_uri, args.cache_dir)
            if args.pretrained_uri else None)
        dataset_uri = open_zip_file(args.dataset_uri, args.cache_dir)

        try:
            run_vissl(
                args.config, dataset_uri, output_uri, extra_args,
                pretrained_path=pretrained_uri)
            extract_backbone(
                join(output_uri, 'checkpoint.torch'),
                join(output_uri, 'backbone.torch'))
        finally:
            sync_to_dir(output_uri, args.output_uri)
Exemple #3
0
    def __init__(self,
                 model_bundle_uri,
                 tmp_dir,
                 update_stats=False,
                 channel_order=None):
        """Creates a new Predictor.

        Args:
            model_bundle_uri: URI of the model bundle to use. Can be any
                type of URI that Raster Vision can read.
            tmp_dir: Temporary directory in which to store files that are used
                by the Predictor. This directory is not cleaned up by this
                class.
            channel_order: Option for a new channel order to use for the
                imagery being predicted against. If not present, the
                channel_order from the original configuration in the predict
                package will be used.
        """
        self.tmp_dir = tmp_dir
        self.update_stats = update_stats
        self.model_loaded = False

        bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        bundle_dir = join(tmp_dir, 'bundle')
        make_dir(bundle_dir)
        with zipfile.ZipFile(bundle_path, 'r') as bundle_zip:
            bundle_zip.extractall(path=bundle_dir)

        config_path = join(bundle_dir, 'pipeline-config.json')
        config_dict = file_to_json(config_path)
        rv_config.set_everett_config(
            config_overrides=config_dict.get('rv_config'))
        config_dict = upgrade_config(config_dict)

        self.pipeline = build_config(config_dict).build(tmp_dir)
        self.scene = None

        if not hasattr(self.pipeline, 'predict'):
            raise Exception(
                'pipeline in model bundle must have predict method')

        self.scene = self.pipeline.config.dataset.validation_scenes[0]

        if not hasattr(self.scene.raster_source, 'uris'):
            raise Exception(
                'raster_source in model bundle must have uris as field')

        if not hasattr(self.scene.label_store, 'uri'):
            raise Exception(
                'label_store in model bundle must have uri as field')

        for t in self.scene.raster_source.transformers:
            t.update_root(bundle_dir)

        if self.update_stats:
            stats_analyzer = StatsAnalyzerConfig(
                output_uri=join(bundle_dir, 'stats.json'))
            self.pipeline.config.analyzers = [stats_analyzer]

        self.scene.label_source = None
        self.scene.aoi_uris = None
        self.pipeline.config.dataset.train_scenes = [self.scene]
        self.pipeline.config.dataset.validation_scenes = [self.scene]
        self.pipeline.config.dataset.test_scenes = None
        self.pipeline.config.train_uri = bundle_dir

        if channel_order is not None:
            self.scene.raster_source.channel_order = channel_order