def setUp(self): self.agent_state1 = AssignState() self.agent_state2 = AssignState(status=AssignState.STATUS_IN_TASK) argparser = ParlaiParser(False, False) argparser.add_parlai_data_path() argparser.add_mturk_args() self.opt = argparser.parse_args() self.opt['task'] = 'unittest' self.opt['assignment_duration_in_seconds'] = 6 mturk_agent_ids = ['mturk_agent_1'] self.mturk_manager = MTurkManager(opt=self.opt, mturk_agent_ids=mturk_agent_ids) self.worker_manager = self.mturk_manager.worker_manager
class TestAssignState(unittest.TestCase): """ Various unit tests for the AssignState class. """ def setUp(self): self.agent_state1 = AssignState() self.agent_state2 = AssignState(status=AssignState.STATUS_IN_TASK) argparser = ParlaiParser(False, False) argparser.add_parlai_data_path() argparser.add_mturk_args() self.opt = argparser.parse_args() self.opt['task'] = 'unittest' self.opt['assignment_duration_in_seconds'] = 6 mturk_agent_ids = ['mturk_agent_1'] self.mturk_manager = MTurkManager(opt=self.opt, mturk_agent_ids=mturk_agent_ids) self.worker_manager = self.mturk_manager.worker_manager def tearDown(self): self.mturk_manager.shutdown() def test_assign_state_init(self): """ Test proper initialization of assignment states. """ self.assertEqual(self.agent_state1.status, AssignState.STATUS_NONE) self.assertEqual(len(self.agent_state1.messages), 0) self.assertEqual(len(self.agent_state1.message_ids), 0) self.assertEqual(self.agent_state2.status, AssignState.STATUS_IN_TASK) self.assertEqual(len(self.agent_state1.messages), 0) self.assertEqual(len(self.agent_state1.message_ids), 0) def test_message_management(self): """ Test message management in an AssignState. """ # Ensure message appends succeed and are idempotent self.agent_state1.append_message(MESSAGE_1) self.assertEqual(len(self.agent_state1.get_messages()), 1) self.agent_state1.append_message(MESSAGE_2) self.assertEqual(len(self.agent_state1.get_messages()), 2) self.agent_state1.append_message(MESSAGE_1) self.assertEqual(len(self.agent_state1.get_messages()), 2) self.assertEqual(len(self.agent_state2.get_messages()), 0) self.assertIn(MESSAGE_1, self.agent_state1.get_messages()) self.assertIn(MESSAGE_2, self.agent_state1.get_messages()) self.assertEqual(len(self.agent_state1.message_ids), 2) self.agent_state2.append_message(MESSAGE_1) self.assertEqual(len(self.agent_state2.message_ids), 1) # Ensure clearing messages acts as intended and doesn't clear agent2 self.agent_state1.clear_messages() self.assertEqual(len(self.agent_state1.messages), 0) self.assertEqual(len(self.agent_state1.message_ids), 0) self.assertEqual(len(self.agent_state2.message_ids), 1) def test_state_handles_status(self): """ Ensures status updates and is_final are valid. """ for status in statuses: self.agent_state1.set_status(status) self.assertEqual(self.agent_state1.get_status(), status) for status in active_statuses: self.agent_state1.set_status(status) self.assertFalse(self.agent_state1.is_final()) for status in complete_statuses: self.agent_state1.set_status(status) self.assertTrue(self.agent_state1.is_final())