コード例 #1
0
    def bundle(self):
        """Save a model bundle with whatever is needed to make predictions.

        The model bundle is a zip file and it is used by the Predictor and
        predict CLI subcommand.
        """
        with tempfile.TemporaryDirectory(dir=self.tmp_dir) as tmp_dir:
            bundle_dir = join(tmp_dir, 'bundle')
            make_dir(bundle_dir)

            for fn in self.config.backend.get_bundle_filenames():
                path = download_if_needed(join(self.config.train_uri, fn),
                                          tmp_dir)
                shutil.copy(path, join(bundle_dir, fn))

            for a in self.config.analyzers:
                for fn in a.get_bundle_filenames():
                    path = download_if_needed(
                        join(self.config.analyze_uri, fn), tmp_dir)
                    shutil.copy(path, join(bundle_dir, fn))

            path = download_if_needed(self.config.get_config_uri(), tmp_dir)
            shutil.copy(path, join(bundle_dir, 'pipeline-config.json'))

            model_bundle_uri = self.config.get_model_bundle_uri()
            model_bundle_path = get_local_path(model_bundle_uri, self.tmp_dir)
            zipdir(bundle_dir, model_bundle_path)
            upload_or_copy(model_bundle_path, model_bundle_uri)
コード例 #2
0
def get_file(uri, cache_dir):
    path = get_local_path(uri, cache_dir)
    if file_exists(path):
        print(f'Using cached file in {path}')
    else:
        path = download_if_needed(uri, cache_dir)
    return path
コード例 #3
0
def open_zip_file(zip_uri, cache_dir):
    zip_path = get_file(zip_uri, cache_dir)
    zip_dir = splitext(get_local_path(zip_uri, cache_dir))[0]
    if isdir(zip_dir):
        print(f'Using cached data in {zip_dir}')
    else:
        unzip(zip_path, zip_dir)
    return zip_dir
コード例 #4
0
ファイル: vissl_wrapper.py プロジェクト: lewfish/ssl
def main(args, extra_args):
    make_dir(args.tmp_root)
    make_dir(args.cache_dir)

    with tempfile.TemporaryDirectory(dir=args.tmp_root) as tmp_dir:
        output_uri = get_local_path(args.output_uri, tmp_dir)
        pretrained_uri = (
            get_file(args.pretrained_uri, args.cache_dir)
            if args.pretrained_uri else None)
        dataset_uri = open_zip_file(args.dataset_uri, args.cache_dir)

        try:
            run_vissl(
                args.config, dataset_uri, output_uri, extra_args,
                pretrained_path=pretrained_uri)
            extract_backbone(
                join(output_uri, 'checkpoint.torch'),
                join(output_uri, 'backbone.torch'))
        finally:
            sync_to_dir(output_uri, args.output_uri)