예제 #1
0
파일: config.py 프로젝트: scottwedge/n6
 def __init__(self,
              db_host=None,
              db_name=None,
              db_user=None,
              db_password=None,
              settings=None,
              config_section=None):
     if config_section is None:
         config_section = self.default_config_section
     settings = self._get_actual_settings(db_host, db_name, db_user,
                                          db_password, settings,
                                          config_section)
     self.context_deposit = ThreadLocalContextDeposit(
         repr_token=self.__class__.__name__,
         attr_factories={'audit_log_external_meta_items': dict})
     self.db_session_factory = None  # to be set in configure_db()
     self._audit_log = None  # to be set in configure_db()
     super(SQLAuthDBConnector, self).__init__(settings, config_section)
예제 #2
0
 def __init__(self, settings=None):
     self.config = self._get_config(settings)
     self._smtp_client_deposit = ThreadLocalContextDeposit(
         repr_token=self.__class__.__name__
     )  #3: `__name__` -> `__qualname__`
예제 #3
0
파일: config.py 프로젝트: scottwedge/n6
class SQLAuthDBConnector(SQLAuthDBConfigMixin):
    def __init__(self,
                 db_host=None,
                 db_name=None,
                 db_user=None,
                 db_password=None,
                 settings=None,
                 config_section=None):
        if config_section is None:
            config_section = self.default_config_section
        settings = self._get_actual_settings(db_host, db_name, db_user,
                                             db_password, settings,
                                             config_section)
        self.context_deposit = ThreadLocalContextDeposit(
            repr_token=self.__class__.__name__,
            attr_factories={'audit_log_external_meta_items': dict})
        self.db_session_factory = None  # to be set in configure_db()
        self._audit_log = None  # to be set in configure_db()
        super(SQLAuthDBConnector, self).__init__(settings, config_section)

    def _get_actual_settings(self, db_host, db_name, db_user, db_password,
                             settings, config_section):
        if self._verify_args_for_connection(db_host=db_host,
                                            db_name=db_name,
                                            db_user=db_user,
                                            db_password=db_password):
            password_part = (':{}'.format(db_password) if db_password else '')
            option_val = 'mysql+mysqldb://{user}{password_part}@{host}/{name}'.format(
                user=db_user,
                password_part=password_part,
                host=db_host,
                name=db_name)
            option_key = '{}.url'.format(config_section)
            if settings is None:
                settings = {
                    option_key: option_val,
                }
            else:
                # keep a config from `settings` (it is most likely
                # a Pyramid-style config dict), but config options
                # made with kwargs should have higher priority
                settings = dict(settings, option_key=option_val)
        return settings

    def _verify_args_for_connection(self, **kwargs):
        if kwargs['db_host'] and kwargs['db_name'] and kwargs['db_user']:
            return True
        incorrectly_specified_args = {
            name: val
            for name, val in kwargs.iteritems() if val
        }
        if incorrectly_specified_args:
            args_repr = ', '.join(
                '{}={!r}'.format(name, val)
                for name, val in incorrectly_specified_args.iteritems())
            raise TypeError(
                '{!r}\'s constructor: *either* the `db_host`, `db_name` '
                'and `db_user` arguments, plus optionally `db_password`, '
                'should be given (as non-empty strings), *or* none of '
                'them! (got: {})'.format(self, args_repr))
        return False

    def configure_db(self):
        super(SQLAuthDBConnector, self).configure_db()
        self.db_session_factory = sqlalchemy.orm.sessionmaker(bind=self.engine,
                                                              autocommit=False,
                                                              autoflush=False)
        self._audit_log = AuditLog(
            session_factory=self.db_session_factory,
            external_meta_items_getter=self._get_audit_log_external_meta_items)

    def _get_audit_log_external_meta_items(self):
        return copy.deepcopy(
            self.context_deposit.audit_log_external_meta_items)

    # Public methods (to be called by client code; they can also be
    # overridden/extended and/or called in subclasses):

    def set_audit_log_external_meta_items(self, n6_module,
                                          **other_external_meta_items):
        external_meta_items = dict(n6_module=n6_module,
                                   **other_external_meta_items)
        self.context_deposit.audit_log_external_meta_items = external_meta_items

    def __enter__(self):
        self.context_deposit.on_enter(
            outermost_context_factory=self.db_session_factory,
            context_factory=self.make_nested_savepoint)
        return self.get_current_session()

    def __exit__(self, exc_type, exc, tb):
        self.context_deposit.on_exit(
            exc_type,
            exc,
            tb,
            context_finalizer=self.finalize_nested_savepoint,
            outermost_context_finalizer=self.finalize_session)

    def get_current_session(self):
        return self.context_deposit.outermost_context

    # Context-management-related methods that can be overridden/extended
    # and/or called in subclasses but do *not* belong to the public
    # interface of `SQLAuthDBConnector` instances:

    def make_nested_savepoint(self):
        session = self.get_current_session()
        assert isinstance(session, Session)
        return session.begin_nested()

    def finalize_nested_savepoint(self, _savepoint, exc_type, exc_value, tb):
        session = self.get_current_session()
        assert isinstance(session, Session)
        self.commit_or_rollback(session, exc_type, exc_value, tb)

    def finalize_session(self, session, exc_type, exc_value, tb):
        assert isinstance(session, Session)
        try:
            self.commit_or_rollback(session, exc_type, exc_value, tb)
        finally:
            session.close()

    def commit_or_rollback(self, session, exc_type, _exc_value, _tb):
        assert isinstance(session, Session)
        if exc_type is None:
            with self.commit_wrapper(session):
                session.commit()
        else:
            session.rollback()

    @contextlib.contextmanager
    def commit_wrapper(self, session):
        try:
            yield
        except:
            session.rollback()
            raise
예제 #4
0
class MailSendingAPI(ConfigMixin, _MessageHelpersMixin):
    """
    TODO: docs
    """

    config_spec = '''
        [mail_sending_api]
        smtp_host :: str
        smtp_port :: int
        smtp_login = "" :: str
        smtp_password = "" :: str
    '''

    def __init__(self, settings=None):
        self.config = self._get_config(settings)
        self._smtp_client_deposit = ThreadLocalContextDeposit(
            repr_token=self.__class__.__name__
        )  #3: `__name__` -> `__qualname__`

    def __enter__(self):
        self._smtp_client_deposit.on_enter(context_factory=self._connect)
        return self

    def __exit__(self, exc_type, exc, tb):
        self._smtp_client_deposit.on_exit(exc_type,
                                          exc,
                                          tb,
                                          context_finalizer=self._disconnect)

    def send_message(
        self,
        message,  # type: EmailMessage
        sender=None,  # type: Optional[UsualHeaderRaw]
        to=None,  # type: Optional[AddressHeaderRaw]
        subject=None,  # type: Optional[UsualHeaderRaw]
        extra_headers=None  # type: Optional[DictyCollectionOfHeaderRaw]
    ):  # type: (...) -> Tuple[Set[String], Dict[String, Tuple[int, String]]]

        if not isinstance(message, EmailMessage):
            raise TypeError('unsupported type of `message`: '
                            '{!r}'.format(type(message)))
        if sys.version_info[0] >= 3 and not isinstance(
                message.policy,
                EmailPolicy):  #3: `sys.version_info[0] >= 3 and `--
            raise TypeError('unsupported type of `message.policy`: '
                            '{!r}'.format(type(message.policy)))

        message = self.copy_message(message)

        if sender is not None:
            del message['Sender']
            self.drop_and_add_header(message, 'From', sender)
        if to is not None:
            self.drop_and_add_header(message, 'To', to)
        if subject is not None:
            self.drop_and_add_header(message, 'Subject', subject)
        if extra_headers is not None:
            self.add_multiple_headers(message, extra_headers)

        from_addr_items = self._from_addr_items(message)
        from_addr = self._pure_addr(from_addr_items[0])
        to_addr_items = self._to_addr_items(message)
        to_addrs = list(map(self._pure_addr, to_addr_items))

        if 'Date' not in message:
            message['Date'] = self.prepare_header_value(
                'Date', datetime.datetime.utcnow())
        if 'Message-ID' not in message:
            message['Message-ID'] = self._generate_message_id(
                message['Date'], from_addr)
        str = formataddr  #3--
        if 'Sender' not in message and len(
                from_addr_items) >= 2:  # see: RFC 5322, section 3.6.2
            message['Sender'] = str(from_addr_items[0])

        self.validate_subject_value(message['Subject'])
        self._verify_no_unsupported_headers(message)

        client = self._get_current_smtp_client()
        (flattening_policy,
         mail_options) = self._prepare_flattening_policy_and_mail_options(
             message, client)
        flattened_message = self._flatten_message(message, flattening_policy)
        recipient_problems = client.sendmail(from_addr,
                                             to_addrs,
                                             flattened_message,
                                             mail_options=mail_options)
        if recipient_problems:
            # Some recipients have been rejected by the SMTP server,
            # but not all (if all of them had been rejected an exception
            # would have been raised).
            LOGGER.warning('Unable to send e-mail to some recipient(s): %r',
                           recipient_problems)

        ok_recipients = set(to_addrs).difference(recipient_problems.keys())
        return (
            # A set of successful dispatch recipient (`To`) addresses
            # (`str` objects).
            ok_recipients,  # type: Set[String]

            # A dict that collects information on *recipient problems*,
            # i.e., that maps failed dispatch recipient (`To`) addresses
            # (`str` objects) to pairs (2-tuples) consisting of the
            # following error data from the SMTP server:
            #   * (1) error code (`int`),
            #   * (2) message (`str`).
            recipient_problems,  # type: Dict[String, Tuple[int, String]]
        )

    def _get_config(self, settings):
        config = self.get_config_section(settings=settings)
        if config['smtp_login'] or config['smtp_password']:
            if not config['smtp_login']:
                raise ConfigError('[mail_sending_api] `smtp_login` is missing '
                                  '(though `smtp_password` is present)')
            if not config['smtp_password']:
                raise ConfigError(
                    '[mail_sending_api] `smtp_password` is missing '
                    '(though `smtp_login` is present)')
        return config

    def _connect(self):
        # TODO: add SSL/TLS as an option.
        client = smtplib.SMTP(self.config['smtp_host'],
                              self.config['smtp_port'])
        if self.config['smtp_login']:
            assert self.config['smtp_password']
            client.login(self.config['smtp_login'],
                         self.config['smtp_password'])
        return client

    def _disconnect(self, client, *_exc_info):
        client.quit()

    def _from_addr_items(self, message):
        addr_header_values = (message.get_all('Sender') if 'Sender' in message
                              else message.get_all('From'))
        addr_items = self._addr_items_from_header_values(addr_header_values)
        if not addr_items:
            raise ValueError('no `Sender`/`From` address(es) specified')
        return addr_items

    def _to_addr_items(self, message):
        addr_header_values = message.get_all('To')
        addr_items = self._addr_items_from_header_values(addr_header_values)
        if not addr_items:
            raise ValueError('no `To` address(es) specified')
        return addr_items

    def _addr_items_from_header_values(self, addr_header_values):
        if sys.version_info[0] < 3:  #3--
            raw_addrs = [
                self._header_value_as_ascii_str(header_value)  #3--
                for header_value in addr_header_values
            ]  #3--
            return list(getaddresses(raw_addrs))  #3--
        return [
            addr_item for header_value in addr_header_values
            for addr_item in header_value.addresses
        ]

    def _header_value_as_ascii_str(
            self, header_value):  #3: remove whole method definition
        if header_value is None:
            raise ValueError(
                'no header value given (got: {!r})'.format(header_value))
        try:
            return unicode(header_value).encode('ascii')
        except (UnicodeError, LookupError
                ) as exc:  # Header.__unicode__() may raise LookupError
            for_repr = (header_value.encode() if isinstance(
                header_value, Header) else header_value)
            exc_ascii_str = make_exc_ascii_str(exc)
            raise ValueError('cannot obtain the actual header value from: '
                             '{!r} ({})'.format(for_repr, exc_ascii_str))

    def _pure_addr(self, addr_item):
        addr = addr_item.addr_spec
        assert isinstance(addr, str)
        if not EMAIL_OVERRESTRICTED_SIMPLE_REGEX.search(addr):
            raise ValueError(
                'e-mail address {!r} does not match '
                'n6lib.common_helpers.EMAIL_OVERRESTRICTED_SIMPLE_REGEX'.
                format(addr))
        assert addr == ascii_str(addr)
        return addr

    def _pure_addr(
            self, addr_item
    ):  #3: remove whole method definition (keeping the above one)
        _, addr = addr_item
        if not EMAIL_OVERRESTRICTED_SIMPLE_REGEX.search(addr):
            raise ValueError(
                'e-mail address {!r} does not match '
                'n6lib.common_helpers.EMAIL_OVERRESTRICTED_SIMPLE_REGEX'.
                format(addr))
        return addr

    def _generate_message_id(self, date_header, from_addr):
        if sys.version_info[0] < 3:  #3--
            date_str = self._header_value_as_ascii_str(date_header)  #3--
            timestamp = mktime_tz(parsedate_tz(date_str))  #3--
        else:  #3: remove *this* line and *dedent* the next line
            timestamp = date_header.datetime.timestamp()
        msg_id = 'n6.{}.{}'.format(trunc(timestamp), make_hex_id(length=32))
        _, domain = from_addr.split('@')
        return '<{}@{}>'.format(msg_id, domain)

    def _verify_no_unsupported_headers(self, message):
        for header_name in UNSUPPORTED_HEADER_NAMES:
            values = message.get_all(header_name)
            if values:
                raise NotImplementedError(
                    'header {!r} is not supported by {} '
                    '(got: {!r})'.format(
                        header_name,
                        self.__class__.
                        __name__,  #3: `__name__` -> `__qualname__`
                        values))

    def _get_current_smtp_client(self):
        client = self._smtp_client_deposit.innermost_context
        if client is None:
            raise RuntimeError(
                'no SMTP connection is active (you need '
                'to make use of the {}\'s context manager '
                'interface)'.format(
                    self.__class__.__name__))  #3: `__name__` -> `__qualname__`
        return client

    def _prepare_flattening_policy_and_mail_options(self, message, client):
        if sys.version_info[0] < 3:  #3--
            return None, []  #3--
        flattening_policy = message.policy
        mail_options = ()
        # See:
        # * https://bugs.python.org/issue32814
        # * https://github.com/python/cpython/pull/8303/files
        if flattening_policy.cte_type == '8bit':
            client.ehlo_or_helo_if_needed()
            if client.does_esmtp and client.has_extn('8bitmime'):
                mail_options += ('BODY=8BITMIME', )
            else:
                flattening_policy = flattening_policy.clone(cte_type='7bit')
        return flattening_policy, mail_options

    def _flatten_message(self, message, flattening_policy):
        bytes_io = BytesIO()
        generator = BytesGenerator(bytes_io,
                                   mangle_from_=False,
                                   policy=flattening_policy)
        generator.flatten(message, unixfrom=False, linesep='\r\n')
        return bytes_io.getvalue()

    def _flatten_message(
            self, message,
            _):  #3: remove whole method definition (keeping the above one)
        bytes_io = BytesIO()
        generator = BytesGenerator(bytes_io, mangle_from_=False)
        generator.flatten(message, unixfrom=False)
        return bytes_io.getvalue()