示例#1
0
def main_worker(local_rank: int):
    container = Injector(MnistModule(local_rank))

    # train_loader = loader_builder.build(split='train')
    # val_loader = loader_builder.build(split='val')
    cfg = container.get(TypedConfig)

    # engine: BaseEngine = container.get(
    #     flame.auto_builder.import_from_path(cfg.engine.type)
    # )
    # engine: BaseEngine = build_from_config_with_container(
    #     container, cfg.engine
    # )

    engine: BaseEngine = container.get(
        flame.auto_builder.import_from_path(cfg.engine)
    )

    # eta = EstimatedTimeOfArrival('epoch', cfg.max_epochs)
    # while engine.unfinished(cfg.max_epochs):
    #     engine.train(train_loader)
    #     engine.validate(val_loader)
    # engine.run()
    container.call_with_injection(
        engine.run
    )
示例#2
0
def test_explicitly_passed_parameters_override_injectable_values():
    # The class needs to be defined globally for the 'X' forward reference to be able to be resolved.
    global X

    # We test a method on top of regular function to exercise the code path that's
    # responsible for handling methods.
    class X:
        @inject
        def method(self, s: str) -> str:
            return s

        @inject
        def method_typed_self(self: 'X', s: str) -> str:
            return s

    @inject
    def function(s: str) -> str:
        return s

    injection_counter = 0

    def provide_str() -> str:
        nonlocal injection_counter
        injection_counter += 1
        return 'injected string'

    def configure(binder: Binder) -> None:
        binder.bind(str, to=provide_str)

    injector = Injector([configure])
    x = X()

    try:
        assert injection_counter == 0

        assert injector.call_with_injection(x.method) == 'injected string'
        assert injection_counter == 1
        assert injector.call_with_injection(x.method_typed_self) == 'injected string'
        assert injection_counter == 2
        assert injector.call_with_injection(function) == 'injected string'
        assert injection_counter == 3

        assert injector.call_with_injection(x.method, args=('passed string',)) == 'passed string'
        assert injection_counter == 3
        assert injector.call_with_injection(x.method_typed_self, args=('passed string',)) == 'passed string'
        assert injection_counter == 3
        assert injector.call_with_injection(function, args=('passed string',)) == 'passed string'
        assert injection_counter == 3

        assert injector.call_with_injection(x.method, kwargs={'s': 'passed string'}) == 'passed string'
        assert injection_counter == 3
        assert (
            injector.call_with_injection(x.method_typed_self, kwargs={'s': 'passed string'})
            == 'passed string'
        )
        assert injection_counter == 3
        assert injector.call_with_injection(function, kwargs={'s': 'passed string'}) == 'passed string'
        assert injection_counter == 3
    finally:
        del X
示例#3
0
def test_explicitly_passed_parameters_override_injectable_values():
    # The class needs to be defined globally for the 'X' forward reference to be able to be resolved.
    global X

    # We test a method on top of regular function to exercise the code path that's
    # responsible for handling methods.
    class X:
        @inject
        def method(self, s: str) -> str:
            return s

        @inject
        def method_typed_self(self: 'X', s: str) -> str:
            return s

    @inject
    def function(s: str) -> str:
        return s

    injection_counter = 0

    def provide_str() -> str:
        nonlocal injection_counter
        injection_counter += 1
        return 'injected string'

    def configure(binder: Binder) -> None:
        binder.bind(str, to=provide_str)

    injector = Injector([configure])
    x = X()

    try:
        assert injection_counter == 0

        assert injector.call_with_injection(x.method) == 'injected string'
        assert injection_counter == 1
        assert injector.call_with_injection(x.method_typed_self) == 'injected string'
        assert injection_counter == 2
        assert injector.call_with_injection(function) == 'injected string'
        assert injection_counter == 3

        assert injector.call_with_injection(x.method, args=('passed string',)) == 'passed string'
        assert injection_counter == 3
        assert injector.call_with_injection(x.method_typed_self, args=('passed string',)) == 'passed string'
        assert injection_counter == 3
        assert injector.call_with_injection(function, args=('passed string',)) == 'passed string'
        assert injection_counter == 3

        assert injector.call_with_injection(x.method, kwargs={'s': 'passed string'}) == 'passed string'
        assert injection_counter == 3
        assert (
                injector.call_with_injection(x.method_typed_self, kwargs={'s': 'passed string'})
                == 'passed string'
        )
        assert injection_counter == 3
        assert injector.call_with_injection(function, kwargs={'s': 'passed string'}) == 'passed string'
        assert injection_counter == 3
    finally:
        del X
示例#4
0
def test_forward_references_in_annotations_are_handled():
    # See https://www.python.org/dev/peps/pep-0484/#forward-references for details
    def configure(binder):
        binder.bind(str, to='hello')

    @inject
    def fun(s: 'str') -> None:
        return s

    injector = Injector(configure)
    injector.call_with_injection(fun) == 'hello'
示例#5
0
def create_injector_from_flags(args=None, modules=[], defaults=None, **kwargs):
    """Create an application Injector from command line flags.

    Calls all AppStartup hooks.
    """
    if args is None:
        args = sys.argv
    modules = [FlagsModule(args, defaults=defaults)] + modules
    injector = Injector(modules, **kwargs)
    injector.binder.multibind(AppStartup, to=[])
    for startup in injector.get(AppStartup):
        injector.call_with_injection(startup)
    return injector
示例#6
0
def test_forward_references_in_annotations_are_handled():
    # See https://www.python.org/dev/peps/pep-0484/#forward-references for details

    class CustomModule(Module):
        @provider
        def provide_x(self) -> 'X':
            return X('hello')

    @inject
    def fun(s: 'X') -> 'X':
        return s

    # The class needs to be module-global in order for the string -> object
    # resolution mechanism to work. I could make it work with locals but it
    # doesn't seem worth it.
    global X

    class X:
        def __init__(self, message: str) -> None:
            self.message = message

    try:
        injector = Injector(CustomModule)
        assert injector.call_with_injection(fun).message == 'hello'
    finally:
        del X
示例#7
0
def test_more_useful_exception_is_raised_when_parameters_type_is_any():
    @inject
    def fun(a: Any) -> None:
        pass

    injector = Injector()

    # This was the exception before:
    #
    # TypeError: Cannot instantiate <class 'typing.AnyMeta'>
    #
    # Now:
    #
    # injector.CallError: Call to AnyMeta.__new__() failed: Cannot instantiate
    #   <class 'typing.AnyMeta'> (injection stack: ['injector_test_py3'])
    #
    # In this case the injection stack doesn't provide too much information but
    # it quickly gets helpful when the stack gets deeper.
    with pytest.raises((CallError, TypeError)):
        injector.call_with_injection(fun)
示例#8
0
def test_more_useful_exception_is_raised_when_parameters_type_is_any():
    @inject
    def fun(a: Any) -> None:
        pass

    injector = Injector()

    # This was the exception before:
    #
    # TypeError: Cannot instantiate <class 'typing.AnyMeta'>
    #
    # Now:
    #
    # injector.CallError: Call to AnyMeta.__new__() failed: Cannot instantiate
    #   <class 'typing.AnyMeta'> (injection stack: ['injector_test_py3'])
    #
    # In this case the injection stack doesn't provide too much information but
    # it quickly gets helpful when the stack gets deeper.
    with pytest.raises(CallError):
        injector.call_with_injection(fun)
示例#9
0
def test_forward_references_in_annotations_are_handled():
    # See https://www.python.org/dev/peps/pep-0484/#forward-references for details
    def configure(binder):
        binder.bind(X, to=X('hello'))

    @inject
    def fun(s: 'X') -> 'X':
        return s

    # The class needs to be module-global in order for the string -> object
    # resolution mechanism to work. I could make it work with locals but it
    # doesn't seem worth it.
    global X

    class X:
        def __init__(self, message: str) -> None:
            self.message = message

    try:
        injector = Injector(configure)
        injector.call_with_injection(fun).message == 'hello'
    finally:
        del X
示例#10
0
class InjectApplication(BaseApplication):
    """support inject application"""
    def __init__(self):
        super(InjectApplication, self).__init__()
        self.injector = Injector()

    def install(self, module):
        self.injector.binder.install(module)
        logger.debug(
            f"inject.install_module:{module}, self_dict:{self.__dict__}")
        return module

    def bind(self, interface=None, scope=None):
        """binding decorator"""
        def decorator(target):
            if isinstance(target, FunctionType):
                if interface:
                    self.injector.binder.bind(
                        interface, self.injector.call_with_injection(target),
                        scope)
                else:
                    logger.warning(
                        f"The bind target object is the function return value, the interface must specify"
                    )
            else:
                self.injector.binder.bind(interface or target, target, scope)
            return target

        logger.debug(
            f"inject:{self.injector} bind_interface:{interface}, scope:{scope}"
        )
        return decorator

    def bind_map(self, interface, key=None, scope=None):
        if not isinstance(interface, dict):
            raise ArgumentError(interface,
                                'Interface type must be a subtype of dict')

        def decorator(target):
            if isinstance(target, FunctionType):
                if key:
                    self.injector.binder.multibind(
                        interface,
                        {key: self.injector.call_with_injection(target)},
                        scope)
                else:
                    logger.warning(
                        f"The bind target object is the function return value, key must be specify"
                    )
            return target

        logger.debug(
            f"Inject.bind_map:interface:{interface}, key:{key}, scope:{scope}")
        return decorator

    def bind_list(self, interface, scope=None):
        """bind decorator"""
        if not isinstance(interface, list):
            raise ArgumentError(interface,
                                "Interface type must be a subtype of list")

        def decorator(target):
            if isinstance(target, FunctionType):
                self.injector.binder.multibind(
                    interface, [self.injector.call_with_injection(target)],
                    scope)
            else:
                self.injector.binder.multibind(interface, [target], scope)
            return target

        logger.debug(f"inject bind_list:interface:{interface}, scope:{scope}")
        return decorator

    def bind_inherit(self, interface):
        """bind decorator dynamic inherit(继承). bind again after dynamic inherit"""
        def decorator(cls):
            inherit_cls = type(cls.__name__,
                               (cls, self.injector.get(interface).__class__),
                               {})
            self.bind(interface)(inherit_cls)
            return inherit_cls

        logger.debug(f"inject.bind_inherit.interface:{interface}")
        return decorator

    def call_start_callbacks(self):
        """execute callback func after start"""
        for callback in self._start_callbacks:
            self.injector.call_with_injection(callback.callback)
        logger.debug(f"inject.callback.after.start:{self._start_callbacks}")

    def call_shutdown_callbacks(self):
        """execute callback func before closed"""
        for callback in self._shutdown_callbacks:
            self.injector.call_with_injection(callback.callback)
        logger.debug(
            f"inject.callback.before.close:{self._shutdown_callbacks}")

    @staticmethod
    def import_modules(module_path, filterdir=(), filterfile=()):
        """load module file"""
        logger.debug(f"inject.import.module.path:{module_path}")
        _import_m = []
        for path, subdir, files in os.walk(module_path):
            logger.debug(
                f"module_path:{module_path}, iter path:{path}, subdir:{subdir}, file:{files}"
            )
            skip = False
            curdir = os.path.normcase(path)
            for checkdir in filterdir:
                skip = True
                break
            if skip:
                continue
            for filename in files:
                if filename.endswith(
                    ('.py',
                     'pyc')) and not filename.startswith(tuple(filterfile)):
                    m = os.path.normcase(
                        os.path.splitext(os.path.join(path,
                                                      filename))[0]).replace(
                                                          os.path.sep, '.')
                    logger.debug(
                        f'not in sys_modules:{len(sys.modules)} load injector module file:{importlib.import_module(m)}'
                    )
                    _import_m.append(m)
            logger.debug(f"had imported:{_import_m} from path:{module_path}")
示例#11
0
def test_custom_scope() -> None:
    @dataclass
    class Project:
        name: str
        special: bool

    class ProjectScope(Scope):
        def __init__(self, *args, **kwargs):
            super(ProjectScope, self).__init__(*args, **kwargs)
            self.context = None

        @contextmanager
        def __call__(self, project: Project):
            if self.context is not None:
                raise Exception('context is not None')
            self.context = {}
            binder = self.injector.get(Binder)
            binder.bind(Project, to=project, scope=ProjectScope)
            yield
            self.context = None

        def get(self, key, provider):
            if self.context is None:
                raise UnsatisfiedRequirement(None, key)

            try:
                return self.context[key]
            except KeyError:
                provider = InstanceProvider(provider.get(self.injector))
                self.context[key] = provider
                return provider

    class Handler(ABC):
        pass

    class OrdinaryHandler(Handler):
        def __init__(self, project: Project):
            self.project = project

    class SpecialHandler(Handler):
        def __init__(self, project: Project):
            self.project = project

    class SomeSingletonService:
        pass

    project_scope = ScopeDecorator(ProjectScope)

    class ProjectScopedHandlerModule(Module):
        @project_scope
        @provider
        def handler(self, project: Project) -> Handler:
            if project.special:
                return SpecialHandler(project)
            else:
                return OrdinaryHandler(project)

        @singleton
        @provider
        def some_singleton_service(self,
                                   handler: Handler) -> SomeSingletonService:
            return SomeSingletonService()

    injector = Injector([ProjectScopedHandlerModule()], auto_bind=False)

    scope = injector.get(ProjectScope)

    with scope(Project(name='proj1', special=False)):
        handler = injector.get(Handler)
    assert isinstance(handler, OrdinaryHandler)
    assert handler.project.name == 'proj1'

    with scope(Project(name='proj2', special=True)):
        handler = injector.get(Handler)
    assert isinstance(handler, SpecialHandler)
    assert handler.project.name == 'proj2'

    with scope(Project(name='proj3', special=True)):

        @inject
        def f(handler: Handler) -> str:
            return handler.project.name

        assert injector.call_with_injection(f) == 'proj3'

    with pytest.raises(UnsatisfiedRequirement):
        injector.get(SomeSingletonService)

    with scope(Project(name='proj4', special=True)):
        some_singleton_service = injector.get(SomeSingletonService)
    assert injector.get(SomeSingletonService) == some_singleton_service  # !!!
示例#12
0
class AppUnit(RouterMixin):
    """
    Creates an application instance.

    **Parameters:**

    * **debug** - Boolean indicating if debug tracebacks should be returned on errors.
    * **exception_handlers** - A dictionary mapping either integer status codes,
    or exception class types onto callables which handle the exceptions.
    * **middleware** - A list of middleware (Middleware class or function) to run for every request.
    * **response_class** - Default response class used in routes.
    Default is `JSONResponse.`
    * **modules** - A list of configuration modules.
    * **auto_bind** - Whether to automatically bind missing types.
    """

    def __init__(
        self,
        debug: bool = False,
        exception_handlers: Dict[Union[int, Type[Exception]], Callable] = None,
        middleware: Sequence[Union[Middleware, Callable]] = None,
        response_class: Optional[Type[Response]] = None,
        modules: Optional[List[ModuleType]] = None,
        auto_bind: bool = False,
    ):
        self._debug = debug
        self.injector = Injector(auto_bind=auto_bind)
        self.router = Router()
        self.exception_handlers = {
            exc_type: self._inject_exception_handler(handler)
            for exc_type, handler in (
                {} if exception_handlers is None else dict(exception_handlers)
            ).items()
        }
        self.user_middleware = (
            []
            if middleware is None
            else [self.prepare_middleware(m) for m in middleware]
        )
        self.user_middleware.insert(0, Middleware(RequestScopeMiddleware))
        self.middleware_stack = self.build_middleware_stack()
        self.cli = click.Group()
        self.response_class = response_class or JSONResponse

        self.injector.binder.bind(AppUnit, to=self, scope=SingletonScope)
        self.injector.binder.bind(Injector, to=self.injector, scope=SingletonScope)
        self.injector.binder.bind(Request, to=context.get_current_request)

        modules = modules or []
        for module in modules:
            self.add_module(module)

    async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None:
        scope["app"] = self
        await self.middleware_stack(scope, receive, send)

    @property
    def debug(self) -> bool:
        return self._debug

    @debug.setter
    def debug(self, value: bool) -> None:
        self._debug = value
        self.middleware_stack = self.build_middleware_stack()

    def add_module(self, module: ModuleType) -> None:
        self.injector.binder.install(module)

    def lookup(
        self, interface: Type[InterfaceType], *, scope: Optional[ScopeType] = None
    ) -> InterfaceType:
        return self.injector.get(interface, scope=scope)

    def bind(
        self,
        interface: Type[InterfaceType],
        to: Optional[Any] = None,
        *,
        scope: Optional[ScopeType] = None,
    ) -> None:
        self.injector.binder.bind(interface, to=to, scope=scope)

    def singleton(
        self, interface: Type[InterfaceType], to: Optional[Any] = None
    ) -> None:
        self.injector.binder.bind(interface, to=to, scope=SingletonScope)

    def add_startup_event(self, event: Callable) -> None:
        event = self._inject_event(event)
        self.router.on_startup.append(event)

    def add_shutdown_event(self, event: Callable) -> None:
        event = self._inject_event(event)
        self.router.on_shutdown.append(event)

    def on_startup(self) -> Callable:
        def decorator(event: Callable):
            self.add_startup_event(event)
            return event

        return decorator

    def on_shutdown(self) -> Callable:
        def decorator(event: Callable):
            self.add_shutdown_event(event)
            return event

        return decorator

    async def startup(self) -> None:
        await self.router.startup()

    async def shutdown(self) -> None:
        await self.router.shutdown()

    def add_route(
        self,
        path: str,
        route: Callable,
        methods: List[str] = None,
        name: str = None,
        include_in_schema: bool = True,
    ) -> None:
        self.router.add_route(
            path,
            endpoint=self._inject_route(route),
            methods=methods,
            name=name,
            include_in_schema=include_in_schema,
        )

    def add_exception_handler(
        self, status_or_exc: Union[int, Type[Exception]], *, handler: Callable
    ) -> None:
        self.exception_handlers[status_or_exc] = self._inject_exception_handler(
            handler=handler
        )
        self.middleware_stack = self.build_middleware_stack()

    def exception_handler(self, status_or_exc: Union[int, Type[Exception]]) -> Callable:
        def decorator(handler: Callable) -> Callable:
            self.add_exception_handler(status_or_exc, handler=handler)
            return handler

        return decorator

    def add_middleware(self, middleware: Union[Middleware, Callable]) -> None:
        self.user_middleware.insert(0, self.prepare_middleware(middleware))
        self.middleware_stack = self.build_middleware_stack()

    def middleware(self) -> Callable:
        def decorator(middleware: Callable) -> Callable:
            self.add_middleware(middleware)
            return middleware

        return decorator

    def prepare_middleware(self, middleware: Union[Middleware, Callable]) -> Middleware:
        if inspect.isfunction(middleware):
            middleware = Middleware(
                BaseHTTPMiddleware,
                dispatch=self._inject_middleware(cast(Callable, middleware)),
            )
        return cast(Middleware, middleware)

    def build_middleware_stack(self) -> ASGIApp:
        debug = self.debug
        error_handler = None
        exception_handlers = {}

        for key, value in self.exception_handlers.items():
            if key in (500, Exception):
                error_handler = value
            else:
                exception_handlers[key] = value

        middleware = (
            [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
            + self.user_middleware
            + [
                Middleware(
                    ExceptionMiddleware, handlers=exception_handlers, debug=debug
                )
            ]
        )

        app = self.router
        for cls, options in reversed(middleware):
            app = cls(app=app, **options)
        return app

    def add_command(
        self,
        cmd: Callable,
        *,
        name: Optional[str] = None,
        cli: Optional[click.Group] = None,
        lifespan: bool = True,
    ) -> None:
        make_command = click.command()
        cli = cli or self.cli
        cli.add_command(
            make_command(self._inject_command(cmd, lifespan=lifespan)), name=name
        )

    def command(self, name: Optional[str] = None, *, lifespan: bool = True) -> Callable:
        def decorator(cmd: Callable) -> Callable:
            self.add_command(cmd, name=name, lifespan=lifespan)
            return cmd

        return decorator

    def run(self, host: str = "localhost", port: int = 8000, **kwargs) -> None:
        """
        Run Uvicorn server.
        """
        if uvicorn is None:
            raise RuntimeError("`uvicorn` is not installed.")

        uvicorn.run(app=self, host=host, port=port, **kwargs)

    def main(
        self,
        args: Optional[List[str]] = None,
        prog_name: Optional[str] = None,
        complete_var: Optional[str] = None,
        standalone_mode: bool = True,
        **extra,
    ) -> int:
        """
        Start application CLI.
        """
        return self.cli.main(
            args=args,
            prog_name=prog_name,
            complete_var=complete_var,
            standalone_mode=standalone_mode,
            **extra,
        )

    ############################################
    # Dependency Injection helpers
    ############################################

    def _inject_event(self, event: Callable) -> Callable:
        def wrapper(func: Callable) -> Callable:
            @functools.wraps(func)
            async def wrapped():
                handler = self.injector.call_with_injection(inject(func))
                if inspect.iscoroutine(handler):
                    return await handler

            return wrapped

        return wrapper(event)

    def _inject_route(self, route: Callable) -> Callable:
        def wrapper(func: Callable) -> Callable:
            @functools.wraps(func)
            async def wrapped(_: Request):
                response = self.injector.call_with_injection(inject(func))
                if inspect.iscoroutine(response):
                    response = await response

                if isinstance(response, Response):
                    return response
                return self.response_class(content=response)

            return wrapped

        return wrapper(route)

    def _inject_exception_handler(self, handler: Callable) -> Callable:
        def wrapper(func: Callable) -> Callable:
            @functools.wraps(func)
            async def wrapped(request: Request, exc: Exception):
                return await self.injector.call_with_injection(
                    inject(func), args=(request, exc)
                )

            return wrapped

        return wrapper(handler)

    def _inject_middleware(self, middleware: Callable) -> Callable:
        def wrapper(func: Callable) -> Callable:
            @functools.wraps(func)
            async def wrapped(request: Request, call_next: Callable):
                return await self.injector.call_with_injection(
                    inject(func), args=(request, call_next)
                )

            return wrapped

        return wrapper(middleware)

    def _inject_command(self, cmd: Callable, lifespan: bool = True) -> Callable:
        def wrapper(func: Callable) -> Callable:
            @functools.wraps(func)
            def wrapped(*args, **kwargs):
                def run():
                    return self.injector.call_with_injection(
                        inject(func), args=args, kwargs=kwargs
                    )

                async def async_run():
                    if lifespan:
                        await self.startup()
                        try:
                            await run()
                        finally:
                            await self.shutdown()
                    else:
                        await run()

                if inspect.iscoroutinefunction(func):
                    return asyncio.run(async_run())
                return run()

            return wrapped

        return wrapper(cmd)