Esempio n. 1
0
    def test_worker_state_agent_management(self):
        """
        Test public state management methods of worker_state.
        """
        agent_1 = MTurkAgent(
            self.opt,
            self.mturk_manager,
            TEST_HIT_ID_1,
            TEST_ASSIGNMENT_ID_1,
            TEST_WORKER_ID_1,
        )
        agent_2 = MTurkAgent(
            self.opt,
            self.mturk_manager,
            TEST_HIT_ID_2,
            TEST_ASSIGNMENT_ID_2,
            TEST_WORKER_ID_1,
        )
        agent_3 = MTurkAgent(
            self.opt,
            self.mturk_manager,
            TEST_HIT_ID_3,
            TEST_ASSIGNMENT_ID_3,
            TEST_WORKER_ID_3,
        )

        self.assertEqual(self.work_state_1.active_conversation_count(), 0)
        self.work_state_1.add_agent(agent_1)
        self.assertEqual(self.work_state_1.active_conversation_count(), 1)
        self.work_state_1.add_agent(agent_2)
        self.assertEqual(self.work_state_1.active_conversation_count(), 2)

        with self.assertRaises(AssertionError):
            self.work_state_1.add_agent(agent_3)

        self.assertEqual(self.work_state_1.active_conversation_count(), 2)
        self.assertEqual(self.work_state_1.completed_assignments(), 0)

        self.assertTrue(self.work_state_1.has_assignment(agent_1.assignment_id))
        self.assertTrue(self.work_state_1.has_assignment(agent_2.assignment_id))
        self.assertFalse(self.work_state_1.has_assignment(agent_3.assignment_id))
        self.assertEqual(
            agent_1, self.work_state_1.get_agent_for_assignment(agent_1.assignment_id)
        )
        self.assertEqual(
            agent_2, self.work_state_1.get_agent_for_assignment(agent_2.assignment_id)
        )
        self.assertIsNone(
            self.work_state_1.get_agent_for_assignment(agent_3.assignment_id)
        )

        agent_1.set_status(AssignState.STATUS_DONE)
        self.assertEqual(self.work_state_1.active_conversation_count(), 1)
        self.assertEqual(self.work_state_1.completed_assignments(), 1)
        agent_2.set_status(AssignState.STATUS_DISCONNECT)
        self.assertEqual(self.work_state_1.active_conversation_count(), 0)
        self.assertEqual(self.work_state_1.completed_assignments(), 1)
Esempio n. 2
0
class TestMTurkAgent(unittest.TestCase):
    """
    Various unit tests for the MTurkAgent class.
    """
    def setUp(self):
        argparser = ParlaiParser(False, False)
        argparser.add_parlai_data_path()
        argparser.add_mturk_args()
        self.opt = argparser.parse_args()
        self.opt['task'] = 'unittest'
        self.opt['assignment_duration_in_seconds'] = 6
        mturk_agent_ids = ['mturk_agent_1']
        self.mturk_manager = MTurkManager(opt=self.opt.copy(),
                                          mturk_agent_ids=mturk_agent_ids)
        self.worker_manager = self.mturk_manager.worker_manager
        self.mturk_manager.send_message = mock.MagicMock()
        self.mturk_manager.send_state_change = mock.MagicMock()
        self.mturk_manager.send_command = mock.MagicMock()

        self.turk_agent = MTurkAgent(
            self.opt.copy(),
            self.mturk_manager,
            TEST_HIT_ID_1,
            TEST_ASSIGNMENT_ID_1,
            TEST_WORKER_ID_1,
        )

    def tearDown(self):
        self.mturk_manager.shutdown()

        disconnect_path = os.path.join(parent_dir, 'disconnect-test.pickle')
        if os.path.exists(disconnect_path):
            os.remove(disconnect_path)

    def test_init(self):
        """
        Test initialization of an agent.
        """
        self.assertIsNotNone(self.turk_agent.creation_time)
        self.assertIsNone(self.turk_agent.id)
        self.assertIsNone(self.turk_agent.message_request_time)
        self.assertIsNone(self.turk_agent.conversation_id)
        self.assertFalse(self.turk_agent.some_agent_disconnected)
        self.assertFalse(self.turk_agent.hit_is_expired)
        self.assertFalse(self.turk_agent.hit_is_abandoned)
        self.assertFalse(self.turk_agent.hit_is_returned)
        self.assertFalse(self.turk_agent.hit_is_complete)
        self.assertFalse(self.turk_agent.disconnected)

    def test_state_wrappers(self):
        """
        Test the mturk agent wrappers around its state.
        """
        for status in statuses:
            self.turk_agent.set_status(status)
            self.assertEqual(self.turk_agent.get_status(), status)
        for status in [
                AssignState.STATUS_DONE, AssignState.STATUS_PARTNER_DISCONNECT
        ]:
            self.turk_agent.set_status(status)
            self.assertTrue(self.turk_agent.submitted_hit())

        for status in active_statuses:
            self.turk_agent.set_status(status)
            self.assertFalse(self.turk_agent.is_final())

        for status in complete_statuses:
            self.turk_agent.set_status(status)
            self.assertTrue(self.turk_agent.is_final())

        self.turk_agent.state.append_message(MESSAGE_1)
        self.assertEqual(len(self.turk_agent.get_messages()), 1)
        self.turk_agent.state.append_message(MESSAGE_2)
        self.assertEqual(len(self.turk_agent.get_messages()), 2)
        self.turk_agent.state.append_message(MESSAGE_1)
        self.assertEqual(len(self.turk_agent.get_messages()), 2)
        self.assertIn(MESSAGE_1, self.turk_agent.get_messages())
        self.assertIn(MESSAGE_2, self.turk_agent.get_messages())

        self.turk_agent.clear_messages()
        self.assertEqual(len(self.turk_agent.get_messages()), 0)

    def test_connection_id(self):
        """
        Ensure the connection_id hasn't changed.
        """
        connection_id = "{}_{}".format(self.turk_agent.worker_id,
                                       self.turk_agent.assignment_id)
        self.assertEqual(self.turk_agent.get_connection_id(), connection_id)

    def test_status_change(self):
        self.turk_agent.set_status(AssignState.STATUS_ONBOARDING)
        time.sleep(0.07)
        self.assertEqual(self.turk_agent.get_status(),
                         AssignState.STATUS_ONBOARDING)
        self.turk_agent.set_status(AssignState.STATUS_WAITING)
        time.sleep(0.07)
        self.assertEqual(self.turk_agent.get_status(),
                         AssignState.STATUS_WAITING)

    def test_message_queue(self):
        """
        Ensure observations and acts work as expected.
        """
        self.turk_agent.observe(ACT_1)
        self.mturk_manager.send_message.assert_called_with(
            TEST_WORKER_ID_1, TEST_ASSIGNMENT_ID_1, ACT_1)

        # First act comes through the queue and returns properly
        self.assertTrue(self.turk_agent.msg_queue.empty())
        self.turk_agent.id = AGENT_ID
        self.turk_agent.put_data(MESSAGE_ID_1, ACT_1)
        self.assertTrue(self.turk_agent.recieved_packets[MESSAGE_ID_1])
        self.assertFalse(self.turk_agent.msg_queue.empty())
        returned_act = self.turk_agent.get_new_act_message()
        self.assertEqual(returned_act, ACT_1)

        # Repeat act is ignored
        self.turk_agent.put_data(MESSAGE_ID_1, ACT_1)
        self.assertTrue(self.turk_agent.msg_queue.empty())

        for i in range(100):
            self.turk_agent.put_data(str(i), ACT_1)
        self.assertEqual(self.turk_agent.msg_queue.qsize(), 100)
        self.turk_agent.flush_msg_queue()
        self.assertTrue(self.turk_agent.msg_queue.empty())

        # Test non-act messages
        blank_message = self.turk_agent.get_new_act_message()
        self.assertIsNone(blank_message)

        self.turk_agent.disconnected = True
        with self.assertRaises(AgentDisconnectedError):
            # Expect this to be a disconnect message
            self.turk_agent.get_new_act_message()
        self.turk_agent.disconnected = False

        self.turk_agent.hit_is_returned = True
        with self.assertRaises(AgentReturnedError):
            # Expect this to be a returned message
            self.turk_agent.get_new_act_message()
        self.turk_agent.hit_is_returned = False

        # Reduce state
        self.turk_agent.reduce_state()
        self.assertIsNone(self.turk_agent.msg_queue)
        self.assertIsNone(self.turk_agent.recieved_packets)

    def test_message_acts(self):
        self.mturk_manager.handle_turker_timeout = mock.MagicMock()

        # non-Blocking check
        self.assertIsNone(self.turk_agent.message_request_time)
        returned_act = self.turk_agent.act(blocking=False)
        self.assertIsNotNone(self.turk_agent.message_request_time)
        self.assertIsNone(returned_act)
        self.turk_agent.id = AGENT_ID
        self.turk_agent.put_data(MESSAGE_ID_1, ACT_1)
        returned_act = self.turk_agent.act(blocking=False)
        self.assertIsNone(self.turk_agent.message_request_time)
        self.assertEqual(returned_act, ACT_1)
        self.mturk_manager.send_command.assert_called_once()

        # non-Blocking timeout check
        with self.assertRaises(AgentTimeoutError):
            self.mturk_manager.send_command = mock.MagicMock()
            returned_act = self.turk_agent.act(timeout=0.07, blocking=False)
            self.assertIsNotNone(self.turk_agent.message_request_time)
            self.assertIsNone(returned_act)
            while returned_act is None:
                returned_act = self.turk_agent.act(timeout=0.07,
                                                   blocking=False)

        # Blocking timeout check
        with self.assertRaises(AgentTimeoutError):
            returned_act = self.turk_agent.act(timeout=0.07)