コード例 #1
0
 def new_agent(
     self,
     worker_id: str,
     unit_id: str,
     task_id: str,
     task_run_id: str,
     assignment_id: str,
     task_type: str,
     provider_type: str,
 ) -> str:
     """
     Wrapper around the new_agent call that finds and updates the unit locally
     too, as this isn't guaranteed otherwise but is an important part of the singleton
     """
     agent_id = super().new_agent(
         worker_id,
         unit_id,
         task_id,
         task_run_id,
         assignment_id,
         task_type,
         provider_type,
     )
     agent = Agent(self, agent_id)
     unit = agent.get_unit()
     unit.agent_id = agent_id
     unit.db_status = AssignmentState.ASSIGNED
     unit.worker_id = agent.worker_id
     return agent_id
コード例 #2
0
ファイル: unit.py プロジェクト: facebookresearch/Mephisto
    def get_assigned_agent(self) -> Optional[Agent]:
        """
        Get the agent assigned to this Unit if there is one, else return None
        """
        # In these statuses, we know the agent isn't changing anymore, and thus will
        # not need to be re-queried
        if self.db_status in AssignmentState.final_unit():
            if self.agent_id is None:
                return None
            return Agent.get(self.db, self.agent_id)

        # Query the database to get the most up-to-date assignment, as this can
        # change after instantiation if the Unit status isn't final
        unit_copy = Unit.get(self.db, self.db_id)
        self.agent_id = unit_copy.agent_id
        if self.agent_id is not None:
            return Agent.get(self.db, self.agent_id)
        return None
コード例 #3
0
ファイル: utils.py プロジェクト: chateval/Mephisto
def make_completed_unit(db: MephistoDB) -> str:
    """
    Creates a completed unit for the most recently created task run
    using some worker. Assumes
    """
    workers = db.find_workers()
    assert len(workers) > 0, "Must have at least one worker in database"
    worker = workers[-1]
    task_runs = db.find_task_runs(is_completed=False)
    assert len(task_runs) > 0, "Must be at least one incomplete task run"
    task_run = task_runs[-1]
    assign_id = db.new_assignment(
        task_run.task_id,
        task_run.db_id,
        task_run.requester_id,
        task_run.task_type,
        task_run.provider_type,
    )
    unit_id = db.new_unit(
        task_run.task_id,
        task_run.db_id,
        task_run.requester_id,
        assign_id,
        0,
        0.2,
        task_run.provider_type,
        task_run.task_type,
    )
    agent_id = db.new_agent(
        worker.db_id,
        unit_id,
        task_run.task_id,
        task_run.db_id,
        assign_id,
        task_run.task_type,
        task_run.provider_type,
    )
    agent = Agent(db, agent_id)
    agent.mark_done()
    unit = Unit(db, unit_id)
    unit.sync_status()
    return unit.db_id
コード例 #4
0
    def get_assigned_agent(self) -> Optional[Agent]:
        """
        Get the agent assigned to this Unit if there is one, else return None
        """
        # In these statuses, we know the agent isn't changing anymore, and thus will
        # not need to be re-queried
        # TODO(#97) add test to ensure this behavior/assumption holds always
        if self.db_status in AssignmentState.final_unit():
            if self.agent_id is None:
                return None
            return Agent(self.db, self.agent_id)

        # Query the database to get the most up-to-date assignment, as this can
        # change after instantiation if the Unit status isn't final
        # TODO(#101) this may not be particularly efficient
        row = self.db.get_unit(self.db_id)
        assert row is not None, f"Unit {self.db_id} stopped existing in the db..."
        agent_id = row["agent_id"]
        if agent_id is not None:
            return Agent(self.db, agent_id)
        return None
コード例 #5
0
    def test_agent(self) -> None:
        """Test creation and querying of agents"""
        assert self.db is not None, "No db initialized"
        db: MephistoDB = self.db

        # Check creation and retrieval of a agent
        worker_name, worker_id = get_test_worker(db)
        unit_id = get_test_unit(db)
        unit = Unit(db, unit_id)

        agent_id = db.new_agent(
            worker_id,
            unit_id,
            unit.task_id,
            unit.task_run_id,
            unit.assignment_id,
            unit.task_type,
            unit.provider_type,
        )
        self.assertIsNotNone(agent_id)
        self.assertTrue(isinstance(agent_id, str))
        agent_row = db.get_agent(agent_id)
        self.assertEqual(agent_row["worker_id"], worker_id)
        self.assertEqual(agent_row["unit_id"], unit_id)
        self.assertEqual(agent_row["status"], AgentState.STATUS_NONE)

        # ensure the unit is assigned now
        units = db.find_units(status=AssignmentState.ASSIGNED)
        self.assertEqual(len(units), 1)

        agent = Agent(db, agent_id)
        self.assertEqual(agent.worker_id, worker_id)

        # Check finding for agents
        agents = db.find_agents()
        self.assertEqual(len(agents), 1)
        self.assertTrue(isinstance(agents[0], Agent))
        self.assertEqual(agents[0].db_id, agent_id)
        self.assertEqual(agents[0].worker_id, worker_id)

        # Check finding for specific agents
        agents = db.find_agents(worker_id=worker_id)
        self.assertEqual(len(agents), 1)
        self.assertTrue(isinstance(agents[0], Agent))
        self.assertEqual(agents[0].db_id, agent_id)
        self.assertEqual(agents[0].worker_id, worker_id)

        agents = db.find_agents(worker_id=self.get_fake_id("Worker"))
        self.assertEqual(len(agents), 0)
コード例 #6
0
 def find_agents(
     self,
     status: Optional[str] = None,
     unit_id: Optional[str] = None,
     worker_id: Optional[str] = None,
     task_id: Optional[str] = None,
     task_run_id: Optional[str] = None,
     assignment_id: Optional[str] = None,
     task_type: Optional[str] = None,
     provider_type: Optional[str] = None,
 ) -> List[Agent]:
     """
     Try to find any agent that matches the above. When called with no arguments,
     return all agents.
     """
     with self.table_access_condition:
         conn = self._get_connection()
         c = conn.cursor()
         c.execute(
             """
             SELECT * from agents
             WHERE (?1 IS NULL OR status = ?1)
             AND (?2 IS NULL OR unit_id = ?2)
             AND (?3 IS NULL OR worker_id = ?3)
             AND (?4 IS NULL OR task_id = ?4)
             AND (?5 IS NULL OR task_run_id = ?5)
             AND (?6 IS NULL OR assignment_id = ?6)
             AND (?7 IS NULL OR task_type = ?7)
             AND (?8 IS NULL OR provider_type = ?8)
             """,
             (
                 status,
                 nonesafe_int(unit_id),
                 nonesafe_int(worker_id),
                 nonesafe_int(task_id),
                 nonesafe_int(task_run_id),
                 nonesafe_int(assignment_id),
                 task_type,
                 provider_type,
             ),
         )
         rows = c.fetchall()
         return [Agent(self, str(r["agent_id"]), row=r) for r in rows]
コード例 #7
0
    def test_agent_fails(self) -> None:
        """Ensure agents fail to be created or loaded under failure conditions"""
        assert self.db is not None, "No db initialized"
        db: MephistoDB = self.db

        # Cant get non-existent entry
        with self.assertRaises(EntryDoesNotExistException):
            agent = Agent(db, self.get_fake_id("Agent"))

        unit_id = get_test_unit(db)
        worker_name, worker_id = get_test_worker(db)
        unit = Unit(db, unit_id)

        # Can't use invalid worker id
        with self.assertRaises(EntryDoesNotExistException):
            agent_id = db.new_agent(
                self.get_fake_id("Worker"),
                unit_id,
                unit.task_id,
                unit.task_run_id,
                unit.assignment_id,
                unit.task_type,
                unit.provider_type,
            )

        # Can't use invalid unit id
        with self.assertRaises(EntryDoesNotExistException):
            agent_id = db.new_agent(
                worker_id,
                self.get_fake_id("Unit"),
                unit.task_id,
                unit.task_run_id,
                unit.assignment_id,
                unit.task_type,
                unit.provider_type,
            )

        # Ensure no agents were created
        agents = db.find_agents()
        self.assertEqual(len(agents), 0)