def new_agent( self, worker_id: str, unit_id: str, task_id: str, task_run_id: str, assignment_id: str, task_type: str, provider_type: str, ) -> str: """ Wrapper around the new_agent call that finds and updates the unit locally too, as this isn't guaranteed otherwise but is an important part of the singleton """ agent_id = super().new_agent( worker_id, unit_id, task_id, task_run_id, assignment_id, task_type, provider_type, ) agent = Agent(self, agent_id) unit = agent.get_unit() unit.agent_id = agent_id unit.db_status = AssignmentState.ASSIGNED unit.worker_id = agent.worker_id return agent_id
def get_assigned_agent(self) -> Optional[Agent]: """ Get the agent assigned to this Unit if there is one, else return None """ # In these statuses, we know the agent isn't changing anymore, and thus will # not need to be re-queried if self.db_status in AssignmentState.final_unit(): if self.agent_id is None: return None return Agent.get(self.db, self.agent_id) # Query the database to get the most up-to-date assignment, as this can # change after instantiation if the Unit status isn't final unit_copy = Unit.get(self.db, self.db_id) self.agent_id = unit_copy.agent_id if self.agent_id is not None: return Agent.get(self.db, self.agent_id) return None
def make_completed_unit(db: MephistoDB) -> str: """ Creates a completed unit for the most recently created task run using some worker. Assumes """ workers = db.find_workers() assert len(workers) > 0, "Must have at least one worker in database" worker = workers[-1] task_runs = db.find_task_runs(is_completed=False) assert len(task_runs) > 0, "Must be at least one incomplete task run" task_run = task_runs[-1] assign_id = db.new_assignment( task_run.task_id, task_run.db_id, task_run.requester_id, task_run.task_type, task_run.provider_type, ) unit_id = db.new_unit( task_run.task_id, task_run.db_id, task_run.requester_id, assign_id, 0, 0.2, task_run.provider_type, task_run.task_type, ) agent_id = db.new_agent( worker.db_id, unit_id, task_run.task_id, task_run.db_id, assign_id, task_run.task_type, task_run.provider_type, ) agent = Agent(db, agent_id) agent.mark_done() unit = Unit(db, unit_id) unit.sync_status() return unit.db_id
def get_assigned_agent(self) -> Optional[Agent]: """ Get the agent assigned to this Unit if there is one, else return None """ # In these statuses, we know the agent isn't changing anymore, and thus will # not need to be re-queried # TODO(#97) add test to ensure this behavior/assumption holds always if self.db_status in AssignmentState.final_unit(): if self.agent_id is None: return None return Agent(self.db, self.agent_id) # Query the database to get the most up-to-date assignment, as this can # change after instantiation if the Unit status isn't final # TODO(#101) this may not be particularly efficient row = self.db.get_unit(self.db_id) assert row is not None, f"Unit {self.db_id} stopped existing in the db..." agent_id = row["agent_id"] if agent_id is not None: return Agent(self.db, agent_id) return None
def test_agent(self) -> None: """Test creation and querying of agents""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db # Check creation and retrieval of a agent worker_name, worker_id = get_test_worker(db) unit_id = get_test_unit(db) unit = Unit(db, unit_id) agent_id = db.new_agent( worker_id, unit_id, unit.task_id, unit.task_run_id, unit.assignment_id, unit.task_type, unit.provider_type, ) self.assertIsNotNone(agent_id) self.assertTrue(isinstance(agent_id, str)) agent_row = db.get_agent(agent_id) self.assertEqual(agent_row["worker_id"], worker_id) self.assertEqual(agent_row["unit_id"], unit_id) self.assertEqual(agent_row["status"], AgentState.STATUS_NONE) # ensure the unit is assigned now units = db.find_units(status=AssignmentState.ASSIGNED) self.assertEqual(len(units), 1) agent = Agent(db, agent_id) self.assertEqual(agent.worker_id, worker_id) # Check finding for agents agents = db.find_agents() self.assertEqual(len(agents), 1) self.assertTrue(isinstance(agents[0], Agent)) self.assertEqual(agents[0].db_id, agent_id) self.assertEqual(agents[0].worker_id, worker_id) # Check finding for specific agents agents = db.find_agents(worker_id=worker_id) self.assertEqual(len(agents), 1) self.assertTrue(isinstance(agents[0], Agent)) self.assertEqual(agents[0].db_id, agent_id) self.assertEqual(agents[0].worker_id, worker_id) agents = db.find_agents(worker_id=self.get_fake_id("Worker")) self.assertEqual(len(agents), 0)
def find_agents( self, status: Optional[str] = None, unit_id: Optional[str] = None, worker_id: Optional[str] = None, task_id: Optional[str] = None, task_run_id: Optional[str] = None, assignment_id: Optional[str] = None, task_type: Optional[str] = None, provider_type: Optional[str] = None, ) -> List[Agent]: """ Try to find any agent that matches the above. When called with no arguments, return all agents. """ with self.table_access_condition: conn = self._get_connection() c = conn.cursor() c.execute( """ SELECT * from agents WHERE (?1 IS NULL OR status = ?1) AND (?2 IS NULL OR unit_id = ?2) AND (?3 IS NULL OR worker_id = ?3) AND (?4 IS NULL OR task_id = ?4) AND (?5 IS NULL OR task_run_id = ?5) AND (?6 IS NULL OR assignment_id = ?6) AND (?7 IS NULL OR task_type = ?7) AND (?8 IS NULL OR provider_type = ?8) """, ( status, nonesafe_int(unit_id), nonesafe_int(worker_id), nonesafe_int(task_id), nonesafe_int(task_run_id), nonesafe_int(assignment_id), task_type, provider_type, ), ) rows = c.fetchall() return [Agent(self, str(r["agent_id"]), row=r) for r in rows]
def test_agent_fails(self) -> None: """Ensure agents fail to be created or loaded under failure conditions""" assert self.db is not None, "No db initialized" db: MephistoDB = self.db # Cant get non-existent entry with self.assertRaises(EntryDoesNotExistException): agent = Agent(db, self.get_fake_id("Agent")) unit_id = get_test_unit(db) worker_name, worker_id = get_test_worker(db) unit = Unit(db, unit_id) # Can't use invalid worker id with self.assertRaises(EntryDoesNotExistException): agent_id = db.new_agent( self.get_fake_id("Worker"), unit_id, unit.task_id, unit.task_run_id, unit.assignment_id, unit.task_type, unit.provider_type, ) # Can't use invalid unit id with self.assertRaises(EntryDoesNotExistException): agent_id = db.new_agent( worker_id, self.get_fake_id("Unit"), unit.task_id, unit.task_run_id, unit.assignment_id, unit.task_type, unit.provider_type, ) # Ensure no agents were created agents = db.find_agents() self.assertEqual(len(agents), 0)