예제 #1
0
def test_call_overridden(selector, switch):
    overriding_provider1 = providers.Selector(switch, one=providers.Object(2))
    overriding_provider2 = providers.Selector(switch, one=providers.Object(3))

    selector.override(overriding_provider1)
    selector.override(overriding_provider2)

    with switch.override("one"):
        assert selector() == 3
    def test_deepcopy_from_memo(self):
        provider = providers.Selector(self.selector)
        provider_copy_memo = providers.Selector(self.selector)

        provider_copy = providers.deepcopy(
            provider,
            memo={id(provider): provider_copy_memo},
        )

        self.assertIs(provider_copy, provider_copy_memo)
    def test_call_overridden(self):
        provider = providers.Selector(self.selector,
                                      sample=providers.Object(1))
        overriding_provider1 = providers.Selector(self.selector,
                                                  sample=providers.Object(2))
        overriding_provider2 = providers.Selector(self.selector,
                                                  sample=providers.Object(3))

        provider.override(overriding_provider1)
        provider.override(overriding_provider2)

        with self.selector.override('sample'):
            self.assertEqual(provider(), 3)
예제 #4
0
def selector(selector_type, switch, one, two):
    if selector_type == "default":
        return providers.Selector(switch, one=one, two=two)
    elif selector_type == "empty":
        return providers.Selector()
    elif selector_type == "sys-streams":
        return providers.Selector(
            switch,
            stdin=providers.Object(sys.stdin),
            stdout=providers.Object(sys.stdout),
            stderr=providers.Object(sys.stderr),
        )
    else:
        raise ValueError("Unknown selector type \"{0}\"".format(selector_type))
    def test_deepcopy(self):
        provider = providers.Selector(self.selector)

        provider_copy = providers.deepcopy(provider)

        self.assertIsNot(provider, provider_copy)
        self.assertIsInstance(provider, providers.Selector)
예제 #6
0
class Container(containers.DeclarativeContainer):
    config = providers.Configuration()
    movie: providers.Factory[entities.Movie] = providers.Factory(entities.Movie)

    csv_finder = providers.Singleton(
        finders.CsvMovieFinder,
        movie_factory=movie.provider,
        path=config.finder.csv.path,
        delimiter=config.finder.csv.delimiter,
    )

    sqlite_finder = providers.Singleton(
        finders.SqliteMovieFinder,
        movie_factory=movie.provider,
        path=config.finder.sqlite.path,
    )

    orm_sqlite_finder = providers.Singleton(
        finders.ORMSqliteFinder,
        movie_factory=movie.provider,
        path=config.finder.sqlite.path,
    )

    finder = providers.Selector(
        config.finder.type,
        csv=csv_finder,
        sqlite=sqlite_finder,
        orm_sqlite=orm_sqlite_finder,
    )

    lister = providers.Factory(listers.MovieLister, movie_finder=finder)
예제 #7
0
class Container(containers.DeclarativeContainer):

    app = flask.Application(Flask, __name__)

    # Configuración
    config = providers.Configuration('config')

    # Dependencias
    logger = providers.Singleton(logging.Logger, name='logger')
    psycopg2_data_manager = providers.Singleton(Psycopg2DataManager,
                                                user=config.username,
                                                password=config.password,
                                                host=config.host,
                                                port=config.port,
                                                db_name=config.db_name)

    sqlalchemy_orm_data_manager = providers.Singleton(SqlAlchemyORMDataManager,
                                                      user=config.username,
                                                      password=config.password,
                                                      host=config.host,
                                                      port=config.port,
                                                      db_name=config.db_name)

    # Servicios
    gestor_pedidos = providers.Selector(
        config.data_handler,
        psycopg2=providers.Singleton(GestorPedidos, psycopg2_data_manager,
                                     logger),
        sqlalchemy=providers.Singleton(GestorPedidos,
                                       sqlalchemy_orm_data_manager, logger),
    )
class Container(containers.DeclarativeContainer):

    config = providers.Configuration(yaml_files=["config.yml"])

    movie = providers.Factory(entities.Movie)

    csv_finder = providers.Singleton(
        finders.CsvMovieFinder,
        movie_factory=movie.provider,
        path=config.finder.csv.path,
        delimiter=config.finder.csv.delimiter,
    )

    sqlite_finder = providers.Singleton(
        finders.SqliteMovieFinder,
        movie_factory=movie.provider,
        path=config.finder.sqlite.path,
    )

    finder = providers.Selector(
        config.finder.type,
        csv=csv_finder,
        sqlite=sqlite_finder,
    )

    lister = providers.Factory(
        listers.MovieLister,
        movie_finder=finder,
    )
예제 #9
0
class ApplicationContainer(containers.DeclarativeContainer):
    _config = providers.Configuration('app')
    _configParser = providers.Callable(App.cliParser)
    _config.from_dict(_configParser())
    main = providers.Selector(_config.cmd,
                              info=providers.Callable(App.info),
                              distribute=providers.Callable(App.distribute),
                              cashimport=providers.Callable(App.cashimport))
def test_traverse_overridden():
    provider1 = providers.Callable(list)
    provider2 = providers.Callable(dict)
    selector1 = providers.Selector(lambda: "provider1", provider1=provider1)

    provider = providers.Selector(
        lambda: "provider2",
        provider2=provider2,
    )
    provider.override(selector1)

    all_providers = list(provider.traverse())

    assert len(all_providers) == 3
    assert provider1 in all_providers
    assert provider2 in all_providers
    assert selector1 in all_providers
    def test_traverse_overridden(self):
        provider1 = providers.Callable(list)
        provider2 = providers.Callable(dict)
        selector1 = providers.Selector(lambda: 'provider1', provider1=provider1)

        provider = providers.Selector(
            lambda: 'provider2',
            provider2=provider2,
        )
        provider.override(selector1)

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 3)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
        self.assertIn(selector1, all_providers)
예제 #12
0
class CloudContainer(containers.DeclarativeContainer):
    config = providers.Configuration(strict=True)

    filehost = providers.Selector(
        config.cloud.filehost,
        anonfiles=providers.Factory(AnonFilesHost),
        gofiles=providers.Factory(GoFilesHost),
        none=providers.Factory(NoneFilesHost),
    )
class Container(containers.DeclarativeContainer):

    config = providers.Configuration()

    selector = providers.Selector(
        config.one_or_another,
        one=providers.Factory(SomeClass),
        another=providers.Factory(SomeOtherClass),
    )
예제 #14
0
 def getMetadataObj(cls):
     config = PipelineContainer.config
     metadata_selector = providers.Selector(
         config.METADATA_RULES.metadata_prepclass,
         defaultmetadataprep=providers.Factory(DefaultMetadataPrep,
                                               config=config.provider),
         filemetadataprep=providers.Factory(FileMetadataPrep,
                                            config=config.provider),
     )
     return metadata_selector()
    def test_call_undefined_provider(self):
        provider = providers.Selector(
            self.selector,
            one=providers.Object(1),
            two=providers.Object(2),
        )

        with self.selector.override('three'):
            with self.assertRaises(errors.Error):
                provider()
    def test_call_selector_is_none(self):
        provider = providers.Selector(
            self.selector,
            one=providers.Object(1),
            two=providers.Object(2),
        )

        with self.selector.override(None):
            with self.assertRaises(errors.Error):
                provider()
    def test_call_with_context_args(self):
        provider = providers.Selector(
            self.selector,
            one=providers.Callable(lambda *args, **kwargs: (args, kwargs)),
        )

        with self.selector.override('one'):
            args, kwargs = provider(1, 2, three=3, four=4)

        self.assertEqual(args, (1, 2))
        self.assertEqual(kwargs, {'three': 3, 'four': 4})
    def test_call_any_callable(self):
        provider = providers.Selector(
            functools.partial(next, itertools.cycle(['one', 'two'])),
            one=providers.Object(1),
            two=providers.Object(2),
        )

        self.assertEqual(provider(), 1)
        self.assertEqual(provider(), 2)
        self.assertEqual(provider(), 1)
        self.assertEqual(provider(), 2)
    def test_getattr_attribute_error(self):
        provider_one = providers.Object(1)
        provider_two = providers.Object(2)

        provider = providers.Selector(
            self.selector,
            one=provider_one,
            two=provider_two,
        )

        with self.assertRaises(AttributeError):
            _ = provider.provider_three
    def test_getattr(self):
        provider_one = providers.Object(1)
        provider_two = providers.Object(2)

        provider = providers.Selector(
            self.selector,
            one=provider_one,
            two=provider_two,
        )

        self.assertIs(provider.one, provider_one)
        self.assertIs(provider.two, provider_two)
    def test_call(self):
        provider = providers.Selector(
            self.selector,
            one=providers.Object(1),
            two=providers.Object(2),
        )

        with self.selector.override('one'):
            self.assertEqual(provider(), 1)

        with self.selector.override('two'):
            self.assertEqual(provider(), 2)
    def test_init_optional(self):
        one = providers.Object(1)
        two = providers.Object(2)

        provider = providers.Selector()
        provider.set_selector(self.selector)
        provider.set_providers(one=one, two=two)

        self.assertEqual(provider.providers, {'one': one, 'two': two})
        with self.selector.override('one'):
            self.assertEqual(provider(), one())
        with self.selector.override('two'):
            self.assertEqual(provider(), two())
    def test_providers_attribute(self):
        provider_one = providers.Object(1)
        provider_two = providers.Object(2)

        provider = providers.Selector(
            self.selector,
            one=provider_one,
            two=provider_two,
        )

        self.assertEqual(provider.providers, {
            'one': provider_one,
            'two': provider_two
        })
    def test_deepcopy_overridden(self):
        provider = providers.Selector(self.selector)
        object_provider = providers.Object(object())

        provider.override(object_provider)

        provider_copy = providers.deepcopy(provider)
        object_provider_copy = provider_copy.overridden[0]

        self.assertIsNot(provider, provider_copy)
        self.assertIsInstance(provider, providers.Selector)

        self.assertIsNot(object_provider, object_provider_copy)
        self.assertIsInstance(object_provider_copy, providers.Object)
예제 #25
0
class Container(containers.DeclarativeContainer):

    config = providers.Configuration()

    component = providers.Selector(
        config.type,
        standard=providers.Factory(
            components.Component,
            param=config.param,
        ),
        all_caps=providers.Factory(
            components.AllCaps,
            param=config.param,
        ),
    )
    def test_repr(self):
        provider = providers.Selector(
            self.selector,
            one=providers.Object(1),
            two=providers.Object(2),
        )

        self.assertIn(
            '<dependency_injector.providers.Selector({0}'.format(
                repr(self.selector)),
            repr(provider),
        )
        self.assertIn('one={0}'.format(repr(provider.one)), repr(provider))
        self.assertIn('two={0}'.format(repr(provider.two)), repr(provider))
        self.assertIn('at {0}'.format(hex(id(provider))), repr(provider))
def test_traverse():
    switch = lambda: "provider1"
    provider1 = providers.Callable(list)
    provider2 = providers.Callable(dict)

    provider = providers.Selector(
        switch,
        provider1=provider1,
        provider2=provider2,
    )

    all_providers = list(provider.traverse())

    assert len(all_providers) == 2
    assert provider1 in all_providers
    assert provider2 in all_providers
    def test_traverse(self):
        switch = lambda: 'provider1'
        provider1 = providers.Callable(list)
        provider2 = providers.Callable(dict)

        provider = providers.Selector(
            switch,
            provider1=provider1,
            provider2=provider2,
        )

        all_providers = list(provider.traverse())

        self.assertEqual(len(all_providers), 2)
        self.assertIn(provider1, all_providers)
        self.assertIn(provider2, all_providers)
예제 #29
0
    def executeRulesDf(cls, df, ruleset):
        config = PipelineContainer.config
        metadata_selector = providers.Selector(
            config.METADATA_RULES.rules_metadata_prepclass,
            defaultmetadataprep=providers.Factory(DefaultMetadataPrep,
                                                  config=config.provider),
            filemetadataprep=providers.Factory(FileMetadataPrep,
                                               config=config.provider),
        )
        ##PipelineContainer.metadata_factory.override(
        ##            providers.Factory(DefaultMetadataPrep,
        ##            config=config.provider
        ##            ),
        ##)
        ##metadata_prep = PipelineContainer.metadata_factory()
        metadata_prep = metadata_selector()
        ruleset = metadata_prep.rulesprep(ruleset)
        key_columns = ruleset.keycolumns
        rule_columns = key_columns + [
            expr(rl.getsqlexp()) for rl in ruleset.rulelist
        ]
        rule_and = expr(' and '.join(
            [rl.getcolname() for rl in ruleset.rulelist]))
        rule_result_df = df.select(*rule_columns)       \
                           .withColumn('result_and',rule_and)
        rule_category = ruleset.category

        if rule_category == 'filter':
            filter_df = rule_result_df.filter(
                rule_result_df.result_and == 'true')
            result_df = df.join(filter_df, key_columns,
                                'inner').select(*df.columns)
        else:
            result_df = df

        ruledict = dict([('validdf', result_df), ('originaldf', df),
                         ('keycolumnslist', key_columns),
                         ('rulecategory', rule_category),
                         ('rulesetname', ruleset.rulesetname)])

        def wrapper(ruledict):
            def get(dfname):
                return ruledict[dfname]

            return get

        return wrapper(ruledict)
        class TestContainer(containers.DeclarativeContainer):
            settings = providers.Configuration()

            root_container = providers.Container(
                NestedContainer,
                settings=settings,
            )

            not_root_container = providers.Selector(
                settings.container,
                using_factory=providers.Factory(
                    NestedContainer,
                    settings=settings,
                ),
                using_container=providers.Container(
                    NestedContainer,
                    settings=settings,
                ))