class SQLAuthDBConnector(SQLAuthDBConfigMixin): def __init__(self, db_host=None, db_name=None, db_user=None, db_password=None, settings=None, config_section=None): if config_section is None: config_section = self.default_config_section settings = self._get_actual_settings(db_host, db_name, db_user, db_password, settings, config_section) self.context_deposit = ThreadLocalContextDeposit( repr_token=self.__class__.__name__, attr_factories={'audit_log_external_meta_items': dict}) self.db_session_factory = None # to be set in configure_db() self._audit_log = None # to be set in configure_db() super(SQLAuthDBConnector, self).__init__(settings, config_section) def _get_actual_settings(self, db_host, db_name, db_user, db_password, settings, config_section): if self._verify_args_for_connection(db_host=db_host, db_name=db_name, db_user=db_user, db_password=db_password): password_part = (':{}'.format(db_password) if db_password else '') option_val = 'mysql+mysqldb://{user}{password_part}@{host}/{name}'.format( user=db_user, password_part=password_part, host=db_host, name=db_name) option_key = '{}.url'.format(config_section) if settings is None: settings = { option_key: option_val, } else: # keep a config from `settings` (it is most likely # a Pyramid-style config dict), but config options # made with kwargs should have higher priority settings = dict(settings, option_key=option_val) return settings def _verify_args_for_connection(self, **kwargs): if kwargs['db_host'] and kwargs['db_name'] and kwargs['db_user']: return True incorrectly_specified_args = { name: val for name, val in kwargs.iteritems() if val } if incorrectly_specified_args: args_repr = ', '.join( '{}={!r}'.format(name, val) for name, val in incorrectly_specified_args.iteritems()) raise TypeError( '{!r}\'s constructor: *either* the `db_host`, `db_name` ' 'and `db_user` arguments, plus optionally `db_password`, ' 'should be given (as non-empty strings), *or* none of ' 'them! (got: {})'.format(self, args_repr)) return False def configure_db(self): super(SQLAuthDBConnector, self).configure_db() self.db_session_factory = sqlalchemy.orm.sessionmaker(bind=self.engine, autocommit=False, autoflush=False) self._audit_log = AuditLog( session_factory=self.db_session_factory, external_meta_items_getter=self._get_audit_log_external_meta_items) def _get_audit_log_external_meta_items(self): return copy.deepcopy( self.context_deposit.audit_log_external_meta_items) # Public methods (to be called by client code; they can also be # overridden/extended and/or called in subclasses): def set_audit_log_external_meta_items(self, n6_module, **other_external_meta_items): external_meta_items = dict(n6_module=n6_module, **other_external_meta_items) self.context_deposit.audit_log_external_meta_items = external_meta_items def __enter__(self): self.context_deposit.on_enter( outermost_context_factory=self.db_session_factory, context_factory=self.make_nested_savepoint) return self.get_current_session() def __exit__(self, exc_type, exc, tb): self.context_deposit.on_exit( exc_type, exc, tb, context_finalizer=self.finalize_nested_savepoint, outermost_context_finalizer=self.finalize_session) def get_current_session(self): return self.context_deposit.outermost_context # Context-management-related methods that can be overridden/extended # and/or called in subclasses but do *not* belong to the public # interface of `SQLAuthDBConnector` instances: def make_nested_savepoint(self): session = self.get_current_session() assert isinstance(session, Session) return session.begin_nested() def finalize_nested_savepoint(self, _savepoint, exc_type, exc_value, tb): session = self.get_current_session() assert isinstance(session, Session) self.commit_or_rollback(session, exc_type, exc_value, tb) def finalize_session(self, session, exc_type, exc_value, tb): assert isinstance(session, Session) try: self.commit_or_rollback(session, exc_type, exc_value, tb) finally: session.close() def commit_or_rollback(self, session, exc_type, _exc_value, _tb): assert isinstance(session, Session) if exc_type is None: with self.commit_wrapper(session): session.commit() else: session.rollback() @contextlib.contextmanager def commit_wrapper(self, session): try: yield except: session.rollback() raise
class MailSendingAPI(ConfigMixin, _MessageHelpersMixin): """ TODO: docs """ config_spec = ''' [mail_sending_api] smtp_host :: str smtp_port :: int smtp_login = "" :: str smtp_password = "" :: str ''' def __init__(self, settings=None): self.config = self._get_config(settings) self._smtp_client_deposit = ThreadLocalContextDeposit( repr_token=self.__class__.__name__ ) #3: `__name__` -> `__qualname__` def __enter__(self): self._smtp_client_deposit.on_enter(context_factory=self._connect) return self def __exit__(self, exc_type, exc, tb): self._smtp_client_deposit.on_exit(exc_type, exc, tb, context_finalizer=self._disconnect) def send_message( self, message, # type: EmailMessage sender=None, # type: Optional[UsualHeaderRaw] to=None, # type: Optional[AddressHeaderRaw] subject=None, # type: Optional[UsualHeaderRaw] extra_headers=None # type: Optional[DictyCollectionOfHeaderRaw] ): # type: (...) -> Tuple[Set[String], Dict[String, Tuple[int, String]]] if not isinstance(message, EmailMessage): raise TypeError('unsupported type of `message`: ' '{!r}'.format(type(message))) if sys.version_info[0] >= 3 and not isinstance( message.policy, EmailPolicy): #3: `sys.version_info[0] >= 3 and `-- raise TypeError('unsupported type of `message.policy`: ' '{!r}'.format(type(message.policy))) message = self.copy_message(message) if sender is not None: del message['Sender'] self.drop_and_add_header(message, 'From', sender) if to is not None: self.drop_and_add_header(message, 'To', to) if subject is not None: self.drop_and_add_header(message, 'Subject', subject) if extra_headers is not None: self.add_multiple_headers(message, extra_headers) from_addr_items = self._from_addr_items(message) from_addr = self._pure_addr(from_addr_items[0]) to_addr_items = self._to_addr_items(message) to_addrs = list(map(self._pure_addr, to_addr_items)) if 'Date' not in message: message['Date'] = self.prepare_header_value( 'Date', datetime.datetime.utcnow()) if 'Message-ID' not in message: message['Message-ID'] = self._generate_message_id( message['Date'], from_addr) str = formataddr #3-- if 'Sender' not in message and len( from_addr_items) >= 2: # see: RFC 5322, section 3.6.2 message['Sender'] = str(from_addr_items[0]) self.validate_subject_value(message['Subject']) self._verify_no_unsupported_headers(message) client = self._get_current_smtp_client() (flattening_policy, mail_options) = self._prepare_flattening_policy_and_mail_options( message, client) flattened_message = self._flatten_message(message, flattening_policy) recipient_problems = client.sendmail(from_addr, to_addrs, flattened_message, mail_options=mail_options) if recipient_problems: # Some recipients have been rejected by the SMTP server, # but not all (if all of them had been rejected an exception # would have been raised). LOGGER.warning('Unable to send e-mail to some recipient(s): %r', recipient_problems) ok_recipients = set(to_addrs).difference(recipient_problems.keys()) return ( # A set of successful dispatch recipient (`To`) addresses # (`str` objects). ok_recipients, # type: Set[String] # A dict that collects information on *recipient problems*, # i.e., that maps failed dispatch recipient (`To`) addresses # (`str` objects) to pairs (2-tuples) consisting of the # following error data from the SMTP server: # * (1) error code (`int`), # * (2) message (`str`). recipient_problems, # type: Dict[String, Tuple[int, String]] ) def _get_config(self, settings): config = self.get_config_section(settings=settings) if config['smtp_login'] or config['smtp_password']: if not config['smtp_login']: raise ConfigError('[mail_sending_api] `smtp_login` is missing ' '(though `smtp_password` is present)') if not config['smtp_password']: raise ConfigError( '[mail_sending_api] `smtp_password` is missing ' '(though `smtp_login` is present)') return config def _connect(self): # TODO: add SSL/TLS as an option. client = smtplib.SMTP(self.config['smtp_host'], self.config['smtp_port']) if self.config['smtp_login']: assert self.config['smtp_password'] client.login(self.config['smtp_login'], self.config['smtp_password']) return client def _disconnect(self, client, *_exc_info): client.quit() def _from_addr_items(self, message): addr_header_values = (message.get_all('Sender') if 'Sender' in message else message.get_all('From')) addr_items = self._addr_items_from_header_values(addr_header_values) if not addr_items: raise ValueError('no `Sender`/`From` address(es) specified') return addr_items def _to_addr_items(self, message): addr_header_values = message.get_all('To') addr_items = self._addr_items_from_header_values(addr_header_values) if not addr_items: raise ValueError('no `To` address(es) specified') return addr_items def _addr_items_from_header_values(self, addr_header_values): if sys.version_info[0] < 3: #3-- raw_addrs = [ self._header_value_as_ascii_str(header_value) #3-- for header_value in addr_header_values ] #3-- return list(getaddresses(raw_addrs)) #3-- return [ addr_item for header_value in addr_header_values for addr_item in header_value.addresses ] def _header_value_as_ascii_str( self, header_value): #3: remove whole method definition if header_value is None: raise ValueError( 'no header value given (got: {!r})'.format(header_value)) try: return unicode(header_value).encode('ascii') except (UnicodeError, LookupError ) as exc: # Header.__unicode__() may raise LookupError for_repr = (header_value.encode() if isinstance( header_value, Header) else header_value) exc_ascii_str = make_exc_ascii_str(exc) raise ValueError('cannot obtain the actual header value from: ' '{!r} ({})'.format(for_repr, exc_ascii_str)) def _pure_addr(self, addr_item): addr = addr_item.addr_spec assert isinstance(addr, str) if not EMAIL_OVERRESTRICTED_SIMPLE_REGEX.search(addr): raise ValueError( 'e-mail address {!r} does not match ' 'n6lib.common_helpers.EMAIL_OVERRESTRICTED_SIMPLE_REGEX'. format(addr)) assert addr == ascii_str(addr) return addr def _pure_addr( self, addr_item ): #3: remove whole method definition (keeping the above one) _, addr = addr_item if not EMAIL_OVERRESTRICTED_SIMPLE_REGEX.search(addr): raise ValueError( 'e-mail address {!r} does not match ' 'n6lib.common_helpers.EMAIL_OVERRESTRICTED_SIMPLE_REGEX'. format(addr)) return addr def _generate_message_id(self, date_header, from_addr): if sys.version_info[0] < 3: #3-- date_str = self._header_value_as_ascii_str(date_header) #3-- timestamp = mktime_tz(parsedate_tz(date_str)) #3-- else: #3: remove *this* line and *dedent* the next line timestamp = date_header.datetime.timestamp() msg_id = 'n6.{}.{}'.format(trunc(timestamp), make_hex_id(length=32)) _, domain = from_addr.split('@') return '<{}@{}>'.format(msg_id, domain) def _verify_no_unsupported_headers(self, message): for header_name in UNSUPPORTED_HEADER_NAMES: values = message.get_all(header_name) if values: raise NotImplementedError( 'header {!r} is not supported by {} ' '(got: {!r})'.format( header_name, self.__class__. __name__, #3: `__name__` -> `__qualname__` values)) def _get_current_smtp_client(self): client = self._smtp_client_deposit.innermost_context if client is None: raise RuntimeError( 'no SMTP connection is active (you need ' 'to make use of the {}\'s context manager ' 'interface)'.format( self.__class__.__name__)) #3: `__name__` -> `__qualname__` return client def _prepare_flattening_policy_and_mail_options(self, message, client): if sys.version_info[0] < 3: #3-- return None, [] #3-- flattening_policy = message.policy mail_options = () # See: # * https://bugs.python.org/issue32814 # * https://github.com/python/cpython/pull/8303/files if flattening_policy.cte_type == '8bit': client.ehlo_or_helo_if_needed() if client.does_esmtp and client.has_extn('8bitmime'): mail_options += ('BODY=8BITMIME', ) else: flattening_policy = flattening_policy.clone(cte_type='7bit') return flattening_policy, mail_options def _flatten_message(self, message, flattening_policy): bytes_io = BytesIO() generator = BytesGenerator(bytes_io, mangle_from_=False, policy=flattening_policy) generator.flatten(message, unixfrom=False, linesep='\r\n') return bytes_io.getvalue() def _flatten_message( self, message, _): #3: remove whole method definition (keeping the above one) bytes_io = BytesIO() generator = BytesGenerator(bytes_io, mangle_from_=False) generator.flatten(message, unixfrom=False) return bytes_io.getvalue()