def example_mixed_agents():
    return [
        AgentContext(
            agent_id=AgentId('i-gpu'),
            agent_addr='10.0.1.1:6001',
            scaling_group='sg01',
            available_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('4.0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
            }),
        ),
        AgentContext(
            agent_id=AgentId('i-cpu'),
            agent_addr='10.0.2.1:6001',
            scaling_group='sg02',
            available_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('2560'),
                'cuda.shares': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
            }),
        ),
    ]
Exemple #2
0
    async def test_query_agents(self, create_app_and_client, get_headers):
        app, client = await create_app_and_client(
            modules=['auth', 'admin', 'manager'])

        # Add test agents info to db (all ALIVE)
        async with app['dbpool'].acquire() as conn:
            for i in range(2):
                query = agents.insert().values({
                    'id':
                    f'test-agent-id-{i}',
                    'status':
                    AgentStatus.ALIVE,
                    'region':
                    'local',
                    'available_slots':
                    ResourceSlot({
                        'cpu': '1',
                        'mem': '1073741824',
                    }),
                    'occupied_slots':
                    ResourceSlot({
                        'cpu': '0',
                        'mem': '0',
                    }),
                    'addr':
                    '127.0.0.1',
                    'lost_at':
                    None,
                })
                await conn.execute(query)

        query = '{ agents { status region addr } }'
        payload = json.dumps({'query': query}).encode()
        headers = get_headers('POST', self.url, payload)
        ret = await client.post(self.url, data=payload, headers=headers)

        assert ret.status == 200
        rsp_json = await ret.json()
        assert rsp_json['agents'][0]['status'] == 'ALIVE'
        assert rsp_json['agents'][0]['region'] == 'local'
        assert rsp_json['agents'][0]['addr'] == '127.0.0.1'
        assert rsp_json['agents'][1]['status'] == 'ALIVE'
        assert rsp_json['agents'][1]['region'] == 'local'
        assert rsp_json['agents'][1]['addr'] == '127.0.0.1'

        # query with status
        query = '{ agents(status: "LOST") { status region addr } }'
        payload = json.dumps({'query': query}).encode()
        headers = get_headers('POST', self.url, payload)
        ret = await client.post(self.url, data=payload, headers=headers)

        assert ret.status == 200
        rsp_json = await ret.json()
        assert len(rsp_json['agents']) == 0
Exemple #3
0
def example_pending_sessions():
    # lower indicies are enqueued first.
    return [
        PendingSession(  # rocm
            kernel_id=pending_kernel_ids[0],
            access_key=AccessKey('user01'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('2.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cuda
            kernel_id=pending_kernel_ids[1],
            access_key=AccessKey('user02'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cpu-only
            kernel_id=pending_kernel_ids[2],
            access_key=AccessKey('user03'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
    ]
async def recalculate_usage(request) -> web.Response:
    '''
    Update `keypairs.c.concurrency_used` and `agents.c.occupied_slots`.

    Those two values are sometimes out of sync. In that case, calling this API
    re-calculates the values for running containers and updates them in DB.
    '''
    async with request.app['dbpool'].acquire() as conn, conn.begin():
        # Query running containers and calculate concurrency_used per AK and
        # occupied_slots per agent.
        query = (sa.select([
            kernels.c.access_key, kernels.c.agent, kernels.c.occupied_slots
        ]).where(kernels.c.status != KernelStatus.TERMINATED).order_by(
            sa.asc(kernels.c.access_key)))
        concurrency_used_per_key: MutableMapping[str,
                                                 int] = defaultdict(lambda: 0)
        occupied_slots_per_agent: MutableMapping[str, ResourceSlot] = \
            defaultdict(lambda: ResourceSlot({'cpu': 0, 'mem': 0}))
        async for row in conn.execute(query):
            concurrency_used_per_key[row.access_key] += 1
            occupied_slots_per_agent[row.agent] += ResourceSlot(
                row.occupied_slots)

        # Update concurrency_used for keypairs with running containers.
        for ak, used in concurrency_used_per_key.items():
            query = (sa.update(keypairs).values(concurrency_used=used).where(
                keypairs.c.access_key == ak))
            await conn.execute(query)
        # Update all other keypairs to have concurrency_used = 0.
        query = (sa.update(keypairs).values(concurrency_used=0).where(
            keypairs.c.concurrency_used != 0).where(
                sa.not_(
                    keypairs.c.access_key.in_(
                        concurrency_used_per_key.keys()))))
        await conn.execute(query)

        # Update occupied_slots for agents with running containers.
        for aid, slots in occupied_slots_per_agent.items():
            query = (sa.update(agents).values(occupied_slots=slots).where(
                agents.c.id == aid))
            await conn.execute(query)
        # Update all other agents to have empty occupied_slots.
        query = (sa.update(agents).values(
            occupied_slots=ResourceSlot({})).where(
                agents.c.status == AgentStatus.ALIVE).where(
                    sa.not_(agents.c.id.in_(occupied_slots_per_agent.keys()))))
        await conn.execute(query)
    return web.json_response({}, status=200)
 async def resolve_occupied_slots(
         self, info: graphene.ResolveInfo) -> Mapping[str, Any]:
     """
     Calculate the sum of occupied resource slots of all sub-kernels,
     and return the JSON-serializable object from the sum result.
     """
     manager = info.context['dlmgr']
     loader = manager.get_loader('ComputeContainer.by_session')
     containers = await loader.load(self.session_id)
     zero = ResourceSlot()
     return sum(
         (ResourceSlot(
             {SlotName(k): Decimal(v)
              for k, v in c.occupied_slots.items()}) for c in containers),
         start=zero,
     ).to_json()
async def _unreserve_agent_slots(
    db_conn: SAConnection,
    session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]],
) -> None:
    # Un-reserve agent slots, using the db transaction of the current invocation context.
    keyfunc = lambda item: item.agent_alloc_ctx.agent_id
    for agent_id, kernel_agent_bindings in itertools.groupby(
        sorted(session_agent_binding[1], key=keyfunc), key=keyfunc
    ):
        per_agent_requested_slots = sum(
            (binding.kernel.requested_slots for binding in kernel_agent_bindings),
            start=ResourceSlot(),
        )
        query = (
            sa.select([agents.c.occupied_slots], for_update=True)
            .select_from(agents)
            .where(agents.c.id == agent_id))
        current_occupied_slots = await db_conn.scalar(query)
        query = (
            sa.update(agents)
            .values({
                'occupied_slots': current_occupied_slots - per_agent_requested_slots
            })
            .where(agents.c.id == agent_id))
        await db_conn.execute(query)
Exemple #7
0
def example_existing_sessions():
    return [
        ExistingSession(
            kernel_id=existing_kernel_ids[0],
            access_key=AccessKey('user01'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernel_id=existing_kernel_ids[1],
            access_key=AccessKey('user02'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernel_id=existing_kernel_ids[2],
            access_key=AccessKey('user03'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
    ]
Exemple #8
0
    async def get_image_slot_ranges(self, image_ref: ImageRef):
        '''
        Returns the minimum and maximum ResourceSlot values.
        All slot values are converted and normalized to Decimal.
        '''
        data = await self.etcd.get_prefix_dict(image_ref.tag_path)
        slot_units = await self.get_resource_slots()
        min_slot = ResourceSlot()
        max_slot = ResourceSlot()

        for slot_key, slot_range in data['resource'].items():
            slot_unit = slot_units.get(slot_key)
            if slot_unit is None:
                # ignore unknown slots
                continue
            min_value = slot_range.get('min')
            if min_value is None:
                min_value = Decimal(0)
            max_value = slot_range.get('max')
            if max_value is None:
                max_value = Decimal('Infinity')
            if slot_unit == 'bytes':
                if not isinstance(min_value, Decimal):
                    min_value = BinarySize.from_str(min_value)
                if not isinstance(max_value, Decimal):
                    max_value = BinarySize.from_str(max_value)
            else:
                if not isinstance(min_value, Decimal):
                    min_value = Decimal(min_value)
                if not isinstance(max_value, Decimal):
                    max_value = Decimal(max_value)
            min_slot[slot_key] = min_value
            max_slot[slot_key] = max_value

        # fill missing
        for slot_key in slot_units.keys():
            if slot_key not in min_slot:
                min_slot[slot_key] = Decimal(0)
            if slot_key not in max_slot:
                max_slot[slot_key] = Decimal('Infinity')

        return min_slot, max_slot
def example_agents_no_valid():
    return [
        AgentContext(
            agent_id=AgentId('i-001'),
            agent_addr='10.0.1.1:6001',
            scaling_group='sg01',
            available_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('4.0'),
                'rocm.devices': Decimal('2'),
            }),
        ),
        AgentContext(
            agent_id=AgentId('i-101'),
            agent_addr='10.0.2.1:6001',
            scaling_group='sg02',
            available_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('2560'),
                'cuda.shares': Decimal('1.0'),
                'rocm.devices': Decimal('8'),
            }),
        ),
    ]
Exemple #10
0
async def recalc_agent_resource_occupancy(db_conn: SAConnection,
                                          agent_id: AgentId) -> None:
    query = (sa.select([
        kernels.c.occupied_slots,
    ]).select_from(kernels).where((kernels.c.agent == agent_id) & (
        kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))))
    occupied_slots = ResourceSlot()
    result = await db_conn.execute(query)
    async for row in result:
        occupied_slots += row['occupied_slots']
    print(f"ag:{agent_id}'s new occupied_slots:{occupied_slots}")
    query = (sa.update(agents).values({
        'occupied_slots': occupied_slots,
    }).where(agents.c.id == agent_id))
    await db_conn.execute(query)
Exemple #11
0
 def read_from_string(cls, text: str) -> 'KernelResourceSpec':
     kvpairs = {}
     for line in text.split('\n'):
         if '=' not in line:
             continue
         key, val = line.strip().split('=', maxsplit=1)
         kvpairs[key] = val
     allocations = cast(
         MutableMapping[DeviceName, MutableMapping[SlotName,
                                                   Mapping[DeviceId,
                                                           Decimal]]],
         defaultdict(lambda: defaultdict(Decimal)),
     )
     for key, val in kvpairs.items():
         if key.endswith('_SHARES'):
             slot_name = SlotName(key[:-7].lower())
             device_name = DeviceName(slot_name.split('.')[0])
             per_device_alloc: MutableMapping[DeviceId, Decimal] = {}
             for entry in val.split(','):
                 raw_dev_id, _, raw_alloc = entry.partition(':')
                 if not raw_dev_id or not raw_alloc:
                     continue
                 dev_id = DeviceId(raw_dev_id)
                 try:
                     if known_slot_types.get(slot_name, 'count') == 'bytes':
                         alloc = Decimal(BinarySize.from_str(raw_alloc))
                     else:
                         alloc = Decimal(raw_alloc)
                 except KeyError as e:
                     log.warning(
                         'A previously launched container has '
                         'unknown slot type: {}. Ignoring it.', e.args[0])
                     continue
                 per_device_alloc[dev_id] = alloc
             allocations[device_name][slot_name] = per_device_alloc
     mounts = [Mount.from_str(m) for m in kvpairs['MOUNTS'].split(',') if m]
     return cls(
         container_id=kvpairs.get('CID', 'unknown'),
         scratch_disk_size=BinarySize.finite_from_str(
             kvpairs['SCRATCH_SIZE']),
         allocations=dict(allocations),
         slots=ResourceSlot(json.loads(kvpairs['SLOTS'])),
         mounts=mounts,
     )
    async def _schedule_in_sgroup(
        self,
        sched_ctx: SchedulingContext,
        agent_db_conn: SAConnection,
        kernel_db_conn: SAConnection,
        sgroup_name: str,
    ) -> List[StartTaskArgs]:
        async with kernel_db_conn.begin():
            scheduler = await self._load_scheduler(kernel_db_conn, sgroup_name)
            pending_sessions = await _list_pending_sessions(
                kernel_db_conn, sgroup_name)
            existing_sessions = await _list_existing_sessions(
                kernel_db_conn, sgroup_name)
        log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})',
                  sgroup_name, len(pending_sessions), len(existing_sessions))
        zero = ResourceSlot()
        args_list: List[StartTaskArgs] = []
        while len(pending_sessions) > 0:
            async with agent_db_conn.begin():
                candidate_agents = await _list_agents_by_sgroup(
                    agent_db_conn, sgroup_name)
                total_capacity = sum(
                    (ag.available_slots for ag in candidate_agents), zero)

            picked_session_id = scheduler.pick_session(
                total_capacity,
                pending_sessions,
                existing_sessions,
            )
            if picked_session_id is None:
                # no session is picked.
                # continue to next sgroup.
                return []
            for picked_idx, sess_ctx in enumerate(pending_sessions):
                if sess_ctx.session_id == picked_session_id:
                    break
            else:
                # no matching entry for picked session?
                raise RuntimeError('should not reach here')
            sess_ctx = pending_sessions.pop(picked_idx)

            log_fmt = 'schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): '
            log_args = (
                sess_ctx.session_id,
                sess_ctx.session_type,
                sess_ctx.session_name,
                sess_ctx.access_key,
                sess_ctx.cluster_mode,
            )
            _log_fmt.set(log_fmt)
            _log_args.set(log_args)
            log.debug(log_fmt + 'try-scheduling', *log_args)
            session_agent_binding: Tuple[PendingSession,
                                         List[KernelAgentBinding]]

            async with kernel_db_conn.begin():
                predicates: Sequence[Tuple[
                    str, Awaitable[PredicateResult]]] = [
                        (
                            'reserved_time',
                            check_reserved_batch_session(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        ('concurrency',
                         check_concurrency(kernel_db_conn, sched_ctx,
                                           sess_ctx)),
                        ('dependencies',
                         check_dependencies(kernel_db_conn, sched_ctx,
                                            sess_ctx)),
                        (
                            'keypair_resource_limit',
                            check_keypair_resource_limit(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        (
                            'user_group_resource_limit',
                            check_group_resource_limit(kernel_db_conn,
                                                       sched_ctx, sess_ctx),
                        ),
                        (
                            'domain_resource_limit',
                            check_domain_resource_limit(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        (
                            'scaling_group_resource_limit',
                            check_scaling_group(kernel_db_conn, sched_ctx,
                                                sess_ctx),
                        ),
                    ]
                check_results: List[Tuple[str, Union[Exception,
                                                     PredicateResult]]] = []
                for predicate_name, check_coro in predicates:
                    try:
                        check_results.append((predicate_name, await
                                              check_coro))
                    except Exception as e:
                        log.exception(log_fmt + 'predicate-error', *log_args)
                        check_results.append((predicate_name, e))
            has_failure = False
            # has_permanent_failure = False
            failed_predicates = []
            passed_predicates = []
            for predicate_name, result in check_results:
                if isinstance(result, Exception):
                    has_failure = True
                    failed_predicates.append({
                        'name': predicate_name,
                        'msg': repr(result),
                    })
                    continue
                if result.passed:
                    passed_predicates.append({
                        'name': predicate_name,
                    })
                else:
                    failed_predicates.append({
                        'name': predicate_name,
                        'msg': result.message or "",
                    })
                    has_failure = True
                    # if result.permanent:
                    #     has_permanent_failure = True
            if has_failure:
                log.debug(log_fmt + 'predicate-checks-failed (temporary)',
                          *log_args)
                # TODO: handle has_permanent_failure as cancellation
                #  - An early implementation of it has caused DB query blocking due to
                #    the inclusion of the kernels.status field. :(
                #    Let's fix it.
                async with kernel_db_conn.begin():
                    await _invoke_failure_callbacks(
                        kernel_db_conn,
                        sched_ctx,
                        sess_ctx,
                        check_results,
                    )
                    query = kernels.update().values({
                        'status_info':
                        "predicate-checks-failed",
                        'status_data':
                        sql_json_increment(
                            kernels.c.status_data, ('scheduler', 'retries'),
                            parent_updates={
                                'last_try': datetime.now(tzutc()).isoformat(),
                                'failed_predicates': failed_predicates,
                                'passed_predicates': passed_predicates,
                            }),
                    }).where(kernels.c.id == sess_ctx.session_id)
                    await kernel_db_conn.execute(query)
                # Predicate failures are *NOT* permanent errors.
                # We need to retry the scheduling afterwards.
                continue
            else:
                async with kernel_db_conn.begin():
                    query = kernels.update().values({
                        'status_data':
                        sql_json_merge(
                            kernels.c.status_data, ('scheduler', ), {
                                'last_try': datetime.now(tzutc()).isoformat(),
                                'failed_predicates': failed_predicates,
                                'passed_predicates': passed_predicates,
                            }),
                    }).where(kernels.c.id == sess_ctx.session_id)
                    await kernel_db_conn.execute(query)

            if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE:
                session_agent_binding = await self._schedule_single_node_session(
                    sched_ctx,
                    scheduler,
                    agent_db_conn,
                    kernel_db_conn,
                    sgroup_name,
                    candidate_agents,
                    sess_ctx,
                    check_results,
                )
            elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE:
                session_agent_binding = await self._schedule_multi_node_session(
                    sched_ctx,
                    scheduler,
                    agent_db_conn,
                    kernel_db_conn,
                    sgroup_name,
                    candidate_agents,
                    sess_ctx,
                    check_results,
                )
            else:
                raise RuntimeError(
                    f"should not reach here; unknown cluster_mode: {sess_ctx.cluster_mode}"
                )
            args_list.append((
                log_args,
                sched_ctx,
                session_agent_binding,
                check_results,
            ))

        return args_list
async def _list_existing_sessions(
    db_conn: SAConnection,
    sgroup: str,
) -> List[ExistingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.startup_command,
            kernels.c.internal_data,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
            (kernels.c.scaling_group == sgroup)
        )
        .order_by(kernels.c.created_at)
    )
    items: MutableMapping[str, ExistingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = ExistingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                occupying_slots=ResourceSlot(),
            )
            items[row['session_id']] = session
        # TODO: support multi-container sessions
        session.kernels.append(KernelInfo(  # type: ignore
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=None,
            startup_command=None,
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.occupying_slots += row['occupied_slots']  # type: ignore
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    # TODO: extend for multi-container sessions
    items: MutableMapping[str, PendingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = PendingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                resource_policy=row['resource_policy'],
                resource_opts={},
                requested_slots=ResourceSlot(),
                internal_data=row['internal_data'],
                target_sgroup_names=[],
                environ={
                    k: v for k, v
                    in map(lambda s: s.split('=', maxsplit=1), row['environ'])
                },
                mounts=row['mounts'],
                mount_map=row['mount_map'],
                bootstrap_script=row['bootstrap_script'],
                startup_command=row['startup_command'],
                preopen_ports=row['preopen_ports'],
            )
            items[row['session_id']] = session
        session.kernels.append(KernelInfo(
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.requested_slots += row['occupied_slots']  # type: ignore
        merge_resource(session.resource_opts, row['resource_opts'])  # type: ignore
        async def _schedule_in_sgroup(db_conn: SAConnection, sgroup_name: str) -> None:
            async with db_conn.begin():
                scheduler = await self._load_scheduler(db_conn, sgroup_name)
                pending_sessions = await _list_pending_sessions(db_conn, sgroup_name)
                existing_sessions = await _list_existing_sessions(db_conn, sgroup_name)
            log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})',
                      sgroup_name, len(pending_sessions), len(existing_sessions))
            zero = ResourceSlot()
            while len(pending_sessions) > 0:
                async with db_conn.begin():
                    candidate_agents = await _list_agents_by_sgroup(db_conn, sgroup_name)
                    total_capacity = sum((ag.available_slots for ag in candidate_agents), zero)

                picked_session_id = scheduler.pick_session(
                    total_capacity,
                    pending_sessions,
                    existing_sessions,
                )
                if picked_session_id is None:
                    # no session is picked.
                    # continue to next sgroup.
                    return
                for picked_idx, sess_ctx in enumerate(pending_sessions):
                    if sess_ctx.session_id == picked_session_id:
                        break
                else:
                    # no matching entry for picked session?
                    raise RuntimeError('should not reach here')
                sess_ctx = pending_sessions.pop(picked_idx)

                log_args = (
                    sess_ctx.session_id,
                    sess_ctx.session_type,
                    sess_ctx.session_name,
                    sess_ctx.access_key,
                    sess_ctx.cluster_mode,
                )
                log.debug(log_fmt + 'try-scheduling', *log_args)
                session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]]

                async with db_conn.begin():
                    predicates: Sequence[Awaitable[PredicateResult]] = [
                        check_reserved_batch_session(db_conn, sched_ctx, sess_ctx),
                        check_concurrency(db_conn, sched_ctx, sess_ctx),
                        check_dependencies(db_conn, sched_ctx, sess_ctx),
                        check_keypair_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_group_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_domain_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_scaling_group(db_conn, sched_ctx, sess_ctx),
                    ]
                    check_results: List[Union[Exception, PredicateResult]] = []
                    for check in predicates:
                        try:
                            check_results.append(await check)
                        except Exception as e:
                            log.exception(log_fmt + 'predicate-error', *log_args)
                            check_results.append(e)
                    has_failure = False
                    for result in check_results:
                        if isinstance(result, Exception):
                            has_failure = True
                            continue
                        if not result.passed:
                            has_failure = True
                    if has_failure:
                        log.debug(log_fmt + 'predicate-checks-failed', *log_args)
                        await _invoke_failure_callbacks(
                            db_conn, sched_ctx, sess_ctx, check_results,
                        )
                        # Predicate failures are *NOT* permanent errors.
                        # We need to retry the scheduling afterwards.
                        continue

                    if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE:
                        # Assign agent resource per session.
                        try:
                            agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx)
                            if agent_id is None:
                                raise InstanceNotAvailable
                            agent_alloc_ctx = await _reserve_agent(
                                sched_ctx, db_conn, sgroup_name, agent_id, sess_ctx.requested_slots,
                            )
                        except InstanceNotAvailable:
                            log.debug(log_fmt + 'no-available-instances', *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        except Exception:
                            log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                          *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        query = kernels.update().values({
                            'agent': agent_alloc_ctx.agent_id,
                            'agent_addr': agent_alloc_ctx.agent_addr,
                            'scaling_group': sgroup_name,
                            'status': KernelStatus.PREPARING,
                            'status_info': 'scheduled',
                            'status_changed': datetime.now(tzutc()),
                        }).where(kernels.c.session_id == sess_ctx.session_id)
                        await db_conn.execute(query)
                        session_agent_binding = (
                            sess_ctx,
                            [
                                KernelAgentBinding(kernel, agent_alloc_ctx)
                                for kernel in sess_ctx.kernels
                            ],
                        )
                    elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE:
                        # Assign agent resource per kernel in the session.
                        agent_query_extra_conds = None
                        if len(sess_ctx.kernels) >= 2:
                            # We should use agents that supports overlay networking.
                            agent_query_extra_conds = (agents.c.clusterized)
                        kernel_agent_bindings = []
                        for kernel in sess_ctx.kernels:
                            try:
                                agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel)
                                if agent_id is None:
                                    raise InstanceNotAvailable
                                agent_alloc_ctx = await _reserve_agent(
                                    sched_ctx, db_conn, sgroup_name, agent_id, kernel.requested_slots,
                                    extra_conds=agent_query_extra_conds,
                                )
                            except InstanceNotAvailable:
                                log.debug(log_fmt + 'no-available-instances', *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            except Exception:
                                log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                              *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            # TODO: if error occurs for one kernel, should we cancel all others?
                            query = kernels.update().values({
                                'agent': agent_alloc_ctx.agent_id,
                                'agent_addr': agent_alloc_ctx.agent_addr,
                                'scaling_group': sgroup_name,
                                'status': KernelStatus.PREPARING,
                                'status_info': 'scheduled',
                                'status_changed': datetime.now(tzutc()),
                            }).where(kernels.c.id == kernel.kernel_id)
                            await db_conn.execute(query)
                            kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx))
                        session_agent_binding = (sess_ctx, kernel_agent_bindings)
                    start_task_args.append(
                        (
                            log_args,
                            sched_ctx,
                            session_agent_binding,
                            check_results,
                        )
                    )
async def check_presets(request: web.Request, params: Any) -> web.Response:
    '''
    Returns the list of all resource presets in the current scaling group,
    with additional information including allocatability of each preset,
    amount of total remaining resources, and the current keypair resource limits.
    '''
    try:
        access_key = request['keypair']['access_key']
        resource_policy = request['keypair']['resource_policy']
        # TODO: uncomment when we implement scaling group.
        # scaling_group = request.query.get('scaling_group')
        # assert scaling_group is not None, 'scaling_group parameter is missing.'
    except (json.decoder.JSONDecodeError, AssertionError) as e:
        raise InvalidAPIParameters(extra_msg=str(e.args[0]))
    registry = request.app['registry']
    known_slot_types = await registry.config_server.get_resource_slots()
    keypair_limits = ResourceSlot.from_policy(resource_policy,
                                              known_slot_types)
    resp: MutableMapping[str, Any] = {
        'keypair_limits': None,
        'keypair_using': None,
        'keypair_remaining': None,
        'scaling_group_remaining': None,
        'presets': [],
    }
    async with request.app['dbpool'].acquire() as conn, conn.begin():
        keypair_occupied = await registry.get_keypair_occupancy(access_key,
                                                                conn=conn)
        keypair_remaining = keypair_limits - keypair_occupied
        resp['keypair_limits'] = keypair_limits.to_json()
        resp['keypair_using'] = keypair_occupied.to_json()
        resp['keypair_remaining'] = keypair_remaining.to_json()
        # query all agent's capacity and occupancy
        agent_slots = []

        j = sa.join(groups, association_groups_users,
                    association_groups_users.c.group_id == groups.c.id)
        query = (sa.select(
            [association_groups_users.c.group_id]).select_from(j).where(
                (association_groups_users.c.user_id == request['user']['uuid'])
                & (groups.c.name == params['group'])))
        group_id = await conn.scalar(query)
        if group_id is None:
            raise InvalidAPIParameters('Unknown user group')

        sgroups = await query_allowed_sgroups(conn,
                                              request['user']['domain_name'],
                                              group_id, access_key)
        sgroups = [sg.name for sg in sgroups]
        if params['scaling_group'] is not None:
            if params['scaling_group'] not in sgroups:
                raise InvalidAPIParameters('Unknown scaling group')
            sgroups = [params['scaling_group']]

        sgroup_remaining = ResourceSlot(
            {k: Decimal(0)
             for k in known_slot_types.keys()})
        query = (sa.select([
            agents.c.available_slots, agents.c.occupied_slots
        ]).select_from(agents).where((agents.c.status == AgentStatus.ALIVE)
                                     & (agents.c.scaling_group.in_(sgroups))))
        async for row in conn.execute(query):
            remaining = row['available_slots'] - row['occupied_slots']
            sgroup_remaining += remaining
            agent_slots.append(remaining)
        resp['scaling_group_remaining'] = sgroup_remaining.to_json()
        # fetch all resource presets in the current scaling group.
        query = (sa.select([resource_presets]).select_from(resource_presets))
        async for row in conn.execute(query):
            # check if there are any agent that can allocate each preset
            allocatable = False
            preset_slots = row['resource_slots'].filter_slots(known_slot_types)
            for agent_slot in agent_slots:
                if agent_slot >= preset_slots and keypair_remaining >= preset_slots:
                    allocatable = True
                    break
            resp['presets'].append({
                'name': row['name'],
                'resource_slots': preset_slots.to_json(),
                'allocatable': allocatable,
            })
    return web.json_response(resp, status=200)
def example_existing_sessions():
    return [
        ExistingSession(
            kernels=[
                KernelInfo(
                    kernel_id=existing_session_kernel_ids[0].kernel_ids[0],
                    session_id=existing_session_kernel_ids[0].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('1.0'),
                        'mem': Decimal('512'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
                KernelInfo(
                    kernel_id=existing_session_kernel_ids[0].kernel_ids[1],
                    session_id=existing_session_kernel_ids[0].session_id,
                    access_key='dummy-access-key',
                    cluster_role='sub',
                    cluster_idx=2,
                    cluster_hostname="sub1",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('2.0'),
                        'mem': Decimal('512'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('1'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user01'),
            session_id=existing_session_kernel_ids[0].session_id,
            session_name='es01',
            session_type=SessionTypes.BATCH,
            cluster_mode='single-node',
            cluster_size=2,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernels=[
                KernelInfo(
                    kernel_id=existing_session_kernel_ids[1].kernel_ids[0],
                    session_id=existing_session_kernel_ids[1].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('1.0'),
                        'mem': Decimal('2048'),
                        'cuda.shares': Decimal('0.5'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user02'),
            session_id=existing_session_kernel_ids[1].session_id,
            session_type=SessionTypes.BATCH,
            session_name='es01',
            cluster_mode='single-node',
            cluster_size=1,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernels=[
                KernelInfo(
                    kernel_id=existing_session_kernel_ids[2].kernel_ids[0],
                    session_id=existing_session_kernel_ids[2].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('4.0'),
                        'mem': Decimal('4096'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user03'),
            session_id=existing_session_kernel_ids[2].session_id,
            session_type=SessionTypes.BATCH,
            session_name='es01',
            cluster_mode='single-node',
            cluster_size=1,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
    ]
def example_pending_sessions():
    # lower indicies are enqueued first.
    return [
        PendingSession(  # rocm
            kernels=[
                KernelInfo(
                    kernel_id=pending_session_kernel_ids[0].kernel_ids[0],
                    session_id=pending_session_kernel_ids[0].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('2.0'),
                        'mem': Decimal('1024'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('1'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user01'),
            session_id=pending_session_kernel_ids[0].session_id,
            session_name='es01',
            session_type=SessionTypes.BATCH,
            cluster_mode='single-node',
            cluster_size=1,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('2.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cuda
            kernels=[
                KernelInfo(
                    kernel_id=pending_session_kernel_ids[1].kernel_ids[0],
                    session_id=pending_session_kernel_ids[1].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('1.0'),
                        'mem': Decimal('2048'),
                        'cuda.shares': Decimal('0.5'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user02'),
            session_id=pending_session_kernel_ids[1].session_id,
            session_name='es01',
            session_type=SessionTypes.BATCH,
            cluster_mode='single-node',
            cluster_size=1,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cpu-only
            kernels=[
                KernelInfo(
                    kernel_id=pending_session_kernel_ids[2].kernel_ids[0],
                    session_id=pending_session_kernel_ids[2].session_id,
                    access_key='dummy-access-key',
                    cluster_role=DEFAULT_ROLE,
                    cluster_idx=1,
                    cluster_hostname=f"{DEFAULT_ROLE}0",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('0.4'),
                        'mem': Decimal('512'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
                KernelInfo(
                    kernel_id=pending_session_kernel_ids[2].kernel_ids[1],
                    session_id=pending_session_kernel_ids[2].session_id,
                    access_key='dummy-access-key',
                    cluster_role='sub',
                    cluster_idx=2,
                    cluster_hostname="sub1",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('0.3'),
                        'mem': Decimal('256'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
                KernelInfo(
                    kernel_id=pending_session_kernel_ids[2].kernel_ids[2],
                    session_id=pending_session_kernel_ids[2].session_id,
                    access_key='dummy-access-key',
                    cluster_role='sub',
                    cluster_idx=3,
                    cluster_hostname="sub2",
                    image_ref=common_image_ref,
                    resource_opts={},
                    requested_slots=ResourceSlot({
                        'cpu': Decimal('0.3'),
                        'mem': Decimal('256'),
                        'cuda.shares': Decimal('0'),
                        'rocm.devices': Decimal('0'),
                    }),
                    bootstrap_script=None,
                    startup_command=None,
                ),
            ],
            access_key=AccessKey('user03'),
            session_id=pending_session_kernel_ids[2].session_id,
            session_name='es01',
            session_type=SessionTypes.BATCH,
            cluster_mode='single-node',
            cluster_size=3,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
    ]
from ai.backend.manager.scheduler.fifo import FIFOSlotScheduler, LIFOSlotScheduler
from ai.backend.manager.scheduler.drf import DRFScheduler
from ai.backend.manager.scheduler.mof import MOFScheduler
from ai.backend.manager.scheduler.predicates import check_reserved_batch_session


def test_load_intrinsic():
    assert isinstance(load_scheduler('fifo', {}), FIFOSlotScheduler)
    assert isinstance(load_scheduler('lifo', {}), LIFOSlotScheduler)
    assert isinstance(load_scheduler('drf', {}), DRFScheduler)
    assert isinstance(load_scheduler('mof', {}), MOFScheduler)


example_group_id = uuid4()

example_total_capacity = ResourceSlot({'cpu': '4.0', 'mem': '4096'})


@pytest.fixture
def example_agents():
    return [
        AgentContext(
            agent_id=AgentId('i-001'),
            agent_addr='10.0.1.1:6001',
            scaling_group='sg01',
            available_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('4.0'),
                'rocm.devices': Decimal('2'),
            }),
Exemple #20
0
async def test_handle_heartbeat(mocker):
    mock_shared_config = MagicMock()
    mock_shared_config.update_resource_slots = AsyncMock()
    mock_shared_config.etcd = None
    mock_dbpool = MagicMock()
    mock_dbconn = MagicMock()
    mock_dbconn_ctx = MagicMock()
    mock_dbtxn_ctx = MagicMock()
    mock_dbresult = MagicMock()
    mock_dbresult.rowcount = 1
    mock_dbpool.acquire = MagicMock(return_value=mock_dbconn_ctx)
    mock_dbconn_ctx.__aenter__ = AsyncMock(return_value=mock_dbconn)
    mock_dbconn_ctx.__aexit__ = AsyncMock()
    mock_dbconn.execute = AsyncMock(return_value=mock_dbresult)
    mock_dbconn.begin = MagicMock(return_value=mock_dbtxn_ctx)
    mock_dbtxn_ctx.__aenter__ = AsyncMock()
    mock_dbtxn_ctx.__aexit__ = AsyncMock()
    mock_redis_stat = MagicMock()
    mock_redis_live = MagicMock()
    mock_redis_live.hset = AsyncMock()
    mock_redis_image = MagicMock()
    mock_event_dispatcher = MagicMock()
    mock_event_dispatcher.produce_event = AsyncMock()
    mock_get_known_registries = AsyncMock(return_value=[
        {
            'index.docker.io': 'https://registry-1.docker.io'
        },
    ])
    mocker.patch('ai.backend.manager.registry.get_known_registries',
                 mock_get_known_registries)
    mock_redis_wrapper = MagicMock()
    mock_redis_wrapper.execute_with_retries = AsyncMock()
    mocker.patch('ai.backend.manager.registry.redis', mock_redis_wrapper)
    image_data = snappy.compress(
        msgpack.packb([
            ('index.docker.io/lablup/python:3.6-ubuntu18.04', ),
        ]))

    def mocked_entrypoints(entry_point_group: str):
        return []

    mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points',
                 mocked_entrypoints)
    mocked_etcd = DummyEtcd()
    # mocker.object.patch(mocked_etcd, 'get_prefix', AsyncMock(return_value={}))
    hook_plugin_ctx = HookPluginContext(mocked_etcd, {})

    registry = AgentRegistry(
        shared_config=mock_shared_config,
        dbpool=mock_dbpool,
        redis_stat=mock_redis_stat,
        redis_live=mock_redis_live,
        redis_image=mock_redis_image,
        event_dispatcher=mock_event_dispatcher,
        storage_manager=None,
        hook_plugin_ctx=hook_plugin_ctx,
    )
    await registry.init()

    # Join
    mock_dbresult.first = AsyncMock(return_value=None)
    await registry.handle_heartbeat(
        'i-001', {
            'scaling_group': 'sg-testing',
            'resource_slots': {
                'cpu': ('count', '1'),
                'mem': ('bytes', '1g')
            },
            'region': 'ap-northeast-2',
            'addr': '10.0.0.5',
            'version': '19.12.0',
            'compute_plugins': [],
            'images': image_data,
        })
    mock_shared_config.update_resource_slots.assert_awaited_once()
    q = mock_dbconn.execute.await_args_list[1].args[0]
    assert isinstance(q, Insert)

    # Update alive instance
    mock_shared_config.update_resource_slots.reset_mock()
    mock_dbconn.execute.reset_mock()
    mock_dbresult.first = AsyncMock(
        return_value={
            'status': AgentStatus.ALIVE,
            'addr': '10.0.0.5',
            'scaling_group': 'sg-testing',
            'available_slots': ResourceSlot({
                'cpu': '1',
                'mem': '1g'
            }),
        })
    await registry.handle_heartbeat(
        'i-001', {
            'scaling_group': 'sg-testing',
            'resource_slots': {
                'cpu': ('count', '1'),
                'mem': ('bytes', '2g')
            },
            'region': 'ap-northeast-2',
            'addr': '10.0.0.6',
            'version': '19.12.0',
            'compute_plugins': [],
            'images': image_data,
        })
    mock_shared_config.update_resource_slots.assert_awaited_once()
    q = mock_dbconn.execute.await_args_list[1].args[0]
    assert isinstance(q, Update)
    assert q.parameters['addr'] == '10.0.0.6'
    assert q.parameters['available_slots'] == ResourceSlot({
        'cpu': '1',
        'mem': '2g'
    })
    assert 'scaling_group' not in q.parameters

    # Rejoin
    mock_shared_config.update_resource_slots.reset_mock()
    mock_dbconn.execute.reset_mock()
    mock_dbresult.first = AsyncMock(
        return_value={
            'status': AgentStatus.LOST,
            'addr': '10.0.0.5',
            'scaling_group': 'sg-testing',
            'available_slots': ResourceSlot({
                'cpu': '1',
                'mem': '1g'
            }),
        })
    await registry.handle_heartbeat(
        'i-001', {
            'scaling_group': 'sg-testing2',
            'resource_slots': {
                'cpu': ('count', '4'),
                'mem': ('bytes', '2g')
            },
            'region': 'ap-northeast-2',
            'addr': '10.0.0.6',
            'version': '19.12.0',
            'compute_plugins': [],
            'images': image_data,
        })
    mock_shared_config.update_resource_slots.assert_awaited_once()
    q = mock_dbconn.execute.await_args_list[1].args[0]
    assert isinstance(q, Update)
    assert q.parameters['status'] == AgentStatus.ALIVE
    assert q.parameters['addr'] == '10.0.0.6'
    assert q.parameters['lost_at'] is None
    assert q.parameters['available_slots'] == ResourceSlot({
        'cpu': '4',
        'mem': '2g'
    })
    assert q.parameters['scaling_group'] == 'sg-testing2'
    assert 'compute_plugins' in q.parameters
    assert 'version' in q.parameters
def upgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table(
        'resource_presets',
        sa.Column('name', sa.String(length=256), nullable=False),
        sa.Column('resource_slots', ResourceSlotColumn(), nullable=False),
        sa.PrimaryKeyConstraint('name', name=op.f('pk_resource_presets')))
    # Add initial fixtures for resource presets
    query = '''
    INSERT INTO resource_presets
    VALUES (
        'small',
        '{"cpu":"1","mem":"2147483648"}'::jsonb
    );
    INSERT INTO resource_presets
    VALUES (
        'small-gpu',
        '{"cpu":"1","mem":"2147483648","cuda.device":"1","cuda.shares":"0.5"}'::jsonb
    );
    INSERT INTO resource_presets
    VALUES (
        'medium',
        '{"cpu":"2","mem":"4294967296"}'::jsonb
    );
    INSERT INTO resource_presets
    VALUES (
        'medium-gpu',
        '{"cpu":"2","mem":"4294967296","cuda.device":"1","cuda.shares":"1.0"}'::jsonb
    );
    INSERT INTO resource_presets
    VALUES (
        'large',
        '{"cpu":"4","mem":"8589934592"}'::jsonb
    );
    INSERT INTO resource_presets
    VALUES (
        'large-gpu',
        '{"cpu":"4","mem":"8589934592","cuda.device":"2","cuda.shares":"2.0"}'::jsonb
    );
    '''
    connection = op.get_bind()
    connection.execute(query)

    query = '''
    SELECT name, total_resource_slots
    FROM keypair_resource_policies
    '''
    connection = op.get_bind()
    result = connection.execute(query)
    updates = []
    for row in result:
        converted = ResourceSlot(row['total_resource_slots'])
        if 'mem' in converted:
            converted['mem'] = Decimal(BinarySize.from_str(converted['mem']))
            updates.append((
                row['name'],
                converted,
            ))
    for name, slots in updates:
        query = (sa.update(keypair_resource_policies).values(
            total_resource_slots=slots).where(
                keypair_resource_policies.c.name == name))
        connection.execute(query)
Exemple #22
0
    async def test_query_compute_worker(self, create_app_and_client,
                                        get_headers, user_keypair):
        app, client = await create_app_and_client(
            modules=['auth', 'admin', 'manager'])

        # Add test agent info to db (needed for foreign key for kernel)
        async with app['dbpool'].acquire() as conn:
            query = agents.insert().values({
                'id':
                'test-agent-id',
                'status':
                AgentStatus.ALIVE,
                'region':
                'local',
                'available_slots':
                ResourceSlot({
                    'cpu': '1',
                    'mem': '1073741824',
                }),
                'occupied_slots':
                ResourceSlot({
                    'cpu': '0',
                    'mem': '0',
                }),
                'addr':
                '127.0.0.1',
                'lost_at':
                None,
            })
            await conn.execute(query)

        async with app['dbpool'].acquire() as conn:
            # Add master kernel info to db
            query = kernels.insert().values({
                'id':
                uuid.uuid4(),
                'status':
                KernelStatus.PREPARING,
                'sess_id':
                'test-sess-id',
                'role':
                'master',
                'agent':
                'test-agent-id',
                'agent_addr':
                '127.0.0.1:5002',
                'image':
                'lablup/lua:latest',
                'tag':
                'test-tag',
                'access_key':
                user_keypair['access_key'],
                'occupied_slots':
                ResourceSlot({
                    'cpu': '1',
                    'mem': '1073741824',
                }),
                'environ': [],
                'repl_in_port':
                0,
                'repl_out_port':
                0,
                'stdin_port':
                0,
                'stdout_port':
                0,
            })
            await conn.execute(query)
            # Add worker
            query = kernels.insert().values({
                'id':
                uuid.uuid4(),
                'status':
                KernelStatus.PREPARING,
                'sess_id':
                'test-sess-id',
                'role':
                'worker',
                'agent':
                'test-agent-id',
                'agent_addr':
                '127.0.0.1:5002',
                'image':
                'lablup/lua:latest',
                'access_key':
                user_keypair['access_key'],
                'occupied_slots':
                ResourceSlot({
                    'cpu': '1',
                    'mem': '1073741824',
                }),
                'environ': [],
                'repl_in_port':
                0,
                'repl_out_port':
                0,
                'stdin_port':
                0,
                'stdout_port':
                0,
            })
            await conn.execute(query)

        query = '{ compute_workers(sess_id: "test-sess-id") { role agent image } }'
        payload = json.dumps({'query': query}).encode()
        headers = get_headers('POST', self.url, payload, keypair=user_keypair)
        ret = await client.post(self.url, data=payload, headers=headers)

        assert ret.status == 200
        rsp_json = await ret.json()
        assert len(rsp_json) == 1
        assert rsp_json['compute_workers'][0]['role'] == 'worker'
Exemple #23
0
async def check_presets(request: web.Request, params: Any) -> web.Response:
    '''
    Returns the list of all resource presets in the current scaling group,
    with additional information including allocatability of each preset,
    amount of total remaining resources, and the current keypair resource limits.
    '''
    try:
        access_key = request['keypair']['access_key']
        resource_policy = request['keypair']['resource_policy']
        domain_name = request['user']['domain_name']
        # TODO: uncomment when we implement scaling group.
        # scaling_group = request.query.get('scaling_group')
        # assert scaling_group is not None, 'scaling_group parameter is missing.'
    except (json.decoder.JSONDecodeError, AssertionError) as e:
        raise InvalidAPIParameters(extra_msg=str(e.args[0]))
    registry = request.app['registry']
    known_slot_types = await registry.config_server.get_resource_slots()
    resp: MutableMapping[str, Any] = {
        'keypair_limits': None,
        'keypair_using': None,
        'keypair_remaining': None,
        'scaling_group_remaining': None,
        'scaling_groups': None,
        'presets': [],
    }
    log.info('CHECK_PRESETS (ak:{}, g:{}, sg:{})',
             request['keypair']['access_key'], params['group'],
             params['scaling_group'])

    async with request.app['dbpool'].acquire() as conn, conn.begin():
        # Check keypair resource limit.
        keypair_limits = ResourceSlot.from_policy(resource_policy,
                                                  known_slot_types)
        keypair_occupied = await registry.get_keypair_occupancy(access_key,
                                                                conn=conn)
        keypair_remaining = keypair_limits - keypair_occupied

        # Check group resource limit and get group_id.
        j = sa.join(groups, association_groups_users,
                    association_groups_users.c.group_id == groups.c.id)
        query = (sa.select(
            [groups.c.id, groups.c.total_resource_slots]).select_from(j).where(
                (association_groups_users.c.user_id == request['user']['uuid'])
                & (groups.c.name == params['group'])
                & (domains.c.name == domain_name)))
        result = await conn.execute(query)
        row = await result.fetchone()
        group_id = row.id
        group_resource_slots = row.total_resource_slots
        if group_id is None:
            raise InvalidAPIParameters('Unknown user group')
        group_resource_policy = {
            'total_resource_slots': group_resource_slots,
            'default_for_unspecified': DefaultForUnspecified.UNLIMITED
        }
        group_limits = ResourceSlot.from_policy(group_resource_policy,
                                                known_slot_types)
        group_occupied = await registry.get_group_occupancy(group_id,
                                                            conn=conn)
        group_remaining = group_limits - group_occupied

        # Check domain resource limit.
        query = (sa.select([domains.c.total_resource_slots
                            ]).where(domains.c.name == domain_name))
        domain_resource_slots = await conn.scalar(query)
        domain_resource_policy = {
            'total_resource_slots': domain_resource_slots,
            'default_for_unspecified': DefaultForUnspecified.UNLIMITED
        }
        domain_limits = ResourceSlot.from_policy(domain_resource_policy,
                                                 known_slot_types)
        domain_occupied = await registry.get_domain_occupancy(domain_name,
                                                              conn=conn)
        domain_remaining = domain_limits - domain_occupied

        # Take minimum remaining resources. There's no need to merge limits and occupied.
        # To keep legacy, we just merge all remaining slots into `keypair_remainig`.
        for slot in known_slot_types:
            keypair_remaining[slot] = min(
                keypair_remaining[slot],
                group_remaining[slot],
                domain_remaining[slot],
            )

        # Prepare per scaling group resource.
        sgroups = await query_allowed_sgroups(conn, domain_name, group_id,
                                              access_key)
        sgroup_names = [sg.name for sg in sgroups]
        if params['scaling_group'] is not None:
            if params['scaling_group'] not in sgroup_names:
                raise InvalidAPIParameters('Unknown scaling group')
            sgroup_names = [params['scaling_group']]
        per_sgroup = {
            sgname: {
                'using':
                ResourceSlot({k: Decimal(0)
                              for k in known_slot_types.keys()}),
                'remaining':
                ResourceSlot({k: Decimal(0)
                              for k in known_slot_types.keys()}),
            }
            for sgname in sgroup_names
        }

        # Per scaling group resource using from resource occupying kernels.
        query = (sa.select([
            kernels.c.occupied_slots, kernels.c.scaling_group
        ]).select_from(kernels).where(
            (kernels.c.user_uuid == request['user']['uuid'])
            & (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))
            & (kernels.c.scaling_group.in_(sgroup_names))))
        async for row in conn.execute(query):
            per_sgroup[row.scaling_group]['using'] += row.occupied_slots

        # Per scaling group resource remaining from agents stats.
        sgroup_remaining = ResourceSlot(
            {k: Decimal(0)
             for k in known_slot_types.keys()})
        query = (sa.select([
            agents.c.available_slots, agents.c.occupied_slots,
            agents.c.scaling_group
        ]).select_from(agents).where((agents.c.status == AgentStatus.ALIVE) & (
            agents.c.scaling_group.in_(sgroup_names))))
        agent_slots = []
        async for row in conn.execute(query):
            remaining = row['available_slots'] - row['occupied_slots']
            remaining += ResourceSlot(
                {k: Decimal(0)
                 for k in known_slot_types.keys()})
            sgroup_remaining += remaining
            agent_slots.append(remaining)
            per_sgroup[row.scaling_group]['remaining'] += remaining

        # Take maximum allocatable resources per sgroup.
        for sgname, sgfields in per_sgroup.items():
            for rtype, slots in sgfields.items():
                if rtype == 'remaining':
                    for slot in known_slot_types.keys():
                        if slot in slots:
                            slots[slot] = min(keypair_remaining[slot],
                                              slots[slot])
                per_sgroup[sgname][rtype] = slots.to_json(
                )  # type: ignore  # it's serialization
        for slot in known_slot_types.keys():
            sgroup_remaining[slot] = min(keypair_remaining[slot],
                                         sgroup_remaining[slot])

        # Fetch all resource presets in the current scaling group.
        query = (sa.select([resource_presets]).select_from(resource_presets))
        async for row in conn.execute(query):
            # Check if there are any agent that can allocate each preset.
            allocatable = False
            preset_slots = row['resource_slots'].normalize_slots(
                ignore_unknown=True)
            for agent_slot in agent_slots:
                if agent_slot >= preset_slots and keypair_remaining >= preset_slots:
                    allocatable = True
                    break
            resp['presets'].append({
                'name':
                row['name'],
                'resource_slots':
                preset_slots.to_json(),
                'shared_memory':
                str(row['shared_memory'])
                if row['shared_memory'] is not None else None,
                'allocatable':
                allocatable,
            })

        # Return group resource status as NaN if not allowed.
        group_resource_visibility = await request.app[
            'registry'].config_server.get(
                'config/api/resources/group_resource_visibility')
        group_resource_visibility = t.ToBool().check(group_resource_visibility)
        if not group_resource_visibility:
            group_limits = ResourceSlot(
                {k: Decimal('NaN')
                 for k in known_slot_types.keys()})
            group_occupied = ResourceSlot(
                {k: Decimal('NaN')
                 for k in known_slot_types.keys()})
            group_remaining = ResourceSlot(
                {k: Decimal('NaN')
                 for k in known_slot_types.keys()})

        resp['keypair_limits'] = keypair_limits.to_json()
        resp['keypair_using'] = keypair_occupied.to_json()
        resp['keypair_remaining'] = keypair_remaining.to_json()
        resp['group_limits'] = group_limits.to_json()
        resp['group_using'] = group_occupied.to_json()
        resp['group_remaining'] = group_remaining.to_json()
        resp['scaling_group_remaining'] = sgroup_remaining.to_json()
        resp['scaling_groups'] = per_sgroup
    return web.json_response(resp, status=200)