class SMTPSession(object): """The SMTPSession processes all input data which were extracted from sockets previously. The idea behind this is that this class only knows about different SMTP commands and does not have to know things like command mode and data mode. The protocol parser will create a new session instance for every new connection so this class does not have to be thread-safe. """ def __init__(self, command_parser, deliverer, policy=None, authenticator=None): self._command_parser = command_parser self._deliverer = deliverer self._policy = policy self._authenticator = authenticator self._command_arguments = None self._close_connection_after_response = False self._is_connected = True self._message = None self.state = None self.valid_commands = None self._build_state_machine() # ------------------------------------------------------------------------- # State machine building def _add_state(self, from_state, to_state, smtp_command, **kwargs): handler_function = self._dispatch_commands self.state.add(from_state, smtp_command, to_state, handler_function, **kwargs) def _get_all_commands(self, including_quit=False): commands = set() for actions in self.state._transitions.values(): for command_name, transition in actions.items(): target_state = transition[0] if target_state in ['new']: continue if including_quit or (target_state != 'finished'): commands.add(command_name) return commands def get_all_allowed_internal_commands(self): """Returns an iterable which includes all allowed commands. This does not mean that a specific command from the result is executable right now in this session state (or that it can be executed at all in this connection). Please note that the returned values are /internal/ commands, not SMTP commands (use get_all_allowed_smtp_commands for that) so there will be 'MAIL FROM' instead of 'MAIL'.""" states = set() for command_name in self._get_all_commands(including_quit=True): if command_name not in ['GREET', 'MSGDATA']: states.add(command_name) return states def get_all_allowed_smtp_commands(self): states = set() for command_name in self.get_all_allowed_internal_commands(): command_name = command_name.split(' ')[0] states.add(command_name) return states def _add_rset_transitions(self): for state_name in self.state.known_non_final_states(): target_state = 'initialized' if (state_name != 'new') else 'new' self._add_state(state_name, 'RSET', target_state) def _add_help_noop_and_quit_transitions(self): """HELP, NOOP and QUIT should be possible from everywhere so we need to add these transitions to all states configured so far.""" states = set() for state_name in self.state.known_states(): if state_name not in ['new', 'finished']: states.add(state_name) for state in states: self._add_state(state, 'NOOP', state) self._add_state(state, 'HELP', state) self._add_state(state, 'QUIT', 'finished') def _build_state_machine(self): self.state = StateMachine(initial_state='new') self._add_state('new', 'GREET', 'greeted') self._add_state('greeted', 'HELO', 'initialized') self._add_state('greeted', 'EHLO', 'initialized', operations=('set_esmtp', )) # ---- self._add_state('initialized', 'MAIL FROM', 'sender_known') self._add_state('initialized', 'AUTH PLAIN', 'authenticated', condition='if_esmtp') self._add_state('initialized', 'AUTH LOGIN', 'authenticated', condition='if_esmtp') self._add_state('authenticated', 'AUTH LOGIN', 'authenticated') self._add_state('authenticated', 'MAIL FROM', 'sender_known') # ---- self._add_state('sender_known', 'RCPT TO', 'recipient_known') # multiple recipients self._add_state('recipient_known', 'RCPT TO', 'recipient_known') self._add_state('recipient_known', 'DATA', 'receiving_message') self._add_state('receiving_message', 'MSGDATA', 'initialized') self._add_help_noop_and_quit_transitions() self._add_rset_transitions() self.valid_commands = self.state.known_actions() # ------------------------------------------------------------------------- def get_ehlo_lines(self): """Return the capabilities to be advertised after EHLO.""" lines = [] if self._policy is not None: lines.extend(self._policy.ehlo_lines(self._message.peer)) elif self._authenticator is not None: # fallback to ease testing in case no policy was specified explicitely lines.append('AUTH PLAIN') lines.append('HELP') return lines def _set_size_restrictions(self): """Set the maximum allowed message in the underlying layer so that big messages are not hold in memory before they are rejected.""" self._command_parser.set_maximum_message_size(self._max_message_size()) def _dispatch_commands(self, from_state, to_state, smtp_command): """This method dispatches a SMTP command to the appropriate handler method. It is called after a new command was received and a valid transition was found.""" #print from_state, ' -> ', to_state, ':', smtp_command name_handler_method = 'smtp_%s' % smtp_command.lower().replace( ' ', '_') try: handler_method = getattr(self, name_handler_method) except AttributeError: # base_msg = 'No handler for %s though transition is defined (no method %s)' # print base_msg % (smtp_command, name_handler_method) self.reply(451, 'Temporary Local Problem: Please come back later') else: # Don't catch InvalidDataError here - else the state would be moved # forward. Instead the handle_input will catch it and send out the # appropriate reply. handler_method() def _evaluate_decision(self, decision): return (decision in [True, None]) def _is_multiline_reply(self, reply_message): return (not isinstance(reply_message, basestring)) def _send_custom_response(self, reply): code, custom_response = reply if self._is_multiline_reply(custom_response): self.multiline_reply(code, custom_response) else: self.reply(code, custom_response) def _evaluate_policydecision_result(self, result): decision = self._evaluate_decision(result.is_command_acceptable()) response_sent = result.use_custom_reply() if result.close_connection_before_response(): self.close_connection() response_sent = True if result.use_custom_reply(): self._send_custom_response(result.get_custom_reply()) if result.close_connection_after_response(): self.please_close_connection_after_response() return decision, response_sent def is_allowed(self, acl_name, *args): if self._policy is not None: decider = getattr(self._policy, acl_name) result = decider(*args) if result in [True, False, None]: return self._evaluate_decision(result), False elif hasattr(result, 'is_command_acceptable'): return self._evaluate_policydecision_result(result) elif len(result) == 2: decision = self._evaluate_decision(result[0]) self._send_custom_response(result[1]) return decision, True raise ValueError('Unknown policy response') return True, False # ------------------------------------------------------------------------- def new_connection(self, remote_ip, remote_port): """This method is called when a new SMTP session is opened. [PUBLIC API] """ self.state.set_state('new') self._message = Message(Peer(remote_ip, remote_port)) decision, response_sent = self.is_allowed('accept_new_connection', self._message.peer) if decision: if not response_sent: self.handle_input('greet') self._set_size_restrictions() else: if not response_sent: self.reply(554, 'SMTP service not available') self.close_connection() def handle_input(self, smtp_command, data=None): """Processes the given SMTP command with the (optional data). [PUBLIC API] """ self._command_arguments = data self.please_close_connection_after_response(False) try: self._handle_command(smtp_command) finally: if self.should_close_connection_after_response(): self.close_connection() self._command_arguments = None def _handle_command(self, smtp_command): # SMTP commands must be treated as case-insensitive command = smtp_command.upper() try: self.state.execute(command) except StateMachineError: if command not in self.valid_commands: self.reply(500, 'unrecognized command "%s"' % smtp_command) else: msg = 'Command "%s" is not allowed here' % smtp_command allowed_transitions = self.state.allowed_actions() if len(allowed_transitions) > 0: msg += ', expected on of %s' % allowed_transitions self.reply(503, msg) except InvalidDataError: e = sys.exc_info()[1] self.reply(501, e.msg()) except InvalidParametersError: # TODO: Get rid of InvalidParametersError, shouldn't be # necessary anymore e = sys.exc_info()[1] if not e.response_sent: msg = 'Syntactically invalid %s argument(s)' % smtp_command self.reply(501, msg) except PolicyDenial: e = sys.exc_info()[1] if not e.response_sent: self.reply(e.code, e.reply_text) def input_exceeds_limits(self): """Called when the client sent a message that exceeded the maximum size.""" self.reply(552, 'message exceeds fixed maximum message size') def reply(self, code, text): """This method returns a message to the client (actually the session object is responsible of actually pushing the bits).""" self._command_parser.push(code, text) def multiline_reply(self, code, responses): """This method returns a message with multiple lines to the client (actually the session object is responsible of actually pushing the bits).""" self._command_parser.multiline_push(code, responses) def please_close_connection_after_response(self, value=None): if value is None: value = True self._close_connection_after_response = value def should_close_connection_after_response(self): return self._close_connection_after_response def close_connection(self): "Request a connection close from the SMTP session handling instance." if self._is_connected: self._is_connected = False self._command_parser.close_when_done() # ------------------------------------------------------------------------- # Protocol handling functions (not public) def arguments(self): """Return the given parameters for the command as a string or an empty string""" return self._command_arguments or '' def smtp_greet(self): """This method handles not a real smtp command. It is called when a new connection was accepted by the server.""" # Policy check was done when accepting the connection so we don't have # to do it here again. primary_hostname = self._command_parser.primary_hostname reply_text = '%s Hello %s' % (primary_hostname, self._message.peer.remote_ip) self.reply(220, reply_text) def validate(self, schema_class): context = dict(esmtp=self.uses_esmtp()) return schema_class().process(self.arguments(), context=context) def smtp_quit(self): self.validate(SMTPCommandArgumentsSchema) primary_hostname = self._command_parser.primary_hostname reply_text = '%s closing connection' % primary_hostname self.reply(221, reply_text) self._command_parser.close_when_done() def smtp_noop(self): self.validate(SMTPCommandArgumentsSchema) self.reply(250, 'OK') def smtp_help(self): # deliberately no checking for additional parameters because RFC 821 # says: # "The command may take an argument (e.g., any command name) and # return more specific information as a response." states = self.get_all_allowed_smtp_commands() self.multiline_reply(214, ('Commands supported', ' '.join(states))) def _reply_to_helo(self, helo_string, response_sent): self._message.smtp_helo = helo_string if not response_sent: primary_hostname = self._command_parser.primary_hostname self.reply(250, primary_hostname) def _process_helo_or_ehlo(self, policy_methodname, reply_method): validated_data = self.validate(HeloSchema) helo_string = validated_data['helo'] decision, response_sent = self.is_allowed(policy_methodname, helo_string, self._message) if decision: reply_method(helo_string, response_sent) elif not decision: raise PolicyDenial(response_sent) def smtp_helo(self): self._process_helo_or_ehlo('accept_helo', self._reply_to_helo) def _reply_to_ehlo(self, helo_string, response_sent): self._message.smtp_helo = helo_string if not response_sent: primary_hostname = self._command_parser.primary_hostname lines = [primary_hostname] + self.get_ehlo_lines() self.multiline_reply(250, lines) def smtp_ehlo(self): self._process_helo_or_ehlo('accept_ehlo', self._reply_to_ehlo) def _check_password(self, username, password): if self._authenticator is None: code = 535 reply_text = 'AUTH not available' self.reply(code, reply_text) raise InvalidParametersError(response_sent=True, code=code, reply_text=reply_text) credentials_correct = \ self._authenticator.authenticate(username, password, self._message.peer) if credentials_correct: self._message.username = username self.reply(235, 'Authentication successful') else: self.reply(535, 'Bad username or password') def smtp_auth_plain(self): if self._authenticator is None: self.reply(535, 'AUTH not available') raise InvalidParametersError(response_sent=True) validated_data = self.validate(AuthPlainSchema) username = validated_data['username'] password = validated_data['password'] decision, response_sent = self.is_allowed('accept_auth_plain', username, password, self._message) if not decision: raise PolicyDenial(response_sent) elif not response_sent: self._check_password(username, password) def smtp_auth_login(self): validated_data = self.validate( AuthLoginSchema) if self.arguments() else {} username = validated_data.get('username') decision, response_sent = self.is_allowed('accept_auth_login', username, self._message) if not decision: raise PolicyDenial(response_sent) elif not response_sent: if not username: next_ = 'Username:'******'username'] = username next_ = 'Password:'******'username'] username = self._message.unvalidated_input.get('username') if username is None: self._message.unvalidated_input['username'] = decoded_input next_ = 'Password:'******'username'] self._command_parser.switch_to_command_mode() self._check_password(username, password) def _check_size_restriction(self, extensions): announced_size = extensions.get('size') if announced_size is None: return max_message_size = self._max_message_size() if max_message_size is None: return if announced_size > max_message_size: self.reply(552, 'message exceeds fixed maximum message size') raise InvalidParametersError('MAIL FROM', response_sent=True) def uses_esmtp(self): return self.state.is_set('esmtp') def smtp_mail_from(self): validated_data = self.validate(MailFromSchema) sender = validated_data['email'] self._check_size_restriction(validated_data) decision, response_sent = self.is_allowed('accept_from', sender, self._message) if not decision: raise PolicyDenial(response_sent) self._message.smtp_from = sender if not response_sent: self.reply(250, 'OK') def smtp_rcpt_to(self): validated_data = self.validate(RcptToSchema) email_address = validated_data['email'] decision, response_sent = self.is_allowed('accept_rcpt_to', email_address, self._message) if decision: self._message.smtp_to.append(email_address) if not response_sent: self.reply(250, 'OK') elif not decision: raise PolicyDenial(response_sent, 550, 'relay not permitted') def smtp_data(self): self.validate(SMTPCommandArgumentsSchema) decision, response_sent = self.is_allowed('accept_data', self._message) if decision and not response_sent: self._command_parser.switch_to_data_mode() self.reply(354, 'Enter message, ending with "." on a line by itself') elif not decision: raise PolicyDenial(response_sent) def _max_message_size(self): max_message_size = None if (self._policy is not None) and (self._message.peer is not None): max_message_size = self._policy.max_message_size( self._message.peer) return max_message_size def _check_size_restrictions(self, msg_data): max_message_size = self._max_message_size() if max_message_size is None: return msg_too_big = (len(msg_data) > int(max_message_size)) if msg_too_big: msg = 'message exceeds fixed maximum message size' raise PolicyDenial(False, 552, msg) def _copy_basic_settings(self, msg): peer = self._message.peer new_message = Message(peer=Peer(peer.remote_ip, peer.remote_port), smtp_helo=self._message.smtp_helo, username=self._message.username) return new_message def smtp_msgdata(self): """This method handles not a real smtp command. It is called when the whole message was received (multi-line DATA command is completed).""" msg_data = self.arguments() self._command_parser.switch_to_command_mode() self._check_size_restrictions(msg_data) decision, response_sent = self.is_allowed('accept_msgdata', msg_data, self._message) if decision: self._message.msg_data = msg_data new_message = self._copy_basic_settings(self._message) self._deliverer.new_message_accepted(self._message) if not response_sent: self.reply(250, 'OK') # Now we must not loose the message anymore! self._message = new_message elif not decision: raise PolicyDenial(response_sent, 550, 'Message content is not acceptable') def smtp_rset(self): self.validate(SMTPCommandArgumentsSchema) self._message = Message(peer=self._message.peer, smtp_helo=self._message.smtp_helo) self.reply(250, 'Reset OK')
class StateMachineTest(PythonicTestCase): def setUp(self): self.state = StateMachine(initial_state='new') def test_can_initialize_statemachine(self): StateMachine(initial_state='foo') # --- adding states ------------------------------------------------------ def test_can_add_states(self): self.state.add('new', 'processed', 'process') self.state.add('new', 'new', 'noop') def test_raise_exception_if_duplicate_action_is_defined(self): self.state.add('new', 'processed', 'process') with assert_raises(StateMachineDefinitionError): self.state.add('new', 'new', 'process') # --- introspection ------------------------------------------------------ def test_can_ask_for_current_state(self): state = StateMachine(initial_state='foo') state.add('foo', 'foo', 'noop') assert_equals('foo', state.state()) assert_false(state.is_impossible_state()) def test_no_state_if_initial_state_not_available(self): state = StateMachine(initial_state='invalid') assert_none(state.state()) assert_true(state.is_impossible_state()) def test_can_ask_for_all_known_actions(self): self.state.add('new', 'new', 'noop') self.state.add('new', 'processed', 'process') self.state.add('processed', 'new', 'rework') assert_equals(set(('noop', 'process', 'rework')), self.state.known_actions()) def test_can_ask_for_all_currently_allowed_actions(self): self.state.add('new', 'new', 'noop') self.state.add('new', 'processed', 'process') self.state.add('processed', 'new', 'rework') assert_equals(set(('noop', 'process')), self.state.allowed_actions()) self.state.set_state('processed') assert_equals(set(('rework', )), self.state.allowed_actions()) def test_can_ask_for_all_known_states(self): assert_equals(set(), self.state.known_states()) self.state.add('new', 'processed', 'process') self.state.add('processed', 'done', 'finalize') assert_equals(set(('new', 'processed', 'done')), self.state.known_states()) def test_can_ask_for_all_non_final_states(self): assert_equals(set(), self.state.known_non_final_states()) self.state.add('new', 'processed', 'process') self.state.add('processed', 'done', 'finalize') assert_equals(set(('new', 'processed')), self.state.known_non_final_states()) # --- handling states ---------------------------------------------------- def test_can_not_set_state_to_invalid_state(self): with assert_raises(StateMachineError): self.state.set_state('invalid') # --- executing ---------------------------------------------------------- def test_can_execute_states(self): self.state.add('new', 'processed', 'process') self.state.execute('process') assert_equals('processed', self.state.state()) def test_handler_is_called_for_state_transition(self): self._transition = None def handler(from_state, to_state, action_name): self._transition = (from_state, to_state, action_name) self.state.add('new', 'new', 'noop', handler) self.state.execute('noop') assert_equals('new', self.state.state()) assert_equals(('new', 'new', 'noop'), self._transition) def test_raise_exception_for_invalid_action(self): self.state.add('new', 'processed', 'process') with assert_raises(StateMachineError): self.state.execute('invalid') self.state.add('processed', 'new', 'rework') with assert_raises(StateMachineError): self.state.execute('rework') self.state.execute('process') with assert_raises(StateMachineError): self.state.execute('process') self.state.execute('rework') def test_raise_exception_if_in_impossible_state(self): state = StateMachine(initial_state='invalid') state.add('new', 'processed', 'process') with assert_raises(StateMachineError): self.state.execute('process') def test_raise_exception_if_no_outgoing_transition_defined_when_executing( self): self.state.add('new', 'processed', 'process') self.state.set_state('processed') with assert_raises(StateMachineError): self.state.execute('rework') # --- transition with operations and conditions -------------------------- def test_can_add_transition_with_additional_operation(self): self.state.add('new', 'processed', 'process', operations=('set_foo', )) def test_can_tell_if_flag_is_set(self): assert_false(self.state.is_set(None)) assert_false(self.state.is_set('foo')) def test_transition_can_also_set_flags(self): self.state.add('new', 'processed', 'process', operations=('set_foo', )) assert_false(self.state.is_set('foo')) self.state.execute('process') assert_true(self.state.is_set('foo')) def test_can_add_conditional_transition(self): self.state.add('new', 'authenticated', 'authenticate', condition='if_tls') def test_allowed_actions_obeys_condition(self): self.state.add('new', 'new', 'use_tls', operations=('set_tls', )) self.state.add('new', 'authenticated', 'authenticate', condition='if_tls') assert_equals(set(('use_tls', )), self.state.allowed_actions()) self.state.execute('use_tls') assert_equals(set(('use_tls', 'authenticate')), self.state.allowed_actions()) def test_conditional_transition_is_only_executed_if_flag_is_true(self): self.state.add('new', 'new', 'use_tls', operations=('set_tls', )) self.state.add('new', 'authenticated', 'authenticate', condition='if_tls') assert_equals('new', self.state.state()) with assert_raises(StateMachineError): self.state.execute('authenticate') self.state.execute('use_tls') assert_true(self.state.is_set('tls')) self.state.execute('authenticate') def test_can_also_specify_negative_flag_checks_for_transitions(self): self.state.add('new', 'new', 'use_tls', operations=('set_tls', ), condition='if_not_tls') self.state.add('new', 'authenticated', 'authenticate', condition='if_tls') with assert_raises(StateMachineError): self.state.execute('authenticate') self.state.execute('use_tls') with assert_raises(StateMachineError): self.state.execute('use_tls') self.state.execute('authenticate')