class MTurkManager(): """Manages interactions between MTurk agents as well as direct interactions between a world and the MTurk server. """ def __init__(self, opt, mturk_agent_ids, is_test=False): """Create an MTurkManager using the given setup opts and a list of agent_ids that will participate in each conversation """ self.opt = opt self.server_url = None self.port = 443 self.task_group_id = None self.run_id = None self.mturk_agent_ids = mturk_agent_ids self.task_files_to_copy = None self.is_sandbox = opt['is_sandbox'] self.worker_pool_change_condition = threading.Condition() self.onboard_function = None self.num_conversations = opt['num_conversations'] self.required_hits = math.ceil( self.num_conversations * len(self.mturk_agent_ids) * HIT_MULT ) self.socket_manager = None self.is_test = is_test self._init_logs() ### Helpers and internal manager methods ### def _init_state(self): """Initialize everything in the worker, task, and thread states""" self.hit_id_list = [] self.worker_pool = [] self.assignment_to_onboard_thread = {} self.task_threads = [] self.conversation_index = 0 self.started_conversations = 0 self.completed_conversations = 0 self.mturk_workers = {} self.conv_to_agent = {} self.accepting_workers = True self._load_disconnects() def _init_logs(self): """Initialize logging settings from the opt""" shared_utils.set_is_debug(self.opt['is_debug']) shared_utils.set_log_level(self.opt['log_level']) def _load_disconnects(self): """Load disconnects from file, populate the disconnects field for any worker_id that has disconnects in the list. Any disconnect that occurred longer ago than the disconnect persist length is ignored """ self.disconnects = [] # Load disconnects from file file_path = os.path.join(parent_dir, DISCONNECT_FILE_NAME) compare_time = time.time() if os.path.exists(file_path): with open(file_path, 'rb') as f: old_disconnects = pickle.load(f) self.disconnects = [ d for d in old_disconnects if (compare_time - d['time']) < DISCONNECT_PERSIST_LENGTH ] # Initialize worker states with proper number of disconnects for disconnect in self.disconnects: worker_id = disconnect['id'] if not worker_id in self.mturk_workers: # add this worker to the worker state self.mturk_workers[worker_id] = WorkerState(worker_id) self.mturk_workers[worker_id].disconnects += 1 def _save_disconnects(self): """Saves the local list of disconnects to file""" file_path = os.path.join(parent_dir, DISCONNECT_FILE_NAME) if os.path.exists(file_path): os.remove(file_path) with open(file_path, 'wb') as f: pickle.dump(self.disconnects, f, pickle.HIGHEST_PROTOCOL) def _handle_bad_disconnect(self, worker_id): """Update the number of bad disconnects for the given worker, block them if they've exceeded the disconnect limit """ if not self.is_sandbox: self.mturk_workers[worker_id].disconnects += 1 self.disconnects.append({'time': time.time(), 'id': worker_id}) if self.mturk_workers[worker_id].disconnects > MAX_DISCONNECTS: text = ( 'This worker has repeatedly disconnected from these tasks,' ' which require constant connection to complete properly ' 'as they involve interaction with other Turkers. They have' ' been blocked to ensure a better experience for other ' 'workers who don\'t disconnect.' ) self.block_worker(worker_id, text) shared_utils.print_and_log( logging.INFO, 'Worker {} was blocked - too many disconnects'.format( worker_id ), True ) def _get_agent_from_pkt(self, pkt): """Get sender, assignment, and conv ids from a packet""" worker_id = pkt.sender_id assignment_id = pkt.assignment_id agent = self._get_agent(worker_id, assignment_id) if agent == None: self._log_missing_agent(worker_id, assignment_id) return agent def _change_worker_to_conv(self, pkt): """Update a worker to a new conversation given a packet from the conversation to be switched to """ agent = self._get_agent_from_pkt(pkt) if agent is not None: self._assign_agent_to_conversation(agent, agent.conversation_id) def _set_worker_status_to_onboard(self, pkt): """Changes assignment status to onboarding based on the packet""" agent = self._get_agent_from_pkt(pkt) if agent is not None: agent.state.status = AssignState.STATUS_ONBOARDING def _set_worker_status_to_waiting(self, pkt): """Changes assignment status to waiting based on the packet""" agent = self._get_agent_from_pkt(pkt) if agent is not None: agent.state.status = AssignState.STATUS_WAITING # Add the worker to pool with self.worker_pool_change_condition: shared_utils.print_and_log( logging.DEBUG, "Adding worker {} to pool...".format(agent.worker_id) ) self.worker_pool.append(agent) def _move_workers_to_waiting(self, workers): """Put all workers into waiting worlds, expire them if no longer accepting workers. If the worker is already final, delete it """ for worker in workers: worker_id = worker.worker_id assignment_id = worker.assignment_id if worker.state.is_final(): worker.reduce_state() self.socket_manager.close_channel(worker.get_connection_id()) continue conversation_id = 'w_{}'.format(uuid.uuid4()) if self.accepting_workers: # Move the worker into a waiting world worker.change_conversation( conversation_id=conversation_id, agent_id='waiting', change_callback=self._set_worker_status_to_waiting ) else: self.force_expire_hit(worker_id, assignment_id) def _expire_onboarding_pool(self): """Expire any worker that is in an onboarding thread""" for worker_id in self.mturk_workers: for assign_id in self.mturk_workers[worker_id].agents: agent = self.mturk_workers[worker_id].agents[assign_id] if (agent.state.status == AssignState.STATUS_ONBOARDING): self.force_expire_hit(worker_id, assign_id) def _expire_worker_pool(self): """Expire all workers in the worker pool""" for agent in self.worker_pool: self.force_expire_hit(agent.worker_id, agent.assignment_id) def _get_unique_pool(self, eligibility_function): """Return a filtered version of the worker pool where each worker is only listed a maximum of one time. In sandbox this is overridden for testing purposes, and the same worker can be returned more than once """ workers = [w for w in self.worker_pool if not w.hit_is_returned and eligibility_function(w)] unique_workers = [] unique_worker_ids = [] for w in workers: if (self.is_sandbox) or (w.worker_id not in unique_worker_ids): unique_workers.append(w) unique_worker_ids.append(w.worker_id) return unique_workers def _handle_worker_disconnect(self, worker_id, assignment_id): """Mark a worker as disconnected and send a message to all agents in his conversation that a partner has disconnected. """ agent = self._get_agent(worker_id, assignment_id) if agent is None: self._log_missing_agent(worker_id, assignment_id) else: # Disconnect in conversation is not workable agent.state.status = AssignState.STATUS_DISCONNECT # in conversation, inform others about disconnect conversation_id = agent.conversation_id if agent in self.conv_to_agent[conversation_id]: for other_agent in self.conv_to_agent[conversation_id]: if agent.assignment_id != other_agent.assignment_id: self._handle_partner_disconnect( other_agent.worker_id, other_agent.assignment_id ) if len(self.mturk_agent_ids) > 1: # The user disconnected from inside a conversation with # another turker, record this as bad behavoir self._handle_bad_disconnect(worker_id) def _handle_partner_disconnect(self, worker_id, assignment_id): """Send a message to a worker notifying them that a partner has disconnected and we marked the HIT as complete for them """ agent = self._get_agent(worker_id, assignment_id) if agent is None: self._log_missing_agent(worker_id, assignment_id) elif not agent.state.is_final(): # Update the assignment state agent.some_agent_disconnected = True agent.state.status = AssignState.STATUS_PARTNER_DISCONNECT # Create and send the command data = agent.get_inactive_command_data() self.send_command(worker_id, assignment_id, data) def _restore_worker_state(self, worker_id, assignment_id): """Send a command to restore the state of an agent who reconnected""" agent = self._get_agent(worker_id, assignment_id) if agent is None: self._log_missing_agent(worker_id, assignment_id) else: def _push_worker_state(msg): if len(agent.state.messages) != 0: data = { 'text': data_model.COMMAND_RESTORE_STATE, 'messages': agent.state.messages, 'last_command': agent.state.last_command } self.send_command(worker_id, assignment_id, data) agent.change_conversation( conversation_id=agent.conversation_id, agent_id=agent.id, change_callback=_push_worker_state ) def _setup_socket(self): """Set up a socket_manager with defined callbacks""" self.socket_manager = SocketManager( self.server_url, self.port, self._on_alive, self._on_new_message, self._on_socket_dead, self.task_group_id ) def _on_alive(self, pkt): """Update MTurkManager's state when a worker sends an alive packet. This asks the socket manager to open a new channel and then handles ensuring the worker state is consistent """ shared_utils.print_and_log( logging.DEBUG, 'on_agent_alive: {}'.format(pkt) ) worker_id = pkt.data['worker_id'] hit_id = pkt.data['hit_id'] assign_id = pkt.data['assignment_id'] conversation_id = pkt.data['conversation_id'] # Open a channel if it doesn't already exist self.socket_manager.open_channel(worker_id, assign_id) if not worker_id in self.mturk_workers: # First time this worker has connected, start tracking self.mturk_workers[worker_id] = WorkerState(worker_id) # Update state of worker based on this connect curr_worker_state = self._get_worker(worker_id) if not assign_id: # invalid assignment_id is an auto-fail shared_utils.print_and_log( logging.WARN, 'Agent ({}) with no assign_id called alive'.format(worker_id) ) elif not assign_id in curr_worker_state.agents: # First time this worker has connected under this assignment, init # new agent if we are still accepting workers if self.accepting_workers: convs = curr_worker_state.active_conversation_count() allowed_convs = self.opt['allowed_conversations'] if allowed_convs == 0 or convs < allowed_convs: agent = self._create_agent(hit_id, assign_id, worker_id) curr_worker_state.add_agent(assign_id, agent) self._onboard_new_worker(agent) else: text = ('You can participate in only {} of these HITs at ' 'once. Please return this HIT and finish your ' 'existing HITs before accepting more.'.format( allowed_convs )) self.force_expire_hit(worker_id, assign_id, text) else: self.force_expire_hit(worker_id, assign_id) else: agent = curr_worker_state.agents[assign_id] agent.log_reconnect() if agent.state.status == AssignState.STATUS_NONE: # Reconnecting before even being given a world. The retries # for switching to the onboarding world should catch this return elif (agent.state.status == AssignState.STATUS_ONBOARDING or agent.state.status == AssignState.STATUS_WAITING): # Reconnecting to the onboarding world or to a waiting world # should either restore state or expire (if workers are no # longer being accepted for this task) if not self.accepting_workers: self.force_expire_hit(worker_id, assign_id) elif not conversation_id: self._restore_worker_state(worker_id, assign_id) elif agent.state.status == AssignState.STATUS_IN_TASK: # Reconnecting to the onboarding world or to a task world # should resend the messages already in the conversation if not conversation_id: self._restore_worker_state(worker_id, assign_id) elif agent.state.status == AssignState.STATUS_ASSIGNED: # Connect after a switch to a task world, mark the switch agent.state.status = AssignState.STATUS_IN_TASK agent.state.last_command = None agent.state.messages = [] elif (agent.state.status == AssignState.STATUS_DISCONNECT or agent.state.status == AssignState.STATUS_DONE or agent.state.status == AssignState.STATUS_EXPIRED or agent.state.status == AssignState.STATUS_RETURNED or agent.state.status == AssignState.STATUS_PARTNER_DISCONNECT): # inform the connecting user in all of these cases that the # task is no longer workable, use appropriate message data = agent.get_inactive_command_data() self.send_command(worker_id, assign_id, data) def _on_new_message(self, pkt): """Put an incoming message onto the correct agent's message queue and add it to the proper message thread as long as the agent is active """ worker_id = pkt.sender_id assignment_id = pkt.assignment_id agent = self._get_agent(worker_id, assignment_id) if agent is None: self._log_missing_agent(worker_id, assignment_id) elif not agent.state.is_final(): shared_utils.print_and_log( logging.INFO, 'Manager received: {}'.format(pkt), should_print=self.opt['verbose'] ) # Push the message to the message thread to send on a reconnect agent.state.messages.append(pkt.data) # Clear the send message command, as a message was recieved agent.state.last_command = None # TODO ensure you can't duplicate a message push here agent.msg_queue.put(pkt.data) def _on_socket_dead(self, worker_id, assignment_id): """Handle a disconnect event, update state as required and notifying other agents if the disconnected agent was in conversation with them returns False if the socket death should be ignored and the socket should stay open and not be considered disconnected """ agent = self._get_agent(worker_id, assignment_id) if agent is None: # This worker never registered, so we don't do anything return shared_utils.print_and_log( logging.DEBUG, 'Worker {} disconnected from {} in status {}'.format( worker_id, assignment_id, agent.state.status ) ) if agent.state.status == AssignState.STATUS_NONE: # Agent never made it to onboarding, delete agent.state.status = AssignState.STATUS_DISCONNECT agent.reduce_state() elif agent.state.status == AssignState.STATUS_ONBOARDING: # Agent never made it to task pool, the onboarding thread will die # and delete the agent if we mark it as a disconnect agent.state.status = AssignState.STATUS_DISCONNECT agent.disconnected = True elif agent.state.status == AssignState.STATUS_WAITING: # agent is in pool, remove from pool and delete if agent in self.worker_pool: with self.worker_pool_change_condition: self.worker_pool.remove(agent) agent.state.status = AssignState.STATUS_DISCONNECT agent.reduce_state() elif agent.state.status == AssignState.STATUS_IN_TASK: self._handle_worker_disconnect(worker_id, assignment_id) agent.disconnected = True elif agent.state.status == AssignState.STATUS_DONE: # It's okay if a complete assignment socket dies, but wait for the # world to clean up the resource return elif agent.state.status == AssignState.STATUS_ASSIGNED: # mark the agent in the assigned state as disconnected, the task # spawn thread is responsible for cleanup agent.state.status = AssignState.STATUS_DISCONNECT agent.disconnected = True self.socket_manager.close_channel(agent.get_connection_id()) def _create_agent(self, hit_id, assignment_id, worker_id): """Initialize an agent and return it""" return MTurkAgent(self.opt, self, hit_id, assignment_id, worker_id) def _onboard_new_worker(self, mturk_agent): """Handle creating an onboarding thread and moving an agent through the onboarding process, updating the state properly along the way """ # get state variable in question worker_id = mturk_agent.worker_id assignment_id = mturk_agent.assignment_id def _onboard_function(mturk_agent): """Onboarding wrapper to set state to onboarding properly""" if self.onboard_function: conversation_id = 'o_'+str(uuid.uuid4()) mturk_agent.change_conversation( conversation_id=conversation_id, agent_id='onboarding', change_callback=self._set_worker_status_to_onboard ) # Wait for turker to be in onboarding status mturk_agent.wait_for_status(AssignState.STATUS_ONBOARDING) # call onboarding function self.onboard_function(mturk_agent) # once onboarding is done, move into a waiting world self._move_workers_to_waiting([mturk_agent]) if not assignment_id in self.assignment_to_onboard_thread: # Start the onboarding thread and run it onboard_thread = threading.Thread( target=_onboard_function, args=(mturk_agent,), name='onboard-{}-{}'.format(worker_id, assignment_id) ) onboard_thread.daemon = True onboard_thread.start() self.assignment_to_onboard_thread[assignment_id] = onboard_thread def _assign_agent_to_conversation(self, agent, conv_id): """Register an agent object with a conversation id, update status""" if agent.state.status != AssignState.STATUS_IN_TASK: # Avoid on a second ack if alive already came through agent.state.status = AssignState.STATUS_ASSIGNED self.socket_manager.delay_heartbeat_until( agent.get_connection_id(), time.time() + HEARTBEAT_DELAY_TIME ) agent.conversation_id = conv_id if not conv_id in self.conv_to_agent: self.conv_to_agent[conv_id] = [] self.conv_to_agent[conv_id].append(agent) def _no_workers_incomplete(self, workers): """Return True if all the given workers completed their task""" for w in workers: if w.state.is_final() and w.state.status != \ AssignState.STATUS_DONE: return False return True def _get_worker(self, worker_id): """A safe way to get a worker by worker_id""" if worker_id in self.mturk_workers: return self.mturk_workers[worker_id] return None def _get_agent(self, worker_id, assignment_id): """A safe way to get an agent by worker and assignment_id""" worker = self._get_worker(worker_id) if worker is not None: if assignment_id in worker.agents: return worker.agents[assignment_id] return None def _log_missing_agent(self, worker_id, assignment_id): """Logs when an agent was expected to exist, yet for some reason it didn't. If these happen often there is a problem""" shared_utils.print_and_log( logging.WARN, 'Expected to have an agent for {}_{}, yet none was found'.format( worker_id, assignment_id ) ) ### Manager Lifecycle Functions ### def setup_server(self, task_directory_path=None): """Prepare the MTurk server for the new HIT we would like to submit""" fin_word = 'start' if self.opt['count_complete']: fin_word = 'finish' shared_utils.print_and_log( logging.INFO, '\nYou are going to allow workers from Amazon Mechanical Turk to ' 'be an agent in ParlAI.\nDuring this process, Internet connection ' 'is required, and you should turn off your computer\'s auto-sleep ' 'feature.\nEnough HITs will be created to fulfill {} times the ' 'number of conversations requested, extra HITs will be expired ' 'once the desired conversations {}.'.format(HIT_MULT, fin_word), should_print=True ) key_input = input('Please press Enter to continue... ') shared_utils.print_and_log(logging.NOTSET, '', True) mturk_utils.setup_aws_credentials() # See if there's enough money in the account to fund the HITs requested num_assignments = self.required_hits payment_opt = { 'type': 'reward', 'num_total_assignments': num_assignments, 'reward': self.opt['reward'], # in dollars 'unique': self.opt['unique_worker'] } total_cost = mturk_utils.calculate_mturk_cost(payment_opt=payment_opt) if not mturk_utils.check_mturk_balance( balance_needed=total_cost, is_sandbox=self.opt['is_sandbox']): raise SystemExit('Insufficient funds') if ((not self.opt['is_sandbox']) and (total_cost > 100 or self.opt['reward'] > 1)): confirm_string = '$%.2f' % total_cost expected_cost = total_cost / HIT_MULT expected_string = '$%.2f' % expected_cost shared_utils.print_and_log( logging.INFO, 'You are going to create {} HITs at {} per assignment, for a ' 'total cost up to {} after MTurk fees. Please enter "{}" to ' 'confirm and continue, and anything else to cancel.\nNote that' ' of the {}, the target amount to spend is {}.'.format( self.required_hits, '$%.2f' % self.opt['reward'], confirm_string, confirm_string, confirm_string, expected_string ), should_print=True ) check = input('Enter here: ') if (check != confirm_string and ('$' + check) != confirm_string): raise SystemExit('Cancelling') shared_utils.print_and_log(logging.INFO, 'Setting up MTurk server...', should_print=True) mturk_utils.create_hit_config( task_description=self.opt['task_description'], unique_worker=self.opt['unique_worker'], is_sandbox=self.opt['is_sandbox'] ) # Poplulate files to copy over to the server if not self.task_files_to_copy: self.task_files_to_copy = [] if not task_directory_path: task_directory_path = os.path.join( self.opt['parlai_home'], 'parlai', 'mturk', 'tasks', self.opt['task'] ) self.task_files_to_copy.append( os.path.join(task_directory_path, 'html', 'cover_page.html')) for mturk_agent_id in self.mturk_agent_ids + ['onboarding']: self.task_files_to_copy.append(os.path.join( task_directory_path, 'html', '{}_index.html'.format(mturk_agent_id) )) # Setup the server with a likely-unique app-name task_name = '{}-{}'.format(str(uuid.uuid4())[:8], self.opt['task']) self.server_task_name = \ ''.join(e for e in task_name if e.isalnum() or e == '-') self.server_url = server_utils.setup_server(self.server_task_name, self.task_files_to_copy) shared_utils.print_and_log(logging.INFO, self.server_url) shared_utils.print_and_log(logging.INFO, "MTurk server setup done.\n", should_print=True) def ready_to_accept_workers(self): """Set up socket to start communicating to workers""" shared_utils.print_and_log(logging.INFO, 'Local: Setting up SocketIO...', not self.is_test) self._setup_socket() def start_new_run(self): """Clear state to prepare for a new run""" self.run_id = str(int(time.time())) self.task_group_id = '{}_{}'.format(self.opt['task'], self.run_id) self._init_state() def set_onboard_function(self, onboard_function): self.onboard_function = onboard_function def start_task(self, eligibility_function, assign_role_function, task_function): """Handle running a task by checking to see when enough agents are in the pool to start an instance of the task. Continue doing this until the desired number of conversations is had. """ def _task_function(opt, workers, conversation_id): """Wait for all workers to join world before running the task""" shared_utils.print_and_log( logging.INFO, 'Starting task {}...'.format(conversation_id) ) shared_utils.print_and_log( logging.DEBUG, 'Waiting for all workers to join the conversation...' ) start_time = time.time() while True: all_joined = True for worker in workers: # check the status of an individual worker assignment if worker.state.status != AssignState.STATUS_IN_TASK: all_joined = False if all_joined: break if time.time() - start_time > WORLD_START_TIMEOUT: # We waited but not all workers rejoined, throw workers # back into the waiting pool. Stragglers will disconnect # from there shared_utils.print_and_log( logging.INFO, 'Timeout waiting for {}, move back to waiting'.format( conversation_id ) ) self._move_workers_to_waiting(workers) return time.sleep(shared_utils.THREAD_SHORT_SLEEP) shared_utils.print_and_log( logging.INFO, 'All workers joined the conversation {}!'.format( conversation_id ) ) self.started_conversations += 1 task_function(mturk_manager=self, opt=opt, workers=workers) # Delete extra state data that is now unneeded for worker in workers: worker.state.clear_messages() # Count if it's a completed conversation if self._no_workers_incomplete(workers): self.completed_conversations += 1 while True: # Loop forever starting task worlds until desired convos are had with self.worker_pool_change_condition: valid_workers = self._get_unique_pool(eligibility_function) needed_workers = len(self.mturk_agent_ids) if len(valid_workers) >= needed_workers: # enough workers in pool to start new conversation self.conversation_index += 1 new_conversation_id = \ 't_{}'.format(self.conversation_index) # Add the required number of valid workers to the conv workers = [w for w in valid_workers[:needed_workers]] assign_role_function(workers) for w in workers: w.change_conversation( conversation_id=new_conversation_id, agent_id=w.id, change_callback=self._change_worker_to_conv ) # Remove selected workers from the pool self.worker_pool.remove(w) # Start a new thread for this task world task_thread = threading.Thread( target=_task_function, args=(self.opt, workers, new_conversation_id), name='task-{}'.format(new_conversation_id) ) task_thread.daemon = True task_thread.start() self.task_threads.append(task_thread) # Once we've had enough conversations, finish and break compare_count = self.started_conversations if (self.opt['count_complete']): compare_count = self.completed_conversations if compare_count == self.num_conversations: self.accepting_workers = False self.expire_all_unassigned_hits() self._expire_onboarding_pool() self._expire_worker_pool() # Wait for all conversations to finish, then break from # the while loop for thread in self.task_threads: thread.join() break time.sleep(shared_utils.THREAD_MEDIUM_SLEEP) def shutdown(self): """Handle any mturk client shutdown cleanup.""" # Ensure all threads are cleaned and state and HITs are handled self.expire_all_unassigned_hits() self._expire_onboarding_pool() self._expire_worker_pool() self.socket_manager.close_all_channels() for assignment_id in self.assignment_to_onboard_thread: self.assignment_to_onboard_thread[assignment_id].join() self._save_disconnects() server_utils.delete_server(self.server_task_name) ### MTurk Agent Interaction Functions ### def force_expire_hit(self, worker_id, assign_id, text=None, ack_func=None): """Send a command to expire a hit to the provided agent, update State to reflect that the HIT is now expired """ # Expire in the state is_final = True agent = self._get_agent(worker_id, assign_id) if agent is not None: if not agent.state.is_final(): is_final = False agent.state.status = AssignState.STATUS_EXPIRED agent.hit_is_expired = True # Send the expiration command if text == None: text = ('This HIT is expired, please return and take a new ' 'one if you\'d want to work on this task.') data = {'text': data_model.COMMAND_EXPIRE_HIT, 'inactive_text': text} self.send_command(worker_id, assign_id, data, ack_func=ack_func) def handle_turker_timeout(self, worker_id, assign_id): """To be used by the MTurk agent when the worker doesn't send a message within the expected window. """ # Expire the hit for the disconnected user text = ('You haven\'t entered a message in too long, leaving the other' ' participant unable to complete the HIT. Thus this hit has ' 'been expired and you have been considered disconnected. ' 'Disconnect too frequently and you will be blocked from ' 'working on these HITs in the future.') self.force_expire_hit(worker_id, assign_id, text) # Send the disconnect event to all workers in the convo self._handle_worker_disconnect(worker_id, assign_id) def send_message(self, receiver_id, assignment_id, data, blocking=True, ack_func=None): """Send a message through the socket manager, update conversation state """ data['type'] = data_model.MESSAGE_TYPE_MESSAGE # Force messages to have a unique ID if 'message_id' not in data: data['message_id'] = str(uuid.uuid4()) event_id = shared_utils.generate_event_id(receiver_id) packet = Packet( event_id, Packet.TYPE_MESSAGE, self.socket_manager.get_my_sender_id(), receiver_id, assignment_id, data, blocking=blocking, ack_func=ack_func ) shared_utils.print_and_log( logging.INFO, 'Manager sending: {}'.format(packet), should_print=self.opt['verbose'] ) # Push outgoing message to the message thread to be able to resend # on a reconnect event agent = self._get_agent(receiver_id, assignment_id) if agent is not None: agent.state.messages.append(packet.data) self.socket_manager.queue_packet(packet) def send_command(self, receiver_id, assignment_id, data, blocking=True, ack_func=None): """Sends a command through the socket manager, update conversation state """ data['type'] = data_model.MESSAGE_TYPE_COMMAND event_id = shared_utils.generate_event_id(receiver_id) packet = Packet( event_id, Packet.TYPE_MESSAGE, self.socket_manager.get_my_sender_id(), receiver_id, assignment_id, data, blocking=blocking, ack_func=ack_func ) agent = self._get_agent(receiver_id, assignment_id) if (data['text'] != data_model.COMMAND_CHANGE_CONVERSATION and data['text'] != data_model.COMMAND_RESTORE_STATE and agent is not None): # Append last command, as it might be necessary to restore state agent.state.last_command = packet.data self.socket_manager.queue_packet(packet) def mark_workers_done(self, workers): """Mark a group of workers as done to keep state consistent""" for worker in workers: if not worker.state.is_final(): worker.state.status = AssignState.STATUS_DONE def free_workers(self, workers): """End completed worker threads""" for worker in workers: self.socket_manager.close_channel(worker.get_connection_id()) ### Amazon MTurk Server Functions ### def get_agent_work_status(self, assignment_id): """Get the current status of an assignment's work""" client = mturk_utils.get_mturk_client(self.is_sandbox) try: response = client.get_assignment(AssignmentId=assignment_id) return response['Assignment']['AssignmentStatus'] except ClientError as e: # If the assignment isn't done, asking for the assignment will fail not_done_message = ('This operation can be called with a status ' 'of: Reviewable,Approved,Rejected') if not_done_message in e.response['Error']['Message']: return MTurkAgent.ASSIGNMENT_NOT_DONE def create_additional_hits(self, num_hits): """Handle creation for a specific number of hits/assignments Put created HIT ids into the hit_id_list """ shared_utils.print_and_log(logging.INFO, 'Creating {} hits...'.format(num_hits)) hit_type_id = mturk_utils.create_hit_type( hit_title=self.opt['hit_title'], hit_description='{} (ID: {})'.format(self.opt['hit_description'], self.task_group_id), hit_keywords=self.opt['hit_keywords'], hit_reward=self.opt['reward'], assignment_duration_in_seconds= # Set to 30 minutes by default self.opt.get('assignment_duration_in_seconds', 30 * 60), is_sandbox=self.opt['is_sandbox'] ) mturk_chat_url = '{}/chat_index?task_group_id={}'.format( self.server_url, self.task_group_id ) shared_utils.print_and_log(logging.INFO, mturk_chat_url) mturk_page_url = None if self.opt['unique_worker'] == True: # Use a single hit with many assignments to allow # workers to only work on the task once mturk_page_url, hit_id = mturk_utils.create_hit_with_hit_type( page_url=mturk_chat_url, hit_type_id=hit_type_id, num_assignments=num_hits, is_sandbox=self.is_sandbox ) self.hit_id_list.append(hit_id) else: # Create unique hits, allowing one worker to be able to handle many # tasks without needing to be unique for i in range(num_hits): mturk_page_url, hit_id = mturk_utils.create_hit_with_hit_type( page_url=mturk_chat_url, hit_type_id=hit_type_id, num_assignments=1, is_sandbox=self.is_sandbox ) self.hit_id_list.append(hit_id) return mturk_page_url def create_hits(self): """Create hits based on the managers current config, return hit url""" shared_utils.print_and_log(logging.INFO, 'Creating HITs...', True) mturk_page_url = self.create_additional_hits( num_hits=self.required_hits ) shared_utils.print_and_log(logging.INFO, 'Link to HIT: {}\n'.format(mturk_page_url), should_print=True) shared_utils.print_and_log( logging.INFO, 'Waiting for Turkers to respond... (Please don\'t close' ' your laptop or put your computer into sleep or standby mode.)\n', should_print=True ) return mturk_page_url def get_hit(self, hit_id): """Get hit from mturk by hit_id""" client = mturk_utils.get_mturk_client(self.is_sandbox) return client.get_hit(HITId=hit_id) def get_assignment(self, assignment_id): """Gets assignment from mturk by assignment_id. Only works if the assignment is in a completed state """ client = mturk_utils.get_mturk_client(self.is_sandbox) return client.get_assignment(AssignmentId=assignment_id) def expire_all_unassigned_hits(self): """Move through the whole hit_id list and attempt to expire the HITs, though this only immediately expires those that aren't assigned. """ shared_utils.print_and_log(logging.INFO, 'Expiring all unassigned HITs...', should_print=not self.is_test) for hit_id in self.hit_id_list: mturk_utils.expire_hit(self.is_sandbox, hit_id) def approve_work(self, assignment_id): """approve work for a given assignment through the mturk client""" client = mturk_utils.get_mturk_client(self.is_sandbox) client.approve_assignment(AssignmentId=assignment_id) def reject_work(self, assignment_id, reason): """reject work for a given assignment through the mturk client""" client = mturk_utils.get_mturk_client(self.is_sandbox) client.reject_assignment( AssignmentId=assignment_id, RequesterFeedback=reason ) def block_worker(self, worker_id, reason): """Block a worker by id using the mturk client, passes reason along""" client = mturk_utils.get_mturk_client(self.is_sandbox) client.create_worker_block(WorkerId=worker_id, Reason=reason) def pay_bonus(self, worker_id, bonus_amount, assignment_id, reason, unique_request_token): """Handles paying bonus to a turker, fails for insufficient funds. Returns True on success and False on failure """ total_cost = mturk_utils.calculate_mturk_cost( payment_opt={'type': 'bonus', 'amount': bonus_amount} ) if not mturk_utils.check_mturk_balance(balance_needed=total_cost, is_sandbox=self.is_sandbox): shared_utils.print_and_log( logging.WARN, 'Cannot pay bonus. Reason: Insufficient ' 'funds in your MTurk account.', should_print=True ) return False client = mturk_utils.get_mturk_client(self.is_sandbox) # unique_request_token may be useful for handling future network errors client.send_bonus( WorkerId=worker_id, BonusAmount=str(bonus_amount), AssignmentId=assignment_id, Reason=reason, UniqueRequestToken=unique_request_token ) shared_utils.print_and_log( logging.INFO, 'Paid ${} bonus to WorkerId: {}'.format( bonus_amount, worker_id ) ) return True def email_worker(self, worker_id, subject, message_text): """Send an email to a worker through the mturk client""" client = mturk_utils.get_mturk_client(self.is_sandbox) response = client.notify_workers( Subject=subject, MessageText=message_text, WorkerIds=[worker_id] ) if len(response['NotifyWorkersFailureStatuses']) > 0: failure_message = response['NotifyWorkersFailureStatuses'][0] return {'failure': failure_message['NotifyWorkersFailureMessage']} else: return {'success': True}
class MTurkManager(): """Manages interactions between MTurk agents as well as direct interactions between a world and the MTurk server. """ def __init__(self, opt, mturk_agent_ids): """Create an MTurkManager using the given setup opts and a list of agent_ids that will participate in each conversation """ self.opt = opt self.server_url = None self.port = 443 self.task_group_id = None self.run_id = None self.mturk_agent_ids = mturk_agent_ids self.task_files_to_copy = None self.is_sandbox = opt['is_sandbox'] self.worker_pool_change_condition = threading.Condition() self.onboard_function = None self.num_conversations = opt['num_conversations'] self.required_hits = math.ceil(self.num_conversations * len(self.mturk_agent_ids) * HIT_MULT) self.socket_manager = None ### Helpers and internal manager methods ### def _init_state(self): """Initialize everything in the worker, task, and thread states""" self.mturk_agents = {} self.hit_id_list = [] self.worker_pool = [] self.worker_index = 0 self.assignment_to_onboard_thread = {} self.task_threads = [] self.conversation_index = 0 self.started_conversations = 0 self.completed_conversations = 0 self.worker_state = {} self.conv_to_agent = {} self.accepting_workers = True self._load_disconnects() def _load_disconnects(self): """Load disconnects from file, populate the disconnects field for any worker_id that has disconnects in the list. Any disconnect that occurred longer ago than the disconnect persist length is ignored """ self.disconnects = [] # Load disconnects from file file_path = os.path.join(parent_dir, DISCONNECT_FILE_NAME) compare_time = time.time() if os.path.exists(file_path): with open(file_path, 'rb') as f: old_disconnects = pickle.load(f) self.disconnects = [ d for d in old_disconnects if (compare_time - d['time']) < DISCONNECT_PERSIST_LENGTH ] # Initialize worker states with proper number of disconnects for disconnect in self.disconnects: worker_id = disconnect['id'] if not worker_id in self.worker_state: # add this worker to the worker state self.worker_state[worker_id] = WorkerState(worker_id) self.worker_state[worker_id].disconnects += 1 def _save_disconnects(self): """Saves the local list of disconnects to file""" file_path = os.path.join(parent_dir, DISCONNECT_FILE_NAME) if os.path.exists(file_path): os.remove(file_path) with open(file_path, 'wb') as f: pickle.dump(self.disconnects, f, pickle.HIGHEST_PROTOCOL) def _handle_bad_disconnect(self, worker_id): """Update the number of bad disconnects for the given worker, block them if they've exceeded the disconnect limit """ self.worker_state[worker_id].disconnects += 1 self.disconnects.append({'time': time.time(), 'id': worker_id}) if self.worker_state[worker_id].disconnects > MAX_DISCONNECTS: text = ('This worker has repeatedly disconnected from these tasks,' ' which require constant connection to complete properly ' 'as they involve interaction with other Turkers. They have' ' been blocked to ensure a better experience for other ' 'workers who don\'t disconnect.') self.block_worker(worker_id, text) print_and_log( 'Worker {} was blocked - too many disconnects'.format( worker_id)) def _get_ids_from_pkt(self, pkt): """Get sender, assignment, and conv ids from a packet""" worker_id = pkt.sender_id assignment_id = pkt.assignment_id agent = self.mturk_agents[worker_id][assignment_id] conversation_id = agent.conversation_id return worker_id, assignment_id, conversation_id def _change_worker_to_conv(self, pkt): """Update a worker to a new conversation given a packet from the conversation to be switched to """ worker_id, assignment_id, conversation_id = self._get_ids_from_pkt(pkt) self._assign_agent_to_conversation( self.mturk_agents[worker_id][assignment_id], conversation_id) def _set_worker_status_to_onboard(self, pkt): """Changes assignment status to onboarding based on the packet""" worker_id, assignment_id, conversation_id = self._get_ids_from_pkt(pkt) assign_state = self.worker_state[worker_id].assignments[assignment_id] assign_state.status = AssignState.STATUS_ONBOARDING assign_state.conversation_id = conversation_id def _set_worker_status_to_waiting(self, pkt): """Changes assignment status to waiting based on the packet""" worker_id, assignment_id, conversation_id = self._get_ids_from_pkt(pkt) assign_state = self.worker_state[worker_id].assignments[assignment_id] assign_state.status = AssignState.STATUS_WAITING assign_state.conversation_id = conversation_id # Wait for turker to be in waiting status self._wait_for_status(assign_state, AssignState.STATUS_WAITING) # Add the worker to pool with self.worker_pool_change_condition: print("Adding worker to pool...") self.worker_pool.append( self.mturk_agents[worker_id][assignment_id]) def _move_workers_to_waiting(self, workers): """Put all workers into waiting worlds, expire them if no longer accepting workers. If the worker is already final, delete it """ for worker in workers: worker_id = worker.worker_id assignment_id = worker.assignment_id assignment = \ self.worker_state[worker_id].assignments[assignment_id] if assignment.is_final(): #This worker must've disconnected or expired, remove them del worker continue conversation_id = 'w_{}'.format(uuid.uuid4()) if self.accepting_workers: # Move the worker into a waiting world worker.change_conversation( conversation_id=conversation_id, agent_id='waiting', change_callback=self._set_worker_status_to_waiting) else: self.force_expire_hit(worker_id, assignment_id) def _wait_for_status(self, assign_state, desired_status): """Suspend a thread until a particular assignment state changes to the desired state """ while True: if assign_state.status == desired_status: break time.sleep(THREAD_SHORT_SLEEP) def _expire_onboarding_pool(self): """Expire any worker that is in an onboarding thread""" for worker_id in self.worker_state: for assign_id in self.worker_state[worker_id].assignments: assign = self.worker_state[worker_id].assignments[assign_id] if (assign.status == AssignState.STATUS_ONBOARDING): self.force_expire_hit(worker_id, assign_id) def _expire_worker_pool(self): """Expire all workers in the worker pool""" for agent in self.worker_pool: self.force_expire_hit(agent.worker_id, agent.assignment_id) def _get_unique_pool(self, eligibility_function): """Return a filtered version of the worker pool where each worker is only listed a maximum of one time. In sandbox this is overridden for testing purposes, and the same worker can be returned more than once """ workers = [ w for w in self.worker_pool if not w.hit_is_returned and eligibility_function(w) ] unique_workers = [] unique_worker_ids = [] for w in workers: if (self.is_sandbox) or (w.worker_id not in unique_worker_ids): unique_workers.append(w) unique_worker_ids.append(w.worker_id) return unique_workers def _handle_partner_disconnect(self, worker_id, assignment_id): """Send a message to a worker notifying them that a partner has disconnected and we marked the HIT as complete for them """ state = self.worker_state[worker_id].assignments[assignment_id] if not state.is_final(): # Update the assignment state agent = self.mturk_agents[worker_id][assignment_id] agent.some_agent_disconnected = True state.status = AssignState.STATUS_PARTNER_DISCONNECT # Create and send the command data = state.get_inactive_command_data(worker_id) self.send_command(worker_id, assignment_id, data) def _restore_worker_state(self, worker_id, assignment_id): """Send a command to restore the state of an agent who reconnected""" assignment = self.worker_state[worker_id].assignments[assignment_id] def _push_worker_state(msg): if len(assignment.messages) != 0: data = { 'text': data_model.COMMAND_RESTORE_STATE, 'messages': assignment.messages, 'last_command': assignment.last_command } self.send_command(worker_id, assignment_id, data) agent = self.mturk_agents[worker_id][assignment_id] agent.change_conversation(conversation_id=agent.conversation_id, agent_id=agent.id, change_callback=_push_worker_state) def _setup_socket(self): """Set up a socket_manager with defined callbacks""" self.socket_manager = SocketManager(self.server_url, self.port, self._on_alive, self._on_new_message, self._on_socket_dead, self.task_group_id) def _on_alive(self, pkt): """Update MTurkManager's state when a worker sends an alive packet. This asks the socket manager to open a new channel and then handles ensuring the worker state is consistent """ print_and_log('on_agent_alive: {}'.format(pkt), False) worker_id = pkt.data['worker_id'] hit_id = pkt.data['hit_id'] assign_id = pkt.data['assignment_id'] conversation_id = pkt.data['conversation_id'] # Open a channel if it doesn't already exist self.socket_manager.open_channel(worker_id, assign_id) if not worker_id in self.worker_state: # First time this worker has connected, start tracking self.worker_state[worker_id] = WorkerState(worker_id) # Update state of worker based on this connect curr_worker_state = self.worker_state[worker_id] if conversation_id and not curr_worker_state: # This was a request from a previous run and should be expired, # send a message and expire when it is acknowledged def _close_my_socket(data): """Small helper to close the socket after user acknowledges that it shouldn't exist""" self.socket_manager.close_channel(worker_id, assign_id) text = ('You disconnected in the middle of this HIT and the ' 'HIT expired before you reconnected. It is no longer ' 'available for completion. Please return this HIT and ' 'accept a new one if you would like to try again.') self.force_expire_hit(worker_id, assign_id, text, _close_my_socket) elif not assign_id: # invalid assignment_id is an auto-fail print_and_log( 'Agent ({}) with no assign_id called alive'.format(worker_id), False) elif not assign_id in curr_worker_state.assignments: # First time this worker has connected under this assignment, init # if we are still accepting workers if self.accepting_workers: convs = \ self.worker_state[worker_id].active_conversation_count() allowed_convs = self.opt['allowed_conversations'] if allowed_convs == 0 or convs < allowed_convs: curr_worker_state.add_assignment(assign_id) self._create_agent(hit_id, assign_id, worker_id) self._onboard_new_worker( self.mturk_agents[worker_id][assign_id]) else: text = ('You can participate in only {} of these HITs at ' 'once. Please return this HIT and finish your ' 'existing HITs before accepting more.'.format( allowed_convs)) self.force_expire_hit(worker_id, assign_id, text) else: self.force_expire_hit(worker_id, assign_id) else: curr_assign = curr_worker_state.assignments[assign_id] curr_assign.log_reconnect(worker_id) if curr_assign.status == AssignState.STATUS_NONE: # Reconnecting before even being given a world. The retries # for switching to the onboarding world should catch this return elif (curr_assign.status == AssignState.STATUS_ONBOARDING or curr_assign.status == AssignState.STATUS_WAITING): # Reconnecting to the onboarding world or to a waiting world # should either restore state or expire (if workers are no # longer being accepted for this task) if not self.accepting_workers: self.force_expire_hit(worker_id, assign_id) elif not conversation_id: self._restore_worker_state(worker_id, assign_id) elif curr_assign.status == AssignState.STATUS_IN_TASK: # Reconnecting to the onboarding world or to a task world # should resend the messages already in the conversation if not conversation_id: self._restore_worker_state(worker_id, assign_id) elif curr_assign.status == AssignState.STATUS_ASSIGNED: # Connect after a switch to a task world, mark the switch curr_assign.status = AssignState.STATUS_IN_TASK curr_assign.last_command = None curr_assign.messages = [] elif (curr_assign.status == AssignState.STATUS_DISCONNECT or curr_assign.status == AssignState.STATUS_DONE or curr_assign.status == AssignState.STATUS_EXPIRED or curr_assign.status == AssignState.STATUS_RETURNED or curr_assign.status == AssignState.STATUS_PARTNER_DISCONNECT): # inform the connecting user in all of these cases that the # task is no longer workable, use appropriate message data = curr_assign.get_inactive_command_data(worker_id) self.send_command(worker_id, assign_id, data) def _on_new_message(self, pkt): """Put an incoming message onto the correct agent's message queue and add it to the proper message thread """ worker_id = pkt.sender_id assignment_id = pkt.assignment_id curr_state = self.worker_state[worker_id].assignments[assignment_id] # Push the message to the message thread ready to send on a reconnect curr_state.messages.append(pkt.data) # Clear the send message command, as a message was recieved curr_state.last_command = None self.mturk_agents[worker_id][assignment_id].msg_queue.put(pkt.data) def _on_socket_dead(self, worker_id, assignment_id): """Handle a disconnect event, update state as required and notifying other agents if the disconnected agent was in conversation with them """ if (worker_id not in self.mturk_agents) or \ (assignment_id not in self.mturk_agents[worker_id]): # This worker never registered, so we don't do anything return True agent = self.mturk_agents[worker_id][assignment_id] agent.disconnected = True assignments = self.worker_state[worker_id].assignments status = assignments[assignment_id].status print_and_log('Worker {} disconnected from {} in status {}'.format( worker_id, assignment_id, status)) if status == AssignState.STATUS_NONE: # Agent never made it to onboarding, delete assignments[assignment_id].status = AssignState.STATUS_DISCONNECT del agent elif status == AssignState.STATUS_ONBOARDING: # Agent never made it to task pool, the onboarding thread will die # and delete the agent if we mark it as a disconnect assignments[assignment_id].status = AssignState.STATUS_DISCONNECT elif status == AssignState.STATUS_WAITING: # agent is in pool, remove from pool and delete if agent in self.worker_pool: with self.worker_pool_change_condition: self.worker_pool.remove(agent) assignments[assignment_id].status = AssignState.STATUS_DISCONNECT del agent elif status == AssignState.STATUS_IN_TASK: # Disconnect in conversation is not workable assignments[assignment_id].status = AssignState.STATUS_DISCONNECT # in conversation, inform others about disconnect conversation_id = assignments[assignment_id].conversation_id if agent in self.conv_to_agent[conversation_id]: for other_agent in self.conv_to_agent[conversation_id]: if agent.assignment_id != other_agent.assignment_id: self._handle_partner_disconnect( other_agent.worker_id, other_agent.assignment_id) if len(self.mturk_agent_ids) > 1: # The user disconnected from inside a conversation with # another turker, record this as bad behavoir self._handle_bad_disconnect(worker_id) elif (status == AssignState.STATUS_DONE or status == AssignState.STATUS_EXPIRED or status == AssignState.STATUS_DISCONNECT or status == AssignState.STATUS_PARTNER_DISCONNECT or status == AssignState.STATUS_RETURNED): # It's okay if a complete assignment socket dies, but wait for the # world to clean up the resource return True else: # A disconnect should be ignored in the assigned state, as we dont # check alive status when reconnecting after given an assignment return False self.socket_manager.close_channel(worker_id, assignment_id) return True def _create_agent(self, hit_id, assignment_id, worker_id): """Initialize an agent and add it to the map""" agent = MTurkAgent(self.opt, self, hit_id, assignment_id, worker_id) if (worker_id in self.mturk_agents): self.mturk_agents[worker_id][assignment_id] = agent else: self.mturk_agents[worker_id] = {} self.mturk_agents[worker_id][assignment_id] = agent def _onboard_new_worker(self, mturk_agent): """Handle creating an onboarding thread and moving an agent through the onboarding process, updating the state properly along the way """ # get state variable in question worker_id = mturk_agent.worker_id assignment_id = mturk_agent.assignment_id assign_state = self.worker_state[worker_id].assignments[assignment_id] def _onboard_function(mturk_agent): """Onboarding wrapper to set state to onboarding properly""" if self.onboard_function: conversation_id = 'o_' + str(uuid.uuid4()) mturk_agent.change_conversation( conversation_id=conversation_id, agent_id='onboarding', change_callback=self._set_worker_status_to_onboard) # Wait for turker to be in onboarding status self._wait_for_status(assign_state, AssignState.STATUS_ONBOARDING) # call onboarding function self.onboard_function(mturk_agent) # once onboarding is done, move into a waiting world self._move_workers_to_waiting([mturk_agent]) if not assignment_id in self.assignment_to_onboard_thread: # Start the onboarding thread and run it onboard_thread = threading.Thread(target=_onboard_function, args=(mturk_agent, )) onboard_thread.daemon = True onboard_thread.start() self.assignment_to_onboard_thread[assignment_id] = onboard_thread def _assign_agent_to_conversation(self, agent, conv_id): """Register an agent object with a conversation id, update status""" worker_id = agent.worker_id assignment_id = agent.assignment_id assign_state = self.worker_state[worker_id].assignments[assignment_id] if assign_state.status != AssignState.STATUS_IN_TASK: # Avoid on a second ack if alive already came through assign_state.status = AssignState.STATUS_ASSIGNED assign_state.conversation_id = conv_id if not conv_id in self.conv_to_agent: self.conv_to_agent[conv_id] = [] self.conv_to_agent[conv_id].append(agent) def _no_workers_incomplete(self, workers): """Return True if all the given workers completed their task""" for w in workers: state = self.worker_state[w.worker_id].assignments[w.assignment_id] if state.is_final() and state.status != AssignState.STATUS_DONE: return False return True ### Manager Lifecycle Functions ### def setup_server(self, task_directory_path=None): """Prepare the MTurk server for the new HIT we would like to submit""" completion_type = 'start' if self.opt['count_complete']: completion_type = 'finish' print_and_log( '\nYou are going to allow workers from Amazon ' 'Mechanical Turk to be an agent in ParlAI.\nDuring this ' 'process, Internet connection is required, and you should ' 'turn off your computer\'s auto-sleep feature.\n' 'Enough HITs will be created to fulfill {} times the number of ' 'conversations requested, extra HITs will be expired once the ' 'desired conversations {}.'.format(HIT_MULT, completion_type)) key_input = input('Please press Enter to continue... ') print_and_log('') setup_aws_credentials() # See if there's enough money in the account to fund the HITs requested num_assignments = self.required_hits payment_opt = { 'type': 'reward', 'num_total_assignments': num_assignments, 'reward': self.opt['reward'], # in dollars 'unique': self.opt['unique_worker'] } total_cost = calculate_mturk_cost(payment_opt=payment_opt) if not check_mturk_balance(balance_needed=total_cost, is_sandbox=self.opt['is_sandbox']): raise SystemExit('Insufficient funds') if total_cost > 100 or self.opt['reward'] > 1: confirm_string = '$%.2f' % total_cost print_and_log( 'You are going to create {} HITs at {} per assignment, for a ' 'total cost of {} after MTurk fees. Please enter "{}" to ' 'confirm and continue, and anything else to cancel'.format( self.required_hits, '$%.2f' % self.opt['reward'], confirm_string, confirm_string)) check = input('Enter here: ') if (check != confirm_string): raise SystemExit('Cancelling') print_and_log('Setting up MTurk server...') create_hit_config(task_description=self.opt['task_description'], unique_worker=self.opt['unique_worker'], is_sandbox=self.opt['is_sandbox']) # Poplulate files to copy over to the server if not self.task_files_to_copy: self.task_files_to_copy = [] if not task_directory_path: task_directory_path = os.path.join(self.opt['parlai_home'], 'parlai', 'mturk', 'tasks', self.opt['task']) self.task_files_to_copy.append( os.path.join(task_directory_path, 'html', 'cover_page.html')) for mturk_agent_id in self.mturk_agent_ids + ['onboarding']: self.task_files_to_copy.append( os.path.join(task_directory_path, 'html', '{}_index.html'.format(mturk_agent_id))) # Setup the server with a likely-unique app-name task_name = '{}-{}'.format(str(uuid.uuid4())[:8], self.opt['task']) self.server_task_name = \ ''.join(e for e in task_name if e.isalnum() or e == '-') self.server_url = \ setup_server(self.server_task_name, self.task_files_to_copy) print_and_log(self.server_url, False) print_and_log("MTurk server setup done.\n") def ready_to_accept_workers(self): """Set up socket to start communicating to workers""" print_and_log('Local: Setting up SocketIO...') self._setup_socket() def start_new_run(self): """Clear state to prepare for a new run""" self.run_id = str(int(time.time())) self.task_group_id = '{}_{}'.format(self.opt['task'], self.run_id) self._init_state() def set_onboard_function(self, onboard_function): self.onboard_function = onboard_function def start_task(self, eligibility_function, role_function, task_function): """Handle running a task by checking to see when enough agents are in the pool to start an instance of the task. Continue doing this until the desired number of conversations is had. """ def _task_function(opt, workers, conversation_id): """Wait for all workers to join world before running the task""" print('Starting task...') print('Waiting for all workers to join the conversation...') start_time = time.time() while True: all_joined = True for worker in workers: # check the status of an individual worker assignment worker_id = worker.worker_id assign_id = worker.assignment_id worker_state = self.worker_state[worker_id] if not assign_id in worker_state.assignments: # This assignment was removed, we should exit this loop print('At least one worker dropped before all joined!') return status = worker_state.assignments[assign_id].status if status != AssignState.STATUS_IN_TASK: all_joined = False if all_joined: break if time.time() - start_time > WORLD_START_TIMEOUT: # We waited but not all workers rejoined, throw workers # back into the waiting pool. Stragglers will disconnect # from there print('Timeout waiting for workers, move back to waiting') self._move_workers_to_waiting(workers) return time.sleep(THREAD_SHORT_SLEEP) print('All workers joined the conversation!') self.started_conversations += 1 task_function(mturk_manager=self, opt=opt, workers=workers) if self._no_workers_incomplete(workers): self.completed_conversations += 1 while True: # Loop forever starting task worlds until desired convos are had with self.worker_pool_change_condition: valid_workers = self._get_unique_pool(eligibility_function) needed_workers = len(self.mturk_agent_ids) if len(valid_workers) >= needed_workers: # enough workers in pool to start new conversation self.conversation_index += 1 new_conversation_id = \ 't_{}'.format(self.conversation_index) # Add the required number of valid workers to the conv selected_workers = [] for w in valid_workers[:needed_workers]: selected_workers.append(w) w.id = role_function(w) w.change_conversation( conversation_id=new_conversation_id, agent_id=w.id, change_callback=self._change_worker_to_conv) # Remove selected workers from the pool for worker in selected_workers: self.worker_pool.remove(worker) # Start a new thread for this task world task_thread = threading.Thread(target=_task_function, args=(self.opt, selected_workers, new_conversation_id)) task_thread.daemon = True task_thread.start() self.task_threads.append(task_thread) # Once we've had enough conversations, finish and break compare_count = self.started_conversations if (self.opt['count_complete']): compare_count = self.completed_conversations if compare_count == self.num_conversations: self.accepting_workers = False self.expire_all_unassigned_hits() self._expire_onboarding_pool() self._expire_worker_pool() # Wait for all conversations to finish, then break from # the while loop for thread in self.task_threads: thread.join() break time.sleep(THREAD_MEDIUM_SLEEP) def shutdown(self): """Handle any mturk client shutdown cleanup.""" # Ensure all threads are cleaned and state and HITs are handled self.expire_all_unassigned_hits() self._expire_onboarding_pool() self._expire_worker_pool() for assignment_id in self.assignment_to_onboard_thread: self.assignment_to_onboard_thread[assignment_id].join() self._save_disconnects() delete_server(self.server_task_name) ### MTurk Agent Interaction Functions ### def force_expire_hit(self, worker_id, assign_id, text=None, ack_func=None): """Send a command to expire a hit to the provided agent, update State to reflect that the HIT is now expired """ # Expire in the state is_final = True if worker_id in self.worker_state: if assign_id in self.worker_state[worker_id].assignments: state = self.worker_state[worker_id].assignments[assign_id] if not state.is_final(): is_final = False state.status = AssignState.STATUS_EXPIRED if not is_final: # Expire in the agent if worker_id in self.mturk_agents: if assign_id in self.mturk_agents[worker_id]: agent = self.mturk_agents[worker_id][assign_id] agent.hit_is_expired = True # Send the expiration command if text == None: text = ('This HIT is expired, please return and take a new ' 'one if you\'d want to work on this task.') data = {'text': data_model.COMMAND_EXPIRE_HIT, 'inactive_text': text} self.send_command(worker_id, assign_id, data, ack_func=ack_func) def send_message(self, receiver_id, assignment_id, data, blocking=True, ack_func=None): """Send a message through the socket manager, update conversation state """ data['type'] = data_model.MESSAGE_TYPE_MESSAGE # Force messages to have a unique ID if 'message_id' not in data: data['message_id'] = str(uuid.uuid4()) event_id = generate_event_id(receiver_id) packet = Packet(event_id, Packet.TYPE_MESSAGE, self.socket_manager.get_my_sender_id(), receiver_id, assignment_id, data, blocking=blocking, ack_func=ack_func) # Push outgoing message to the message thread to be able to resend # on a reconnect event assignment = self.worker_state[receiver_id].assignments[assignment_id] assignment.messages.append(packet.data) self.socket_manager.queue_packet(packet) def send_command(self, receiver_id, assignment_id, data, blocking=True, ack_func=None): """Sends a command through the socket manager, update conversation state """ data['type'] = data_model.MESSAGE_TYPE_COMMAND event_id = generate_event_id(receiver_id) packet = Packet(event_id, Packet.TYPE_MESSAGE, self.socket_manager.get_my_sender_id(), receiver_id, assignment_id, data, blocking=blocking, ack_func=ack_func) if (data['text'] != data_model.COMMAND_CHANGE_CONVERSATION and data['text'] != data_model.COMMAND_RESTORE_STATE and assignment_id in self.worker_state[receiver_id].assignments): # Append last command, as it might be necessary to restore state assign = self.worker_state[receiver_id].assignments[assignment_id] assign.last_command = packet.data self.socket_manager.queue_packet(packet) def mark_workers_done(self, workers): """Mark a group of workers as done to keep state consistent""" for worker in workers: worker_id = worker.worker_id assign_id = worker.assignment_id state = self.worker_state[worker_id].assignments[assign_id] if not state.is_final(): state.status = AssignState.STATUS_DONE def free_workers(self, workers): """End completed worker threads""" for worker in workers: worker_id = worker.worker_id assign_id = worker.assignment_id self.socket_manager.close_channel(worker_id, assign_id) ### Amazon MTurk Server Functions ### def get_agent_work_status(self, assignment_id): """Get the current status of an assignment's work""" client = get_mturk_client(self.is_sandbox) try: response = client.get_assignment(AssignmentId=assignment_id) return response['Assignment']['AssignmentStatus'] except ClientError as e: # If the assignment isn't done, asking for the assignment will fail not_done_message = ('This operation can be called with a status ' 'of: Reviewable,Approved,Rejected') if not_done_message in e.response['Error']['Message']: return MTurkAgent.ASSIGNMENT_NOT_DONE def create_additional_hits(self, num_hits): """Handle creation for a specific number of hits/assignments Put created HIT ids into the hit_id_list """ print_and_log('Creating {} hits...'.format(num_hits), False) hit_type_id = create_hit_type( hit_title=self.opt['hit_title'], hit_description='{} (ID: {})'.format(self.opt['hit_description'], self.task_group_id), hit_keywords=self.opt['hit_keywords'], hit_reward=self.opt['reward'], assignment_duration_in_seconds= # Set to 30 minutes by default self.opt.get('assignment_duration_in_seconds', 30 * 60), is_sandbox=self.opt['is_sandbox']) mturk_chat_url = '{}/chat_index?task_group_id={}'.format( self.server_url, self.task_group_id) print_and_log(mturk_chat_url, False) mturk_page_url = None if self.opt['unique_worker'] == True: # Use a single hit with many assignments to allow # workers to only work on the task once mturk_page_url, hit_id = create_hit_with_hit_type( page_url=mturk_chat_url, hit_type_id=hit_type_id, num_assignments=num_hits, is_sandbox=self.is_sandbox) self.hit_id_list.append(hit_id) else: # Create unique hits, allowing one worker to be able to handle many # tasks without needing to be unique for i in range(num_hits): mturk_page_url, hit_id = create_hit_with_hit_type( page_url=mturk_chat_url, hit_type_id=hit_type_id, num_assignments=1, is_sandbox=self.is_sandbox) self.hit_id_list.append(hit_id) return mturk_page_url def create_hits(self): """Create hits based on the managers current config, return hit url""" print_and_log('Creating HITs...') mturk_page_url = self.create_additional_hits( num_hits=self.required_hits) print_and_log('Link to HIT: {}\n'.format(mturk_page_url)) print_and_log( 'Waiting for Turkers to respond... (Please don\'t close' ' your laptop or put your computer into sleep or standby mode.)\n') return mturk_page_url def get_hit(self, hit_id): """Get hit from mturk by hit_id""" client = get_mturk_client(self.is_sandbox) return client.get_hit(HITId=hit_id) def get_assignment(self, assignment_id): """Gets assignment from mturk by assignment_id. Only works if the assignment is in a completed state """ client = get_mturk_client(self.is_sandbox) return client.get_assignment(AssignmentId=assignment_id) def expire_all_unassigned_hits(self): """Move through the whole hit_id list and attempt to expire the HITs, though this only immediately expires those that aren't assigned. """ print_and_log("Expiring all unassigned HITs...") for hit_id in self.hit_id_list: expire_hit(self.is_sandbox, hit_id) def approve_work(self, assignment_id): """approve work for a given assignment through the mturk client""" client = get_mturk_client(self.is_sandbox) client.approve_assignment(AssignmentId=assignment_id) def reject_work(self, assignment_id, reason): """reject work for a given assignment through the mturk client""" client = get_mturk_client(self.is_sandbox) client.reject_assignment(AssignmentId=assignment_id, RequesterFeedback=reason) def block_worker(self, worker_id, reason): """Block a worker by id using the mturk client, passes reason along""" client = get_mturk_client(self.is_sandbox) client.create_worker_block(WorkerId=worker_id, Reason=reason) def pay_bonus(self, worker_id, bonus_amount, assignment_id, reason, unique_request_token): """Handles paying bonus to a turker, fails for insufficient funds. Returns True on success and False on failure """ total_cost = calculate_mturk_cost(payment_opt={ 'type': 'bonus', 'amount': bonus_amount }) if not check_mturk_balance(balance_needed=total_cost, is_sandbox=self.is_sandbox): print_and_log('Cannot pay bonus. Reason: Insufficient funds' ' in your MTurk account.') return False client = get_mturk_client(self.is_sandbox) # unique_request_token may be useful for handling future network errors client.send_bonus(WorkerId=worker_id, BonusAmount=str(bonus_amount), AssignmentId=assignment_id, Reason=reason, UniqueRequestToken=unique_request_token) print_and_log('Paid ${} bonus to WorkerId: {}'.format( bonus_amount, worker_id)) return True def email_worker(self, worker_id, subject, message_text): """Send an email to a worker through the mturk client""" client = get_mturk_client(self.is_sandbox) response = client.notify_workers(Subject=subject, MessageText=message_text, WorkerIds=[worker_id]) if len(response['NotifyWorkersFailureStatuses']) > 0: failure_message = response['NotifyWorkersFailureStatuses'][0] return {'failure': failure_message['NotifyWorkersFailureMessage']} else: return {'success': True}
class TestSocketManagerRoutingFunctionality(unittest.TestCase): ID = 'ID' SENDER_ID = 'SENDER_ID' ASSIGNMENT_ID = 'ASSIGNMENT_ID' DATA = 'DATA' CONVERSATION_ID = 'CONVERSATION_ID' REQUIRES_ACK = True BLOCKING = False ACK_FUNCTION = 'ACK_FUNCTION' WORLD_ID = '[World_{}]'.format(TASK_GROUP_ID_1) def on_alive(self, packet): self.alive_packet = packet def on_message(self, packet): self.message_packet = packet def on_worker_death(self, worker_id, assignment_id): self.dead_worker_id = worker_id self.dead_assignment_id = assignment_id def on_server_death(self): self.server_died = True def setUp(self): self.AGENT_HEARTBEAT_PACKET = Packet( self.ID, Packet.TYPE_HEARTBEAT, self.SENDER_ID, self.WORLD_ID, self.ASSIGNMENT_ID, self.DATA, self.CONVERSATION_ID) self.AGENT_ALIVE_PACKET = Packet( MESSAGE_ID_1, Packet.TYPE_ALIVE, self.SENDER_ID, self.WORLD_ID, self.ASSIGNMENT_ID, self.DATA, self.CONVERSATION_ID) self.MESSAGE_SEND_PACKET_1 = Packet( MESSAGE_ID_2, Packet.TYPE_MESSAGE, self.WORLD_ID, self.SENDER_ID, self.ASSIGNMENT_ID, self.DATA, self.CONVERSATION_ID) self.MESSAGE_SEND_PACKET_2 = Packet( MESSAGE_ID_3, Packet.TYPE_MESSAGE, self.WORLD_ID, self.SENDER_ID, self.ASSIGNMENT_ID, self.DATA, self.CONVERSATION_ID, requires_ack=False) self.MESSAGE_SEND_PACKET_3 = Packet( MESSAGE_ID_4, Packet.TYPE_MESSAGE, self.WORLD_ID, self.SENDER_ID, self.ASSIGNMENT_ID, self.DATA, self.CONVERSATION_ID, blocking=False) self.fake_socket = MockSocket() time.sleep(0.3) self.alive_packet = None self.message_packet = None self.dead_worker_id = None self.dead_assignment_id = None self.server_died = False self.socket_manager = SocketManager( 'https://127.0.0.1', 3030, self.on_alive, self.on_message, self.on_worker_death, TASK_GROUP_ID_1, 1, self.on_server_death) def tearDown(self): self.socket_manager.shutdown() self.fake_socket.close() def test_init_state(self): '''Ensure all of the initial state of the socket_manager is ready''' self.assertEqual(self.socket_manager.server_url, 'https://127.0.0.1') self.assertEqual(self.socket_manager.port, 3030) self.assertEqual(self.socket_manager.alive_callback, self.on_alive) self.assertEqual(self.socket_manager.message_callback, self.on_message) self.assertEqual(self.socket_manager.socket_dead_callback, self.on_worker_death) self.assertEqual(self.socket_manager.task_group_id, TASK_GROUP_ID_1) self.assertEqual(self.socket_manager.missed_pongs, 1 + (1 / SocketManager.HEARTBEAT_RATE)) self.assertIsNotNone(self.socket_manager.ws) self.assertTrue(self.socket_manager.keep_running) self.assertIsNotNone(self.socket_manager.listen_thread) self.assertDictEqual(self.socket_manager.queues, {}) self.assertDictEqual(self.socket_manager.threads, {}) self.assertDictEqual(self.socket_manager.run, {}) self.assertDictEqual(self.socket_manager.last_sent_heartbeat_time, {}) self.assertDictEqual(self.socket_manager.last_received_heartbeat, {}) self.assertDictEqual(self.socket_manager.pongs_without_heartbeat, {}) self.assertDictEqual(self.socket_manager.packet_map, {}) self.assertTrue(self.socket_manager.alive) self.assertFalse(self.socket_manager.is_shutdown) self.assertEqual(self.socket_manager.get_my_sender_id(), self.WORLD_ID) def test_needed_heartbeat(self): '''Ensure needed heartbeat sends heartbeats at the right time''' self.socket_manager._safe_send = mock.MagicMock() connection_id = self.AGENT_HEARTBEAT_PACKET.get_sender_connection_id() # Ensure no failure under uninitialized cases self.socket_manager._send_needed_heartbeat(connection_id) self.socket_manager.last_received_heartbeat[connection_id] = None self.socket_manager._send_needed_heartbeat(connection_id) self.socket_manager._safe_send.assert_not_called() # assert not called when called too recently self.socket_manager.last_received_heartbeat[connection_id] = \ self.AGENT_HEARTBEAT_PACKET self.socket_manager.last_sent_heartbeat_time[connection_id] = \ time.time() + 10 self.socket_manager._send_needed_heartbeat(connection_id) self.socket_manager._safe_send.assert_not_called() # Assert called when supposed to self.socket_manager.last_sent_heartbeat_time[connection_id] = \ time.time() - SocketManager.HEARTBEAT_RATE self.assertGreater( time.time() - self.socket_manager.last_sent_heartbeat_time[connection_id], SocketManager.HEARTBEAT_RATE) self.socket_manager._send_needed_heartbeat(connection_id) self.assertLess( time.time() - self.socket_manager.last_sent_heartbeat_time[connection_id], SocketManager.HEARTBEAT_RATE) used_packet_json = self.socket_manager._safe_send.call_args[0][0] used_packet_dict = json.loads(used_packet_json) self.assertEqual( used_packet_dict['type'], data_model.SOCKET_ROUTE_PACKET_STRING) used_packet = Packet.from_dict(used_packet_dict['content']) self.assertNotEqual(self.AGENT_HEARTBEAT_PACKET.id, used_packet.id) self.assertEqual(used_packet.type, Packet.TYPE_HEARTBEAT) self.assertEqual(used_packet.sender_id, self.WORLD_ID) self.assertEqual(used_packet.receiver_id, self.SENDER_ID) self.assertEqual(used_packet.assignment_id, self.ASSIGNMENT_ID) self.assertEqual(used_packet.data, '') self.assertEqual(used_packet.conversation_id, self.CONVERSATION_ID) self.assertEqual(used_packet.requires_ack, False) self.assertEqual(used_packet.blocking, False) def test_ack_send(self): '''Ensure acks are being properly created and sent''' self.socket_manager._safe_send = mock.MagicMock() self.socket_manager._send_ack(self.AGENT_ALIVE_PACKET) used_packet_json = self.socket_manager._safe_send.call_args[0][0] used_packet_dict = json.loads(used_packet_json) self.assertEqual( used_packet_dict['type'], data_model.SOCKET_ROUTE_PACKET_STRING) used_packet = Packet.from_dict(used_packet_dict['content']) self.assertEqual(self.AGENT_ALIVE_PACKET.id, used_packet.id) self.assertEqual(used_packet.type, Packet.TYPE_ACK) self.assertEqual(used_packet.sender_id, self.WORLD_ID) self.assertEqual(used_packet.receiver_id, self.SENDER_ID) self.assertEqual(used_packet.assignment_id, self.ASSIGNMENT_ID) self.assertEqual(used_packet.conversation_id, self.CONVERSATION_ID) self.assertEqual(used_packet.requires_ack, False) self.assertEqual(used_packet.blocking, False) self.assertEqual(self.AGENT_ALIVE_PACKET.status, Packet.STATUS_SENT) def _send_packet_in_background(self, packet, send_time): '''creates a thread to handle waiting for a packet send''' def do_send(): self.socket_manager._send_packet( packet, packet.get_receiver_connection_id(), send_time ) self.sent = True send_thread = threading.Thread(target=do_send, daemon=True) send_thread.start() time.sleep(0.02) def test_blocking_ack_packet_send(self): '''Checks to see if ack'ed blocking packets are working properly''' self.socket_manager._safe_send = mock.MagicMock() self.socket_manager._safe_put = mock.MagicMock() self.sent = False # Test a blocking acknowledged packet send_time = time.time() self.assertEqual(self.MESSAGE_SEND_PACKET_1.status, Packet.STATUS_INIT) self._send_packet_in_background(self.MESSAGE_SEND_PACKET_1, send_time) self.assertEqual(self.MESSAGE_SEND_PACKET_1.status, Packet.STATUS_SENT) self.socket_manager._safe_send.assert_called_once() self.socket_manager._safe_put.assert_not_called() # Allow it to time out self.assertFalse(self.sent) time.sleep(0.5) self.assertTrue(self.sent) self.assertEqual(self.MESSAGE_SEND_PACKET_1.status, Packet.STATUS_INIT) self.socket_manager._safe_put.assert_called_once() call_args = self.socket_manager._safe_put.call_args[0] connection_id = call_args[0] queue_item = call_args[1] self.assertEqual( connection_id, self.MESSAGE_SEND_PACKET_1.get_receiver_connection_id()) self.assertEqual(queue_item[0], send_time) self.assertEqual(queue_item[1], self.MESSAGE_SEND_PACKET_1) self.socket_manager._safe_send.reset_mock() self.socket_manager._safe_put.reset_mock() # Send it again - end outcome should be a call to send only # with sent set self.sent = False self.assertEqual(self.MESSAGE_SEND_PACKET_1.status, Packet.STATUS_INIT) self._send_packet_in_background(self.MESSAGE_SEND_PACKET_1, send_time) self.assertEqual(self.MESSAGE_SEND_PACKET_1.status, Packet.STATUS_SENT) self.socket_manager._safe_send.assert_called_once() self.socket_manager._safe_put.assert_not_called() self.assertFalse(self.sent) self.MESSAGE_SEND_PACKET_1.status = Packet.STATUS_ACK time.sleep(0.1) self.assertTrue(self.sent) self.socket_manager._safe_put.assert_not_called() def test_non_blocking_ack_packet_send(self): '''Checks to see if ack'ed non-blocking packets are working''' self.socket_manager._safe_send = mock.MagicMock() self.socket_manager._safe_put = mock.MagicMock() self.sent = False # Test a blocking acknowledged packet send_time = time.time() self.assertEqual(self.MESSAGE_SEND_PACKET_3.status, Packet.STATUS_INIT) self._send_packet_in_background(self.MESSAGE_SEND_PACKET_3, send_time) self.assertEqual(self.MESSAGE_SEND_PACKET_3.status, Packet.STATUS_SENT) self.socket_manager._safe_send.assert_called_once() self.socket_manager._safe_put.assert_called_once() self.assertTrue(self.sent) call_args = self.socket_manager._safe_put.call_args[0] connection_id = call_args[0] queue_item = call_args[1] self.assertEqual( connection_id, self.MESSAGE_SEND_PACKET_3.get_receiver_connection_id()) expected_send_time = \ send_time + SocketManager.ACK_TIME[self.MESSAGE_SEND_PACKET_3.type] self.assertAlmostEqual(queue_item[0], expected_send_time, places=2) self.assertEqual(queue_item[1], self.MESSAGE_SEND_PACKET_3) used_packet_json = self.socket_manager._safe_send.call_args[0][0] used_packet_dict = json.loads(used_packet_json) self.assertEqual( used_packet_dict['type'], data_model.SOCKET_ROUTE_PACKET_STRING) self.assertDictEqual(used_packet_dict['content'], self.MESSAGE_SEND_PACKET_3.as_dict()) def test_non_ack_packet_send(self): '''Checks to see if non-ack'ed packets are working''' self.socket_manager._safe_send = mock.MagicMock() self.socket_manager._safe_put = mock.MagicMock() self.sent = False # Test a blocking acknowledged packet send_time = time.time() self.assertEqual(self.MESSAGE_SEND_PACKET_2.status, Packet.STATUS_INIT) self._send_packet_in_background(self.MESSAGE_SEND_PACKET_2, send_time) self.assertEqual(self.MESSAGE_SEND_PACKET_2.status, Packet.STATUS_SENT) self.socket_manager._safe_send.assert_called_once() self.socket_manager._safe_put.assert_not_called() self.assertTrue(self.sent) used_packet_json = self.socket_manager._safe_send.call_args[0][0] used_packet_dict = json.loads(used_packet_json) self.assertEqual( used_packet_dict['type'], data_model.SOCKET_ROUTE_PACKET_STRING) self.assertDictEqual(used_packet_dict['content'], self.MESSAGE_SEND_PACKET_2.as_dict()) def test_simple_packet_channel_management(self): '''Ensure that channels are created, managed, and then removed as expected ''' self.socket_manager._safe_put = mock.MagicMock() use_packet = self.MESSAGE_SEND_PACKET_1 worker_id = use_packet.receiver_id assignment_id = use_packet.assignment_id # Open a channel and assert it is there self.socket_manager.open_channel(worker_id, assignment_id) time.sleep(0.1) connection_id = use_packet.get_receiver_connection_id() self.assertTrue(self.socket_manager.run[connection_id]) socket_thread = self.socket_manager.threads[connection_id] self.assertTrue(socket_thread.isAlive()) self.assertIsNotNone(self.socket_manager.queues[connection_id]) self.assertEqual( self.socket_manager.last_sent_heartbeat_time[connection_id], 0) self.assertEqual( self.socket_manager.pongs_without_heartbeat[connection_id], 0) self.assertIsNone( self.socket_manager.last_received_heartbeat[connection_id]) self.assertTrue(self.socket_manager.socket_is_open(connection_id)) self.assertFalse(self.socket_manager.socket_is_open(FAKE_ID)) # Send a bad packet, ensure it is ignored resp = self.socket_manager.queue_packet(self.AGENT_ALIVE_PACKET) self.socket_manager._safe_put.assert_not_called() self.assertFalse(resp) self.assertNotIn(self.AGENT_ALIVE_PACKET.id, self.socket_manager.packet_map) # Send a packet to an open socket, ensure it got queued resp = self.socket_manager.queue_packet(use_packet) self.socket_manager._safe_put.assert_called_once() self.assertIn(use_packet.id, self.socket_manager.packet_map) self.assertTrue(resp) # Assert we can get the status of a packet in the map, but not # existing doesn't throw an error self.assertEqual(self.socket_manager.get_status(use_packet.id), use_packet.status) self.assertEqual(self.socket_manager.get_status(FAKE_ID), Packet.STATUS_NONE) # Assert that closing a thread does the correct cleanup work self.socket_manager.close_channel(connection_id) time.sleep(0.2) self.assertFalse(self.socket_manager.run[connection_id]) self.assertNotIn(connection_id, self.socket_manager.queues) self.assertNotIn(connection_id, self.socket_manager.threads) self.assertNotIn(use_packet.id, self.socket_manager.packet_map) self.assertFalse(socket_thread.isAlive()) # Assert that opening multiple threads and closing them is possible self.socket_manager.open_channel(worker_id, assignment_id) self.socket_manager.open_channel(worker_id + '2', assignment_id) time.sleep(0.1) self.assertEqual(len(self.socket_manager.queues), 2) self.socket_manager.close_all_channels() time.sleep(0.1) self.assertEqual(len(self.socket_manager.queues), 0) def test_safe_put(self): '''Test safe put and queue retrieval mechanisms''' self.socket_manager._send_packet = mock.MagicMock() use_packet = self.MESSAGE_SEND_PACKET_1 worker_id = use_packet.receiver_id assignment_id = use_packet.assignment_id connection_id = use_packet.get_receiver_connection_id() # Open a channel and assert it is there self.socket_manager.open_channel(worker_id, assignment_id) send_time = time.time() self.socket_manager._safe_put(connection_id, (send_time, use_packet)) # Wait for the sending thread to try to pull the packet from the queue time.sleep(0.3) # Ensure the right packet was popped and sent. self.socket_manager._send_packet.assert_called_once() call_args = self.socket_manager._send_packet.call_args[0] self.assertEqual(use_packet, call_args[0]) self.assertEqual(connection_id, call_args[1]) self.assertEqual(send_time, call_args[2]) self.socket_manager.close_all_channels() time.sleep(0.1) self.socket_manager._safe_put(connection_id, (send_time, use_packet)) self.assertEqual(use_packet.status, Packet.STATUS_FAIL)