예제 #1
0
    def test_add_source(self):
        source_name = 'source'
        origin = Source({'url': 'path', 'format': 'ext'})
        project = Project()

        project.add_source(source_name, origin)

        added = project.get_source(source_name)
        self.assertIsNotNone(added)
        self.assertEqual(added, origin)
예제 #2
0
    def test_added_source_can_be_saved(self):
        source_name = 'source'
        origin = Source({
            'url': 'path',
        })
        project = Project()
        project.add_source(source_name, origin)

        saved = project.config

        self.assertEqual(origin, saved.sources[source_name])
예제 #3
0
    def test_added_source_can_be_dumped(self):
        source_name = 'source'
        origin = Source({
            'url': 'path',
        })
        project = Project()
        project.add_source(source_name, origin)

        with TestDir() as test_dir:
            project.save(test_dir)

            loaded = Project.load(test_dir)
            loaded = loaded.get_source(source_name)
            self.assertEqual(origin, loaded)
예제 #4
0
    def import_from(cls,
                    path: str,
                    format: str = None,
                    env: Environment = None,
                    **kwargs) -> 'Dataset':
        from datumaro.components.config_model import Source

        if env is None:
            env = Environment()

        if not format:
            format = cls.detect(path, env)

        # TODO: remove importers, put this logic into extractors
        if format in env.importers:
            importer = env.make_importer(format)
            with logging_disabled(log.INFO):
                project = importer(path, **kwargs)
            detected_sources = list(project.config.sources.values())
        elif format in env.extractors:
            detected_sources = [{
                'url': path,
                'format': format,
                'options': kwargs
            }]
        else:
            raise DatumaroError(
                "Unknown source format '%s'. To make it "
                "available, add the corresponding Extractor implementation "
                "to the environment" % format)

        extractors = []
        for src_conf in detected_sources:
            if not isinstance(src_conf, Source):
                src_conf = Source(src_conf)
            extractors.append(
                env.make_extractor(src_conf.format, src_conf.url,
                                   **src_conf.options))

        dataset = cls.from_extractors(*extractors, env=env)
        dataset._source_path = path
        dataset._format = format
        return dataset
예제 #5
0
 def add_source(self, name, value=None):
     if value is None or isinstance(value, (dict, Config)):
         value = Source(value)
     self.config.sources[name] = value
     self.env.sources.register(name, value)
예제 #6
0
    def import_from(cls,
                    path: str,
                    format: Optional[str] = None,
                    *,
                    env: Optional[Environment] = None,
                    progress_reporter: Optional[ProgressReporter] = None,
                    error_policy: Optional[ImportErrorPolicy] = None,
                    **kwargs) -> Dataset:
        """
        Creates a `Dataset` instance from a dataset on the disk.

        Args:
            path - The input file or directory path
            format - Dataset format.
                If a string is passed, it is treated as a plugin name,
                which is searched for in the `env` plugin context.
                If not set, will try to detect automatically,
                using the `env` plugin context.
            env - A plugin collection. If not set, the built-in plugins are used
            progress_reporter - An object to report progress.
                Implies earger loading.
            error_policy - An object to report format-related errors.
                Implies earger loading.
            **kwargs - Parameters for the format
        """

        if env is None:
            env = Environment()

        if not format:
            format = cls.detect(path, env=env)

        # TODO: remove importers, put this logic into extractors
        if format in env.importers:
            importer = env.make_importer(format)
            with logging_disabled(log.INFO):
                detected_sources = importer(path, **kwargs)
        elif format in env.extractors:
            detected_sources = [{
                'url': path,
                'format': format,
                'options': kwargs
            }]
        else:
            raise UnknownFormatError(format)

        # TODO: probably, should not be available in lazy mode, because it
        # becomes unreliable and error-prone. For progress reporting it
        # makes little sense, because loading stage is spread over other
        # operations. Error reporting is going to be unreliable.
        has_ctx_args = progress_reporter is not None or error_policy is not None
        eager = has_ctx_args

        if not progress_reporter:
            progress_reporter = NullProgressReporter()
        pbars = progress_reporter.split(len(detected_sources))

        try:
            extractors = []
            for src_conf, pbar in zip(detected_sources, pbars):
                if not isinstance(src_conf, Source):
                    src_conf = Source(src_conf)

                extractor_kwargs = dict(src_conf.options)

                assert 'ctx' not in extractor_kwargs
                extractor_kwargs['ctx'] = ImportContext(
                    progress_reporter=pbar, error_policy=error_policy)

                try:
                    extractors.append(
                        env.make_extractor(src_conf.format, src_conf.url,
                                           **extractor_kwargs))
                except TypeError as e:
                    # TODO: for backward compatibility. To be removed after 0.3
                    if "unexpected keyword argument 'ctx'" not in str(e):
                        raise

                    if has_ctx_args:
                        warnings.warn(
                            "It seems that '%s' extractor "
                            "does not support progress and error reporting, "
                            "it will be disabled" % src_conf.format,
                            DeprecationWarning)
                    extractor_kwargs.pop('ctx')

                    extractors.append(
                        env.make_extractor(src_conf.format, src_conf.url,
                                           **extractor_kwargs))

            dataset = cls.from_extractors(*extractors, env=env)
            if eager:
                dataset.init_cache()
        except _ImportFail as e:
            raise e.__cause__

        dataset._source_path = path
        dataset._format = format

        return dataset