Exemplo n.º 1
0
    def test_can_copy_recursively(self):
        # will be copied shallow, because uses plain dict
        schema_low = SchemaBuilder() \
            .add('options', dict) \
            .build()

        # will be copied deeply, because uses DictConfig
        schema_top = SchemaBuilder() \
            .add('container', lambda: DictConfig(
                lambda v: Config(v, schema=schema_low))) \
            .build()

        src_conf = Config({
            'container': {
                'x': {
                    'options': {
                        'k': 1
                    }
                }
            }
        }, schema=schema_top)

        copied_conf = Config(src_conf, schema=schema_top)
        copied_conf['container']['y'] = {'options': {'k': 2} }

        self.assertNotEqual(copied_conf, src_conf)
Exemplo n.º 2
0
    def test_can_produce_multilayer_config_from_dict(self):
        schema_low = SchemaBuilder() \
            .add('options', dict) \
            .build()
        schema_mid = SchemaBuilder() \
            .add('desc', lambda: Config(schema=schema_low)) \
            .build()
        schema_top = SchemaBuilder() \
            .add('container', lambda: DefaultConfig(
                lambda v: Config(v, schema=schema_mid))) \
            .build()

        value = 1
        source = Config(
            {'container': {
                'elem': {
                    'desc': {
                        'options': {
                            'k': value
                        }
                    }
                }
            }},
            schema=schema_top)

        self.assertEqual(value, source.container['elem'].desc.options['k'])
Exemplo n.º 3
0
    def test_cant_dump_custom_types(self):
        # The reason for this is safety.
        class X:
            pass
        conf = Config({ 'x': X() })

        with self.assertRaises(yaml.representer.RepresenterError):
            conf.dump(StringIO())
Exemplo n.º 4
0
    def make_source_project(self, name):
        source = self.get_source(name)

        config = Config(self.config)
        config.remove('sources')
        config.remove('subsets')
        project = Project(config)
        project.add_source(name, source)
        return project
Exemplo n.º 5
0
 def __init__(self, config=None, env=None):
     self.config = Config(config,
                          fallback=PROJECT_DEFAULT_CONFIG,
                          schema=PROJECT_SCHEMA)
     if env is None:
         env = Environment(self.config)
     elif config is not None:
         raise ValueError(
             "env can only be provided when no config provided")
     self.env = env
Exemplo n.º 6
0
    def test_cant_import_custom_types(self):
        # The reason for this is safety. The problem is mostly about
        # importing, because it can result in remote code execution or
        # cause unpredictable problems

        s = StringIO()
        yaml.dump({ 'x': os.system }, s, Dumper=yaml.Dumper)
        s.seek(0)

        with self.assertRaises(yaml.constructor.ConstructorError):
            Config.parse(s)
Exemplo n.º 7
0
    def test_cant_set_incorrect_value(self):
        schema = SchemaBuilder() \
            .add('k', int) \
            .build()

        with self.assertRaises(ValueError):
            Config({ 'k': 'srf' }, schema=schema)
Exemplo n.º 8
0
    def test_cant_set_incorrect_key(self):
        schema = SchemaBuilder() \
            .add('k', int) \
            .build()

        with self.assertRaises(KeyError):
            Config({ 'v': 11 }, schema=schema)
Exemplo n.º 9
0
    def __init__(self, config=None):
        config = Config(config,
                        fallback=PROJECT_DEFAULT_CONFIG,
                        schema=PROJECT_SCHEMA)

        self.models = ModelRegistry(config)
        self.sources = SourceRegistry(config)

        self.git = GitWrapper(config)

        env_dir = osp.join(config.project_dir, config.env_dir)
        builtin = self._load_builtin_plugins()
        custom = self._load_plugins2(osp.join(env_dir, config.plugins_dir))
        select = lambda seq, t: [e for e in seq if issubclass(e, t)]
        from datumaro.components.extractor import Transform
        from datumaro.components.extractor import SourceExtractor
        from datumaro.components.extractor import Importer
        from datumaro.components.converter import Converter
        from datumaro.components.launcher import Launcher
        self.extractors = PluginRegistry(builtin=select(
            builtin, SourceExtractor),
                                         local=select(custom, SourceExtractor))
        self.extractors.register(self.PROJECT_EXTRACTOR_NAME,
                                 load_project_as_dataset)

        self.importers = PluginRegistry(builtin=select(builtin, Importer),
                                        local=select(custom, Importer))
        self.launchers = PluginRegistry(builtin=select(builtin, Launcher),
                                        local=select(custom, Launcher))
        self.converters = PluginRegistry(builtin=select(builtin, Converter),
                                         local=select(custom, Converter))
        self.transforms = PluginRegistry(builtin=select(builtin, Transform),
                                         local=select(custom, Transform))
Exemplo n.º 10
0
    def _save_branch_project(self, extractor, save_dir=None):
        if not isinstance(extractor, Dataset):
            extractor = Dataset.from_extractors(
                extractor
            )  # apply lazy transforms to avoid repeating traversals

        # NOTE: probably this function should be in the ViewModel layer
        save_dir = osp.abspath(save_dir)
        if save_dir:
            dst_project = Project()
        else:
            if not self.config.project_dir:
                raise ValueError("Either a save directory or a project "
                                 "directory should be specified")
            save_dir = self.config.project_dir

            dst_project = Project(Config(self.config))
            dst_project.config.remove('project_dir')
            dst_project.config.remove('sources')
        dst_project.config.project_name = osp.basename(save_dir)

        dst_dataset = dst_project.make_dataset()
        dst_dataset._categories = extractor.categories()
        dst_dataset.update(extractor)

        dst_dataset.save(save_dir=save_dir, merge=True)
Exemplo n.º 11
0
    def __init__(self, url):
        super().__init__()

        local_dir = url
        self._local_dir = local_dir
        self._cache_dir = osp.join(local_dir, 'images')

        with open(osp.join(url, 'config.json'), 'r') as config_file:
            config = json.load(config_file)
            config = Config(config,
                            fallback=DEFAULT_CONFIG,
                            schema=CONFIG_SCHEMA)
        self._config = config

        with open(osp.join(url, 'images_meta.json'), 'r') as images_file:
            images_meta = json.load(images_file)
            image_list = images_meta['images']

        items = []
        for entry in image_list:
            item_id = entry['id']
            item = datumaro.DatasetItem(id=item_id,
                                        image=self._make_image_loader(item_id))
            items.append((item.id, item))

        items = sorted(items, key=lambda e: e[0])
        items = OrderedDict(items)
        self._items = items

        self._cvat_cli = None
        self._session = None
Exemplo n.º 12
0
    def __init__(self, url):
        super().__init__()

        local_dir = url
        self._local_dir = local_dir
        self._cache_dir = osp.join(local_dir, 'images')

        with open(osp.join(url, 'config.json'), 'r') as config_file:
            config = json.load(config_file)
            config = Config(config, schema=CONFIG_SCHEMA)
        self._config = config

        with open(osp.join(url, 'images_meta.json'), 'r') as images_file:
            images_meta = json.load(images_file)
            image_list = images_meta['images']

        items = []
        for entry in image_list:
            item_id = entry['id']
            item_filename = entry.get('name', str(item_id))
            size = None
            if entry.get('height') and entry.get('width'):
                size = (entry['height'], entry['width'])
            image = Image(data=self._make_image_loader(item_id),
                          path=item_filename,
                          size=size)
            item = DatasetItem(id=item_id, image=image)
            items.append((item.id, item))

        items = sorted(items, key=lambda e: int(e[0]))
        items = OrderedDict(items)
        self._items = items

        self._cvat_cli = None
        self._session = None
Exemplo n.º 13
0
 def load(cls, path):
     path = osp.abspath(path)
     config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir,
         PROJECT_DEFAULT_CONFIG.project_filename)
     config = Config.parse(config_path)
     config.project_dir = path
     config.project_filename = osp.basename(config_path)
     return Project(config)
Exemplo n.º 14
0
 def load(path):
     path = osp.abspath(path)
     if osp.isdir(path):
         path = osp.join(path, PROJECT_DEFAULT_CONFIG.project_filename)
     config = Config.parse(path)
     config.project_dir = osp.dirname(path)
     config.project_filename = osp.basename(path)
     return Project(config)
Exemplo n.º 15
0
    def save(self,
             save_dir=None,
             merge=False,
             recursive=True,
             save_images=False):
        if save_dir is None:
            assert self.config.project_dir
            save_dir = self.config.project_dir
            project = self._project
        else:
            merge = True

        if merge:
            project = Project(Config(self.config))
            project.config.remove('sources')

        save_dir = osp.abspath(save_dir)
        dataset_save_dir = osp.join(save_dir, project.config.dataset_dir)

        converter_kwargs = {
            'save_images': save_images,
        }

        save_dir_existed = osp.exists(save_dir)
        try:
            os.makedirs(save_dir, exist_ok=True)
            os.makedirs(dataset_save_dir, exist_ok=True)

            if merge:
                # merge and save the resulting dataset
                self.env.converters.get(DEFAULT_FORMAT).convert(
                    self, dataset_save_dir, **converter_kwargs)
            else:
                if recursive:
                    # children items should already be updated
                    # so we just save them recursively
                    for source in self._sources.values():
                        if isinstance(source, ProjectDataset):
                            source.save(**converter_kwargs)

                self.env.converters.get(DEFAULT_FORMAT).convert(
                    self.iterate_own(), dataset_save_dir, **converter_kwargs)

            project.save(save_dir)
        except BaseException:
            if not save_dir_existed and osp.isdir(save_dir):
                shutil.rmtree(save_dir, ignore_errors=True)
            raise
Exemplo n.º 16
0
    def test_project_generate(self):
        src_config = Config({
            'project_name': 'test_project',
            'format_version': 1,
        })

        with TestDir() as test_dir:
            project_path = test_dir
            Project.generate(project_path, src_config)

            self.assertTrue(osp.isdir(project_path))

            result_config = Project.load(project_path).config
            self.assertEqual(src_config.project_name,
                             result_config.project_name)
            self.assertEqual(src_config.format_version,
                             result_config.format_version)
Exemplo n.º 17
0
    def save(self,
             save_dir=None,
             merge=False,
             recursive=True,
             save_images=False):
        if save_dir is None:
            assert self.config.project_dir
            save_dir = self.config.project_dir
            project = self._project
        else:
            merge = True

        if merge:
            project = Project(Config(self.config))
            project.config.remove('sources')

        save_dir = osp.abspath(save_dir)
        os.makedirs(save_dir, exist_ok=True)

        dataset_save_dir = osp.join(save_dir, project.config.dataset_dir)
        os.makedirs(dataset_save_dir, exist_ok=True)

        converter_kwargs = {
            'save_images': save_images,
        }

        if merge:
            # merge and save the resulting dataset
            converter = self.env.make_converter(DEFAULT_FORMAT,
                                                **converter_kwargs)
            converter(self, dataset_save_dir)
        else:
            if recursive:
                # children items should already be updated
                # so we just save them recursively
                for source in self._sources.values():
                    if isinstance(source, ProjectDataset):
                        source.save(**converter_kwargs)

            converter = self.env.make_converter(DEFAULT_FORMAT,
                                                **converter_kwargs)
            converter(self.iterate_own(), dataset_save_dir)

        project.save(save_dir)
Exemplo n.º 18
0
    def __init__(self, config=None):
        config = Config(config,
            fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)

        env_dir = osp.join(config.project_dir, config.env_dir)
        env_config_path = osp.join(env_dir, config.env_filename)
        env_config = Config(fallback=ENV_DEFAULT_CONFIG, schema=ENV_SCHEMA)
        if osp.isfile(env_config_path):
            env_config.update(Config.parse(env_config_path))

        self.config = env_config

        self.models = ModelRegistry(env_config)
        self.sources = SourceRegistry(config)

        import datumaro.components.importers as builtin_importers
        builtin_importers = builtin_importers.items
        custom_importers = self._get_custom_module_items(
            env_dir, env_config.importers_dir)
        self.importers = ModuleRegistry(config,
            builtin=builtin_importers, local=custom_importers)

        import datumaro.components.extractors as builtin_extractors
        builtin_extractors = builtin_extractors.items
        custom_extractors = self._get_custom_module_items(
            env_dir, env_config.extractors_dir)
        self.extractors = ModuleRegistry(config,
            builtin=builtin_extractors, local=custom_extractors)
        self.extractors.register(self.PROJECT_EXTRACTOR_NAME,
            load_project_as_dataset)

        import datumaro.components.launchers as builtin_launchers
        builtin_launchers = builtin_launchers.items
        custom_launchers = self._get_custom_module_items(
            env_dir, env_config.launchers_dir)
        self.launchers = ModuleRegistry(config,
            builtin=builtin_launchers, local=custom_launchers)

        import datumaro.components.converters as builtin_converters
        builtin_converters = builtin_converters.items
        custom_converters = self._get_custom_module_items(
            env_dir, env_config.converters_dir)
        if custom_converters is not None:
            custom_converters = custom_converters.items
        self.converters = ModuleRegistry(config,
            builtin=builtin_converters, local=custom_converters)

        self.statistics = ModuleRegistry(config)
        self.visualizers = ModuleRegistry(config)
        self.git = GitWrapper(config)
Exemplo n.º 19
0
    def test_can_save_and_load(self):
        with TestDir() as test_dir:
            schema_low = SchemaBuilder() \
                .add('options', dict) \
                .build()
            schema_mid = SchemaBuilder() \
                .add('desc', lambda: Config(schema=schema_low)) \
                .build()
            schema_top = SchemaBuilder() \
                .add('container', lambda: DictConfig(
                    lambda v: Config(v, schema=schema_mid))) \
                .build()

            source = Config({
                'container': {
                    'elem': {
                        'desc': {
                            'options': {
                                'k': (1, 2, 3),
                                'd': 'asfd',
                            }
                        }
                    }
                }
            }, schema=schema_top)
            p = osp.join(test_dir, 'f.yaml')

            source.dump(p)

            loaded = Config.parse(p, schema=schema_top)

            self.assertTrue(isinstance(
                loaded.container['elem'].desc.options['k'], list))
            loaded.container['elem'].desc.options['k'] = \
                tuple(loaded.container['elem'].desc.options['k'])
            self.assertEqual(source, loaded)
Exemplo n.º 20
0
    \
    .add('subsets', list) \
    .add('sources', lambda: _DefaultConfig(
        lambda v=None: Source(v))) \
    .add('models', lambda: _DefaultConfig(
        lambda v=None: Model(v))) \
    \
    .add('models_dir', str, internal=True) \
    .add('plugins_dir', str, internal=True) \
    .add('sources_dir', str, internal=True) \
    .add('dataset_dir', str, internal=True) \
    .add('project_filename', str, internal=True) \
    .add('project_dir', str, internal=True) \
    .add('env_dir', str, internal=True) \
    .build()

PROJECT_DEFAULT_CONFIG = Config(
    {
        'project_name': 'undefined',
        'format_version': 1,
        'sources_dir': 'sources',
        'dataset_dir': 'dataset',
        'models_dir': 'models',
        'plugins_dir': 'plugins',
        'project_filename': 'config.yaml',
        'project_dir': '',
        'env_dir': '.datumaro',
    },
    mutable=False,
    schema=PROJECT_SCHEMA)
Exemplo n.º 21
0
 def __init__(self, config=None):
     self.config = Config(config,
         fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
     self.env = Environment(self.config)
Exemplo n.º 22
0
class Project:
    @classmethod
    def load(cls, path):
        path = osp.abspath(path)
        config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir,
                               PROJECT_DEFAULT_CONFIG.project_filename)
        config = Config.parse(config_path)
        config.project_dir = path
        config.project_filename = osp.basename(config_path)
        return Project(config)

    def save(self, save_dir=None):
        config = self.config

        if save_dir is None:
            assert config.project_dir
            project_dir = config.project_dir
        else:
            project_dir = save_dir

        env_dir = osp.join(project_dir, config.env_dir)
        save_dir = osp.abspath(env_dir)

        project_dir_existed = osp.exists(project_dir)
        env_dir_existed = osp.exists(env_dir)
        try:
            os.makedirs(save_dir, exist_ok=True)

            config_path = osp.join(save_dir, config.project_filename)
            config.dump(config_path)
        except BaseException:
            if not env_dir_existed:
                shutil.rmtree(save_dir, ignore_errors=True)
            if not project_dir_existed:
                shutil.rmtree(project_dir, ignore_errors=True)
            raise

    @staticmethod
    def generate(save_dir, config=None):
        config = Config(config)
        config.project_dir = save_dir
        project = Project(config)
        project.save(save_dir)
        return project

    @staticmethod
    def import_from(path, dataset_format=None, env=None, **format_options):
        if env is None:
            env = Environment()

        if not dataset_format:
            matches = env.detect_dataset(path)
            if not matches:
                raise DatumaroError(
                    "Failed to detect dataset format automatically")
            if 1 < len(matches):
                raise DatumaroError(
                    "Failed to detect dataset format automatically:"
                    " data matches more than one format: %s" % \
                    ', '.join(matches))
            dataset_format = matches[0]
        elif not env.is_format_known(dataset_format):
            raise KeyError("Unknown dataset format '%s'" % dataset_format)

        if dataset_format in env.importers:
            project = env.make_importer(dataset_format)(path, **format_options)
        elif dataset_format in env.extractors:
            project = Project(env=env)
            project.add_source('source', {
                'url': path,
                'format': dataset_format,
                'options': format_options,
            })
        else:
            raise DatumaroError(
                "Unknown format '%s'. To make it "
                "available, add the corresponding Extractor implementation "
                "to the environment" % dataset_format)
        return project

    def __init__(self, config=None, env=None):
        self.config = Config(config,
                             fallback=PROJECT_DEFAULT_CONFIG,
                             schema=PROJECT_SCHEMA)
        if env is None:
            env = Environment(self.config)
        elif config is not None:
            raise ValueError(
                "env can only be provided when no config provided")
        self.env = env

    def make_dataset(self):
        return ProjectDataset(self)

    def add_source(self, name, value=None):
        if value is None or isinstance(value, (dict, Config)):
            value = Source(value)
        self.config.sources[name] = value
        self.env.sources.register(name, value)

    def remove_source(self, name):
        self.config.sources.remove(name)
        self.env.sources.unregister(name)

    def get_source(self, name):
        try:
            return self.config.sources[name]
        except KeyError:
            raise KeyError("Source '%s' is not found" % name)

    def get_subsets(self):
        return self.config.subsets

    def set_subsets(self, value):
        if not value:
            self.config.remove('subsets')
        else:
            self.config.subsets = value

    def add_model(self, name, value=None):
        if value is None or isinstance(value, (dict, Config)):
            value = Model(value)
        self.env.register_model(name, value)
        self.config.models[name] = value

    def get_model(self, name):
        try:
            return self.env.models.get(name)
        except KeyError:
            raise KeyError("Model '%s' is not found" % name)

    def remove_model(self, name):
        self.config.models.remove(name)
        self.env.unregister_model(name)

    def make_executable_model(self, name):
        model = self.get_model(name)
        return self.env.make_launcher(model.launcher,
                                      **model.options,
                                      model_dir=osp.join(
                                          self.config.project_dir,
                                          self.local_model_dir(name)))

    def make_source_project(self, name):
        source = self.get_source(name)

        config = Config(self.config)
        config.remove('sources')
        config.remove('subsets')
        project = Project(config)
        project.add_source(name, source)
        return project

    def local_model_dir(self, model_name):
        return osp.join(self.config.env_dir, self.config.models_dir,
                        model_name)

    def local_source_dir(self, source_name):
        return osp.join(self.config.sources_dir, source_name)
Exemplo n.º 23
0
class Project:
    @classmethod
    def load(cls, path):
        path = osp.abspath(path)
        config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir,
            PROJECT_DEFAULT_CONFIG.project_filename)
        config = Config.parse(config_path)
        config.project_dir = path
        config.project_filename = osp.basename(config_path)
        return Project(config)

    def save(self, save_dir=None):
        config = self.config

        if save_dir is None:
            assert config.project_dir
            project_dir = config.project_dir
        else:
            project_dir = save_dir

        env_dir = osp.join(project_dir, config.env_dir)
        save_dir = osp.abspath(env_dir)

        project_dir_existed = osp.exists(project_dir)
        env_dir_existed = osp.exists(env_dir)
        try:
            os.makedirs(save_dir, exist_ok=True)

            config_path = osp.join(save_dir, config.project_filename)
            config.dump(config_path)
        except Exception:
            if not env_dir_existed:
                shutil.rmtree(save_dir, ignore_errors=True)
            if not project_dir_existed:
                shutil.rmtree(project_dir, ignore_errors=True)
            raise

    @staticmethod
    def generate(save_dir, config=None):
        project = Project(config)
        project.save(save_dir)
        project.config.project_dir = save_dir
        return project

    @staticmethod
    def import_from(path, dataset_format, env=None, **kwargs):
        if env is None:
            env = Environment()
        importer = env.make_importer(dataset_format)
        return importer(path, **kwargs)

    def __init__(self, config=None):
        self.config = Config(config,
            fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
        self.env = Environment(self.config)

    def make_dataset(self):
        return ProjectDataset(self)

    def add_source(self, name, value=None):
        if value is None or isinstance(value, (dict, Config)):
            value = Source(value)
        self.config.sources[name] = value
        self.env.sources.register(name, value)

    def remove_source(self, name):
        self.config.sources.remove(name)
        self.env.sources.unregister(name)

    def get_source(self, name):
        try:
            return self.config.sources[name]
        except KeyError:
            raise KeyError("Source '%s' is not found" % name)

    def get_subsets(self):
        return self.config.subsets

    def set_subsets(self, value):
        if not value:
            self.config.remove('subsets')
        else:
            self.config.subsets = value

    def add_model(self, name, value=None):
        if value is None or isinstance(value, (dict, Config)):
            value = Model(value)
        self.env.register_model(name, value)
        self.config.models[name] = value

    def get_model(self, name):
        try:
            return self.env.models.get(name)
        except KeyError:
            raise KeyError("Model '%s' is not found" % name)

    def remove_model(self, name):
        self.config.models.remove(name)
        self.env.unregister_model(name)

    def make_executable_model(self, name):
        model = self.get_model(name)
        model.model_dir = self.local_model_dir(name)
        return self.env.make_launcher(model.launcher,
            **model.options, model_dir=model.model_dir)

    def make_source_project(self, name):
        source = self.get_source(name)

        config = Config(self.config)
        config.remove('sources')
        config.remove('subsets')
        project = Project(config)
        project.add_source(name, source)
        return project

    def local_model_dir(self, model_name):
        return osp.join(
            self.config.env_dir, self.config.models_dir, model_name)

    def local_source_dir(self, source_name):
        return osp.join(self.config.sources_dir, source_name)
Exemplo n.º 24
0
 def generate(save_dir, config=None):
     config = Config(config)
     config.project_dir = save_dir
     project = Project(config)
     project.save(save_dir)
     return project
Exemplo n.º 25
0
    def test_cant_change_immutable(self):
        conf = Config({ 'x': 42 }, mutable=False)

        with self.assertRaises(ImmutableObjectError):
            conf.y = 5
Exemplo n.º 26
0
    SchemaBuilder as _SchemaBuilder,
)
import datumaro.components.extractor as datumaro
from datumaro.util.image import lazy_image, load_image, Image

from cvat.utils.cli.core import CLI as CVAT_CLI, CVAT_API_V1


CONFIG_SCHEMA = _SchemaBuilder() \
    .add('task_id', int) \
    .add('server_host', str) \
    .add('server_port', int) \
    .build()

DEFAULT_CONFIG = Config({'server_port': 80},
                        schema=CONFIG_SCHEMA,
                        mutable=False)


class cvat_rest_api_task_images(datumaro.SourceExtractor):
    def _image_local_path(self, item_id):
        task_id = self._config.task_id
        return osp.join(
            self._cache_dir,
            'task_{}_frame_{:06d}.jpg'.format(task_id, int(item_id)))

    def _make_image_loader(self, item_id):
        return lazy_image(item_id,
                          lambda item_id: self._image_loader(item_id, self))

    def _is_image_cached(self, item_id):
Exemplo n.º 27
0
        return res


TREE_SCHEMA = _SchemaBuilder() \
    .add('format_version', int) \
    \
    .add('sources', lambda: _DictConfig(lambda v=None: Source(v))) \
    .add('build_targets', lambda: _DictConfig(lambda v=None: BuildTarget(v))) \
    \
    .add('base_dir', str, internal=True) \
    .add('config_path', str, internal=True) \
    .build()

TREE_DEFAULT_CONFIG = Config({
    'format_version': 2,
    'config_path': '',
},
                             mutable=False,
                             schema=TREE_SCHEMA)


class TreeConfig(Config):
    def __init__(self, config=None, mutable=True):
        super().__init__(config=config,
                         mutable=mutable,
                         fallback=TREE_DEFAULT_CONFIG,
                         schema=TREE_SCHEMA)


PROJECT_SCHEMA = _SchemaBuilder() \
    .add('format_version', int) \
    \
Exemplo n.º 28
0
class Project:
    @staticmethod
    def load(path):
        path = osp.abspath(path)
        if osp.isdir(path):
            path = osp.join(path, PROJECT_DEFAULT_CONFIG.project_filename)
        config = Config.parse(path)
        config.project_dir = osp.dirname(path)
        config.project_filename = osp.basename(path)
        return Project(config)

    def save(self, save_dir=None):
        config = self.config
        if save_dir is None:
            assert config.project_dir
            save_dir = osp.abspath(config.project_dir)
        config_path = osp.join(save_dir, config.project_filename)

        env_dir = osp.join(save_dir, config.env_dir)
        os.makedirs(env_dir, exist_ok=True)
        self.env.save(osp.join(env_dir, config.env_filename))

        config.dump(config_path)

    @staticmethod
    def generate(save_dir, config=None):
        project = Project(config)
        project.save(save_dir)
        project.config.project_dir = save_dir
        return project

    @staticmethod
    def import_from(path, dataset_format, env=None, **kwargs):
        if env is None:
            env = Environment()
        importer = env.make_importer(dataset_format)
        return importer(path, **kwargs)

    def __init__(self, config=None):
        self.config = Config(config,
            fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
        self.env = Environment(self.config)

    def make_dataset(self):
        return ProjectDataset(self)

    def add_source(self, name, value=Source()):
        if isinstance(value, (dict, Config)):
            value = Source(value)
        self.config.sources[name] = value
        self.env.sources.register(name, value)

    def remove_source(self, name):
        self.config.sources.remove(name)
        self.env.sources.unregister(name)

    def get_source(self, name):
        return self.config.sources[name]

    def get_subsets(self):
        return self.config.subsets

    def set_subsets(self, value):
        if not value:
            self.config.remove('subsets')
        else:
            self.config.subsets = value

    def add_model(self, name, value=Model()):
        if isinstance(value, (dict, Config)):
            value = Model(value)
        self.env.register_model(name, value)

    def get_model(self, name):
        return self.env.models.get(name)

    def remove_model(self, name):
        self.env.unregister_model(name)

    def make_executable_model(self, name):
        model = self.get_model(name)
        model.model_dir = self.local_model_dir(name)
        return self.env.make_launcher(model.launcher,
            **model.options, model_dir=model.model_dir)

    def make_source_project(self, name):
        source = self.get_source(name)

        config = Config(self.config)
        config.remove('sources')
        config.remove('subsets')
        config.remove('filter')
        project = Project(config)
        project.add_source(name, source)
        return project

    def get_filter(self):
        if 'filter' in self.config:
            return self.config.filter
        return ''

    def set_filter(self, value=None):
        if not value:
            self.config.remove('filter')
        else:
            # check filter
            XPathDatasetFilter(value)
            self.config.filter = value

    def local_model_dir(self, model_name):
        return osp.join(
            self.config.env_dir, self.env.config.models_dir, model_name)

    def local_source_dir(self, source_name):
        return osp.join(self.config.sources_dir, source_name)
Exemplo n.º 29
0
 def test_empty_config_is_ok():
     Project(Config())
Exemplo n.º 30
0
ENV_SCHEMA = _SchemaBuilder() \
    .add('models_dir', str) \
    .add('importers_dir', str) \
    .add('launchers_dir', str) \
    .add('converters_dir', str) \
    .add('extractors_dir', str) \
    \
    .add('models', lambda: _DefaultConfig(
        lambda v=None: Model(v))) \
    .build()

ENV_DEFAULT_CONFIG = Config(
    {
        'models_dir': 'models',
        'importers_dir': 'importers',
        'launchers_dir': 'launchers',
        'converters_dir': 'converters',
        'extractors_dir': 'extractors',
    },
    mutable=False,
    schema=ENV_SCHEMA)


PROJECT_SCHEMA = _SchemaBuilder() \
    .add('project_name', str) \
    .add('format_version', int) \
    \
    .add('sources_dir', str) \
    .add('dataset_dir', str) \
    .add('build_dir', str) \
    .add('subsets', list) \
    .add('sources', lambda: _DefaultConfig(