def setUp(self):
     self.agent_state1 = AssignState()
     self.agent_state2 = AssignState(status=AssignState.STATUS_IN_TASK)
     argparser = ParlaiParser(False, False)
     argparser.add_parlai_data_path()
     argparser.add_mturk_args()
     self.opt = argparser.parse_args()
     self.opt['task'] = 'unittest'
     self.opt['assignment_duration_in_seconds'] = 6
     mturk_agent_ids = ['mturk_agent_1']
     self.mturk_manager = MTurkManager(opt=self.opt,
                                       mturk_agent_ids=mturk_agent_ids)
     self.worker_manager = self.mturk_manager.worker_manager
class TestAssignState(unittest.TestCase):
    """
    Various unit tests for the AssignState class.
    """
    def setUp(self):
        self.agent_state1 = AssignState()
        self.agent_state2 = AssignState(status=AssignState.STATUS_IN_TASK)
        argparser = ParlaiParser(False, False)
        argparser.add_parlai_data_path()
        argparser.add_mturk_args()
        self.opt = argparser.parse_args()
        self.opt['task'] = 'unittest'
        self.opt['assignment_duration_in_seconds'] = 6
        mturk_agent_ids = ['mturk_agent_1']
        self.mturk_manager = MTurkManager(opt=self.opt,
                                          mturk_agent_ids=mturk_agent_ids)
        self.worker_manager = self.mturk_manager.worker_manager

    def tearDown(self):
        self.mturk_manager.shutdown()

    def test_assign_state_init(self):
        """
        Test proper initialization of assignment states.
        """
        self.assertEqual(self.agent_state1.status, AssignState.STATUS_NONE)
        self.assertEqual(len(self.agent_state1.messages), 0)
        self.assertEqual(len(self.agent_state1.message_ids), 0)
        self.assertEqual(self.agent_state2.status, AssignState.STATUS_IN_TASK)
        self.assertEqual(len(self.agent_state1.messages), 0)
        self.assertEqual(len(self.agent_state1.message_ids), 0)

    def test_message_management(self):
        """
        Test message management in an AssignState.
        """
        # Ensure message appends succeed and are idempotent
        self.agent_state1.append_message(MESSAGE_1)
        self.assertEqual(len(self.agent_state1.get_messages()), 1)
        self.agent_state1.append_message(MESSAGE_2)
        self.assertEqual(len(self.agent_state1.get_messages()), 2)
        self.agent_state1.append_message(MESSAGE_1)
        self.assertEqual(len(self.agent_state1.get_messages()), 2)
        self.assertEqual(len(self.agent_state2.get_messages()), 0)
        self.assertIn(MESSAGE_1, self.agent_state1.get_messages())
        self.assertIn(MESSAGE_2, self.agent_state1.get_messages())
        self.assertEqual(len(self.agent_state1.message_ids), 2)
        self.agent_state2.append_message(MESSAGE_1)
        self.assertEqual(len(self.agent_state2.message_ids), 1)

        # Ensure clearing messages acts as intended and doesn't clear agent2
        self.agent_state1.clear_messages()
        self.assertEqual(len(self.agent_state1.messages), 0)
        self.assertEqual(len(self.agent_state1.message_ids), 0)
        self.assertEqual(len(self.agent_state2.message_ids), 1)

    def test_state_handles_status(self):
        """
        Ensures status updates and is_final are valid.
        """

        for status in statuses:
            self.agent_state1.set_status(status)
            self.assertEqual(self.agent_state1.get_status(), status)

        for status in active_statuses:
            self.agent_state1.set_status(status)
            self.assertFalse(self.agent_state1.is_final())

        for status in complete_statuses:
            self.agent_state1.set_status(status)
            self.assertTrue(self.agent_state1.is_final())