def setUp(self): # A source for session information class StubSessionSource (object): pass StubSessionSource = Stub(_SessionSourceInterface)(StubSessionSource) # A source for vehicle availability information class StubLocationsSource (object): def fetch_location_profiles(self, sessionid): pass StubLocationsSource = Stub(_LocationsSourceInterface)(StubLocationsSource) # A generator for a representation (view) of the availability information class StubLocationsView (object): def render_locations(self, session, locations): if session: return "Success" else: return "Failure" StubLocationsView = Stub(_LocationsViewInterface)(StubLocationsView) class StubErrorView (object): pass StubErrorView = Stub(_ErrorViewInterface)(StubErrorView) # The system under test self.session_source = StubSessionSource() self.locations_source = StubLocationsSource() self.locations_view = StubLocationsView() self.error_view = StubErrorView() self.handler = LocationsHandler(session_source=self.session_source, locations_source=self.locations_source, locations_view=self.locations_view, error_view=self.error_view) self.handler.request = StubRequest() self.handler.response = StubResponse()
class LocationHandlerTest (unittest.TestCase): def setUp(self): # A source for session information class StubSessionSource (object): pass StubSessionSource = Stub(_SessionSourceInterface)(StubSessionSource) # A source for vehicle availability information class StubLocationsSource (object): def fetch_location_profiles(self, sessionid): pass StubLocationsSource = Stub(_LocationsSourceInterface)(StubLocationsSource) # A generator for a representation (view) of the availability information class StubLocationsView (object): def render_locations(self, session, locations): if session: return "Success" else: return "Failure" StubLocationsView = Stub(_LocationsViewInterface)(StubLocationsView) class StubErrorView (object): pass StubErrorView = Stub(_ErrorViewInterface)(StubErrorView) # The system under test self.session_source = StubSessionSource() self.locations_source = StubLocationsSource() self.locations_view = StubLocationsView() self.error_view = StubErrorView() self.handler = LocationsHandler(session_source=self.session_source, locations_source=self.locations_source, locations_view=self.locations_view, error_view=self.error_view) self.handler.request = StubRequest() self.handler.response = StubResponse() def testShouldRespondSuccessfullyWhenGivenAValidSession(self): # Given... @patch(self.handler) def get_user_id(self): self.userid_called = True return 'user1234' @patch(self.handler) def get_session_id(self): self.sessionid_called = True return 'ses1234' @patch(self.handler) def get_session(self, userid, sessionid): self.userid = userid self.sessionid = sessionid return 'my session' @patch(self.locations_source) def fetch_location_profiles(self, sessionid): self.sessionid = sessionid return 'my locations' @patch(self.locations_view) def render_locations(self, session, locations): self.session = session self.locations = locations return 'location profiles body' @patch(self.error_view) def render_error(self, error_code, error_msg, error_detail): pass # When... self.handler.get() # Then... response_body = self.handler.response.out.getvalue() self.assert_(self.handler.userid_called) self.assert_(self.handler.sessionid_called) self.assertEqual(self.handler.userid, 'user1234') self.assertEqual(self.handler.sessionid, 'ses1234') self.assertEqual(self.locations_source.sessionid, 'ses1234') self.assertEqual(self.locations_view.locations, 'my locations') self.assertEqual(self.locations_view.session, 'my session') self.assertEqual(response_body, 'location profiles body') def testShouldGenerateAFailureWhenGivenAnInvalidSession(self): # Given... @patch(self.handler) def get_user_id(self): self.userid_called = True return 'user1234' @patch(self.handler) def get_session_id(self): self.sessionid_called = True return 'ses1234' @patch(self.handler) def get_session(self, userid, sessionid): self.userid = userid self.sessionid = sessionid raise SessionExpiredError() @patch(self.error_view) def render_error(self, error_code, error_msg, error_detail): return error_detail # When... self.handler.get() # Then... response_body = self.handler.response.out.getvalue() self.assert_(self.handler.userid_called) self.assert_(self.handler.sessionid_called) self.assertEqual(self.handler.userid, 'user1234') self.assertEqual(self.handler.sessionid, 'ses1234') self.assert_('SessionExpiredError' in response_body, 'Should contain SessionExpiredError: %r' % response_body)