예제 #1
0
    def backend(self):
        """

        :return:
        """
        if not self._backend:
            cls_backend = load_object(
                self.config["USERNAMEPASSWORD_AUTH_BACKEND"])
            self._backend = cls_backend.from_auth(self)
        return self._backend
예제 #2
0
    async def data_received(self, data: bytes) -> None:
        """
        A version identifier/method selection message:

        +----+----------+----------+
        |VER | NMETHODS | METHODS  |
        +----+----------+----------+
        | 1  |    1     | 1 to 255 |
        +----+----------+----------+

        :param data:
        :type data: bytes
        :return:
        :rtype: None
        """
        if not self.validate(data):
            raise ProtocolVersionNotSupportedException

        self.logger.debug(
            "[%s] [INIT] [%s:%s] received: %s",
            hex(id(self.protocol))[-4:],
            *self.protocol.info_peername,
            repr(data),
        )

        nmethods: int = data[1]
        methods: List[int] = list(data[2:2 + nmethods])

        available_auth_methods = sorted(
            set(self.protocol.config["AUTH_METHODS"]).intersection(
                set(methods)),
            key=lambda x: list(self.protocol.config["AUTH_METHODS"]).index(x),
        )

        try:
            auth_method = available_auth_methods[0]
        except IndexError:
            self.logger.debug(
                "No acceptable methods found. "
                "The following methods are supported:\n%s",
                pprint.pformat(self.protocol.config["AUTH_METHODS"]),
            )
            self.protocol.transport.write(pack("!BB", VERSION,
                                               0xFF))  # NO ACCEPTABLE METHODS
            raise Socks5NoAcceptableMethodsException
        else:
            self.protocol.transport.write(pack("!BB", VERSION, auth_method))
            self.protocol.cls_auth_method = load_object(
                self.protocol.config["AUTH_METHODS"][auth_method])
예제 #3
0
    def __init__(self, settings: Settings):
        """
        Initialize with Settings
        :param settings:
        :type settings: Settings
        """
        get_runtime_info()

        self.settings: Settings = settings

        # initial loop at the very beginning
        self.loop: AbstractEventLoop = get_event_loop(settings)
        self.logger.info(
            "In this service the loop is adopted from: %s", settings["LOOP"].upper()
        )

        self._configure_loop()

        self.signal_manager: SignalManager = load_object(
            settings["CLS_SIGNAL_MANAGER"]
        ).from_settings(settings)

        # Setup signals for Service, because Service can't setup Signal Manager
        # from classmethod from_settings
        self.signal_manager.connect(self.service_started, service_started)
        self.signal_manager.connect(self.service_stopped, service_stopped)

        self.extension_manager: ExtensionManager = load_object(
            settings["CLS_EXTENSION_MANAGER"]
        ).from_service(self)

        self.middleware_manager: MiddlewareManager = load_object(
            settings["CLS_MIDDLEWARE_MANAGER"]
        ).from_service(self)

        self.channels: Dict[str, Channel] = self._get_channels()
예제 #4
0
    async def send_data(self, data: bytes) -> None:
        """

        :param data:
        :type data: bytes
        :return:
        :rtype: None
        """
        if self.client_transport is None:
            ssl_context: Optional[ssl.SSLContext]
            if self.config["CLIENT_SSL_CERT_FILE"]:
                ssl_context = ssl.create_default_context(
                    purpose=ssl.Purpose.SERVER_AUTH,
                    cafile=self.config["CLIENT_SSL_CERT_FILE"],
                )
            else:
                ssl_context = None

            cls_client = load_object(self.config["CLIENT_PROTOCOL"])

            loop = get_event_loop()

            transport: Transport
            client: Protocol
            transport, client = await loop.create_connection(
                protocol_factory=lambda: cls_client.from_channel(self.channel),
                host=self.config.get("CLIENT_ADDRESS"),
                port=self.config.get("CLIENT_PORT"),
                ssl=ssl_context,
            )

            cert: Dict[str, Union[Tuple, int, str]]
            if (cert := transport.get_extra_info("peercert")
                ) and cert["serialNumber"] not in self.certificates:
                self.logger.info(
                    "Enabled a certificate:\n%s",
                    pprint.pformat(cert),
                )
                self.certificates.add(cert["serialNumber"])
                self.stats.increase("certificates")

            client.server_transport = self.transport
            self.client_transport = transport
예제 #5
0
    def _get_channels(self) -> Dict[str, Channel]:
        """

        :return:
        :rtype: Dict[str, Channel]
        """
        channels: Dict[str, Channel] = {}

        cls_channel: Channel = load_object(self.settings["CLS_CHANNEL"])

        for name in self.settings["CHANNELS"]:
            channels[name] = cls_channel.from_service(
                self, name=name, setting_prefix=f"CHANNEL_{name.upper()}_"
            )

        self.logger.info(
            "Enable channels:\n%s", pprint.pformat(self.settings["CHANNELS"])
        )

        return channels
예제 #6
0
def get_event_loop(settings: Settings,
                   func: str = "new_event_loop",
                   **kwargs) -> AbstractEventLoop:
    """
    Return a singleton object for asyncio loop
    :param settings:
    :type settings: Settings
    :param func:
    :type func: str
    :param kwargs:
    :return:
    :rtype: AbstractEventLoop
    """

    global _LOOP  # pylint: disable=global-statement

    if _LOOP is None:
        loop_path: str = ".".join([settings["LOOP"], func])
        _LOOP = load_object(loop_path)(**kwargs)

    return _LOOP
예제 #7
0
    async def start(self) -> None:
        """

        :return:
        :rtype: None
        """
        cls_interface = load_object(self.config["INTERFACE_PROTOCOL"])

        loop = get_event_loop()

        ssl_context: Optional[ssl.SSLContext]
        if self.config.get("INTERFACE_SSL_CERT_FILE"):
            ssl_context = ssl.create_default_context(
                purpose=ssl.Purpose.CLIENT_AUTH)
            ssl_context.load_cert_chain(
                certfile=self.config.get("INTERFACE_SSL_CERT_FILE"),
                keyfile=self.config.get("INTERFACE_SSL_KEY_FILE"),
                password=self.config.get("INTERFACE_SSL_PASSWORD"),
            )
        else:
            ssl_context = None

        self.server = await loop.create_server(
            protocol_factory=lambda: cls_interface.from_channel(
                self, role="interface"),
            host=self.config["INTERFACE_ADDRESS"],
            port=self.config["INTERFACE_PORT"],
            ssl=ssl_context,
        )

        self.logger.info(
            "Channel [%s] is open; "
            "Protocol [%s] is listening on the interface: [%s:%s]",
            self.name,
            self.config["INTERFACE_PROTOCOL"],
            self.config["INTERFACE_ADDRESS"],
            self.config["INTERFACE_PORT"],
        )
예제 #8
0
    def __init__(self, service, name: str = None, setting_prefix: str = None):
        """

        :param service:
        :type service:
        :param name:
        :type name: str
        :param setting_prefix:
        :type setting_prefix: str
        """
        super().__init__(service, name, setting_prefix)  # type: ignore

        self._cls_components: Dict[str, int] = dict(
            sorted(
                self.settings[self.manage].items(),  # type: ignore
                key=lambda items: items[1],
            ))

        self._components: Dict[str, object] = {
            cls.name: cls.from_service(self.service)  # type: ignore
            for cls in (load_object(cls)
                        for cls in self._cls_components.keys())
        }
예제 #9
0
    async def data_received(self, data: bytes) -> None:
        """

        :param data:
        :type data: bytes
        :return:
        :rtype: None
        """
        if not self.validate(data):
            raise ProtocolVersionNotSupportedException

        (
            ver,  # pylint: disable=unused-variable
            cmd,
            rsv,  # pylint: disable=unused-variable
            atyp,
            dst_addr,
            dst_port,
        ) = await self.parse_host_data(data)

        if cmd not in self.supported_cmd:
            raise Socks5CMDNotSupportedException

        self.logger.debug(
            "[%s] [HOST] [%s:%s] [%s:%s] received: %s",
            hex(id(self.protocol))[-4:],
            *self.protocol.info_peername,
            to_str(dst_addr),
            dst_port,
            repr(data),
        )

        cls_client = load_object(self.protocol.config["CLIENT_PROTOCOL"])

        try:
            (
                client_transport,
                client_protocol,
            ) = await self.protocol.loop.create_connection(
                lambda: cls_client.from_channel(self.protocol.channel,
                                                role="client"),
                dst_addr,
                dst_port,
            )
        except OSError as exc:
            if exc.args == (101, "Network is unreachable"):
                self.logger.error(
                    "The target is unreachable: %s:%s",
                    dst_addr,
                    dst_port,
                )
                # self.stats.increase(f"Error/{self.name}/{exc.strerror}")
                self.protocol.transport.write(
                    pack("!BBBBIH", VERSION, 0x03, 0x00, atyp, 0xFF,
                         0xFF)  # Network unreachable
                )
                raise Socks5NetworkUnreachableException
            raise exc

        client_protocol.server_transport = self.protocol.transport
        self.protocol.client_transport = client_transport

        bnd_addr_: str
        bnd_port: int
        bnd_addr_, bnd_port = client_transport.get_extra_info("sockname")

        bnd_addr: bytes = socket.inet_pton(self.protocol.socket.family,
                                           bnd_addr_)

        atyp_: int
        if self.protocol.socket.family == socket.AF_INET:
            atyp_ = 0x01
        elif self.protocol.socket.family == socket.AF_INET6:
            atyp_ = 0x03

        self.protocol.transport.write(
            pack(
                f"!BBBB{len(bnd_addr)}sH",
                VERSION,
                0x00,
                0x00,
                atyp_,
                bnd_addr,
                bnd_port,
            ))
예제 #10
0
 def test_load_object(self):
     obj = load_object("bifrost.utils.misc.load_object")
     self.assertIs(obj, load_object)