def example_mixed_agents():
    return [
        AgentContext(
            agent_id=AgentId('i-gpu'),
            agent_addr='10.0.1.1:6001',
            scaling_group='sg01',
            available_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('4.0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
            }),
        ),
        AgentContext(
            agent_id=AgentId('i-cpu'),
            agent_addr='10.0.2.1:6001',
            scaling_group='sg02',
            available_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('2560'),
                'cuda.shares': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
            }),
        ),
    ]
Beispiel #2
0
    async def test_query_agents(self, create_app_and_client, get_headers):
        app, client = await create_app_and_client(
            modules=['auth', 'admin', 'manager'])

        # Add test agents info to db (all ALIVE)
        async with app['dbpool'].acquire() as conn:
            for i in range(2):
                query = agents.insert().values({
                    'id':
                    f'test-agent-id-{i}',
                    'status':
                    AgentStatus.ALIVE,
                    'region':
                    'local',
                    'available_slots':
                    ResourceSlot({
                        'cpu': '1',
                        'mem': '1073741824',
                    }),
                    'occupied_slots':
                    ResourceSlot({
                        'cpu': '0',
                        'mem': '0',
                    }),
                    'addr':
                    '127.0.0.1',
                    'lost_at':
                    None,
                })
                await conn.execute(query)

        query = '{ agents { status region addr } }'
        payload = json.dumps({'query': query}).encode()
        headers = get_headers('POST', self.url, payload)
        ret = await client.post(self.url, data=payload, headers=headers)

        assert ret.status == 200
        rsp_json = await ret.json()
        assert rsp_json['agents'][0]['status'] == 'ALIVE'
        assert rsp_json['agents'][0]['region'] == 'local'
        assert rsp_json['agents'][0]['addr'] == '127.0.0.1'
        assert rsp_json['agents'][1]['status'] == 'ALIVE'
        assert rsp_json['agents'][1]['region'] == 'local'
        assert rsp_json['agents'][1]['addr'] == '127.0.0.1'

        # query with status
        query = '{ agents(status: "LOST") { status region addr } }'
        payload = json.dumps({'query': query}).encode()
        headers = get_headers('POST', self.url, payload)
        ret = await client.post(self.url, data=payload, headers=headers)

        assert ret.status == 200
        rsp_json = await ret.json()
        assert len(rsp_json['agents']) == 0
Beispiel #3
0
def example_pending_sessions():
    # lower indicies are enqueued first.
    return [
        PendingSession(  # rocm
            kernel_id=pending_kernel_ids[0],
            access_key=AccessKey('user01'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('2.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cuda
            kernel_id=pending_kernel_ids[1],
            access_key=AccessKey('user02'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
        PendingSession(  # cpu-only
            kernel_id=pending_kernel_ids[2],
            access_key=AccessKey('user03'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            scaling_group='sg01',
            requested_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            target_sgroup_names=[],
            **_common_dummy_for_pending_session,
        ),
    ]
async def recalculate_usage(request) -> web.Response:
    '''
    Update `keypairs.c.concurrency_used` and `agents.c.occupied_slots`.

    Those two values are sometimes out of sync. In that case, calling this API
    re-calculates the values for running containers and updates them in DB.
    '''
    async with request.app['dbpool'].acquire() as conn, conn.begin():
        # Query running containers and calculate concurrency_used per AK and
        # occupied_slots per agent.
        query = (sa.select([
            kernels.c.access_key, kernels.c.agent, kernels.c.occupied_slots
        ]).where(kernels.c.status != KernelStatus.TERMINATED).order_by(
            sa.asc(kernels.c.access_key)))
        concurrency_used_per_key: MutableMapping[str,
                                                 int] = defaultdict(lambda: 0)
        occupied_slots_per_agent: MutableMapping[str, ResourceSlot] = \
            defaultdict(lambda: ResourceSlot({'cpu': 0, 'mem': 0}))
        async for row in conn.execute(query):
            concurrency_used_per_key[row.access_key] += 1
            occupied_slots_per_agent[row.agent] += ResourceSlot(
                row.occupied_slots)

        # Update concurrency_used for keypairs with running containers.
        for ak, used in concurrency_used_per_key.items():
            query = (sa.update(keypairs).values(concurrency_used=used).where(
                keypairs.c.access_key == ak))
            await conn.execute(query)
        # Update all other keypairs to have concurrency_used = 0.
        query = (sa.update(keypairs).values(concurrency_used=0).where(
            keypairs.c.concurrency_used != 0).where(
                sa.not_(
                    keypairs.c.access_key.in_(
                        concurrency_used_per_key.keys()))))
        await conn.execute(query)

        # Update occupied_slots for agents with running containers.
        for aid, slots in occupied_slots_per_agent.items():
            query = (sa.update(agents).values(occupied_slots=slots).where(
                agents.c.id == aid))
            await conn.execute(query)
        # Update all other agents to have empty occupied_slots.
        query = (sa.update(agents).values(
            occupied_slots=ResourceSlot({})).where(
                agents.c.status == AgentStatus.ALIVE).where(
                    sa.not_(agents.c.id.in_(occupied_slots_per_agent.keys()))))
        await conn.execute(query)
    return web.json_response({}, status=200)
async def check_domain_resource_limit(
    db_conn: SAConnection,
    sched_ctx: SchedulingContext,
    sess_ctx: PendingSession,
) -> PredicateResult:
    query = (sa.select([domains.c.total_resource_slots
                        ]).where(domains.c.name == sess_ctx.domain_name))
    domain_resource_slots = await db_conn.scalar(query)
    domain_resource_policy = {
        'total_resource_slots': domain_resource_slots,
        'default_for_unspecified': DefaultForUnspecified.UNLIMITED
    }
    total_domain_allowed = ResourceSlot.from_policy(domain_resource_policy,
                                                    sched_ctx.known_slot_types)
    domain_occupied = await sched_ctx.registry.get_domain_occupancy(
        sess_ctx.domain_name, conn=db_conn)
    log.debug('domain:{} current-occupancy: {}', sess_ctx.domain_name,
              domain_occupied)
    log.debug('domain:{} total-allowed: {}', sess_ctx.domain_name,
              total_domain_allowed)
    if not (domain_occupied + sess_ctx.requested_slots <=
            total_domain_allowed):
        return PredicateResult(
            False,
            'Your domain resource quota is exceeded. ({})'.format(' '.join(
                f'{k}={v}' for k, v in total_domain_allowed.to_humanized(
                    sched_ctx.known_slot_types).items())),
        )
    return PredicateResult(True)
Beispiel #6
0
 async def mutate(cls, root, info, name, props):
     data = {}
     set_if_set(props, data, 'name')  # data['name'] is new domain name
     set_if_set(props, data, 'description')
     set_if_set(props, data, 'is_active')
     set_if_set(props,
                data,
                'total_resource_slots',
                clean_func=lambda v: ResourceSlot.from_user_input(v, None))
     set_if_set(props, data, 'allowed_vfolder_hosts')
     set_if_set(props, data, 'allowed_docker_registries')
     set_if_set(props, data, 'integration_id')
     if 'name' in data and _rx_slug.search(data['name']) is None:
         raise ValueError('invalid name format. slug format required.')
     update_query = (domains.update().values(data).where(
         domains.c.name == name))
     # The name may have changed if set.
     if 'name' in data:
         name = data['name']
     item_query = domains.select().where(domains.c.name == name)
     return await simple_db_mutate_returning_item(cls,
                                                  info.context,
                                                  update_query,
                                                  item_query=item_query,
                                                  item_cls=Domain)
async def check_keypair_resource_limit(
    db_conn: SAConnection,
    sched_ctx: SchedulingContext,
    sess_ctx: PendingSession,
) -> PredicateResult:
    query = (sa.select([
        keypair_resource_policies
    ]).select_from(keypair_resource_policies).where(
        keypair_resource_policies.c.name == sess_ctx.resource_policy))
    result = await db_conn.execute(query)
    resource_policy = await result.first()
    if len(sess_ctx.kernels) > resource_policy['max_containers_per_session']:
        return PredicateResult(
            False,
            f"You cannot create session with more than "
            f"{resource_policy['max_containers_per_session']} containers.",
            permanent=True,
        )
    total_keypair_allowed = ResourceSlot.from_policy(
        resource_policy, sched_ctx.known_slot_types)
    key_occupied = await sched_ctx.registry.get_keypair_occupancy(
        sess_ctx.access_key, conn=db_conn)
    log.debug('keypair:{} current-occupancy: {}', sess_ctx.access_key,
              key_occupied)
    log.debug('keypair:{} total-allowed: {}', sess_ctx.access_key,
              total_keypair_allowed)
    if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed):
        return PredicateResult(
            False,
            "Your keypair resource quota is exceeded. ({})".format(' '.join(
                f'{k}={v}' for k, v in total_keypair_allowed.to_humanized(
                    sched_ctx.known_slot_types).items())),
        )
    return PredicateResult(True)
Beispiel #8
0
 async def mutate(cls, root, info, name, props):
     if _rx_slug.search(name) is None:
         raise ValueError('invalid name format. slug format required.')
     data = {
         'name':
         name,
         'description':
         props.description,
         'is_active':
         props.is_active,
         'domain_name':
         props.domain_name,
         'total_resource_slots':
         ResourceSlot.from_user_input(props.total_resource_slots, None),
         'allowed_vfolder_hosts':
         props.allowed_vfolder_hosts,
         'integration_id':
         props.integration_id,
     }
     insert_query = groups.insert().values(data)
     item_query = (groups.select().where((groups.c.name == name) & (
         groups.c.domain_name == props.domain_name)))
     return await simple_db_mutate_returning_item(cls,
                                                  info.context,
                                                  insert_query,
                                                  item_query=item_query,
                                                  item_cls=Group)
async def _unreserve_agent_slots(
    db_conn: SAConnection,
    session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]],
) -> None:
    # Un-reserve agent slots, using the db transaction of the current invocation context.
    keyfunc = lambda item: item.agent_alloc_ctx.agent_id
    for agent_id, kernel_agent_bindings in itertools.groupby(
        sorted(session_agent_binding[1], key=keyfunc), key=keyfunc
    ):
        per_agent_requested_slots = sum(
            (binding.kernel.requested_slots for binding in kernel_agent_bindings),
            start=ResourceSlot(),
        )
        query = (
            sa.select([agents.c.occupied_slots], for_update=True)
            .select_from(agents)
            .where(agents.c.id == agent_id))
        current_occupied_slots = await db_conn.scalar(query)
        query = (
            sa.update(agents)
            .values({
                'occupied_slots': current_occupied_slots - per_agent_requested_slots
            })
            .where(agents.c.id == agent_id))
        await db_conn.execute(query)
Beispiel #10
0
 async def resolve_occupied_slots(
         self, info: graphene.ResolveInfo) -> Mapping[str, Any]:
     """
     Calculate the sum of occupied resource slots of all sub-kernels,
     and return the JSON-serializable object from the sum result.
     """
     manager = info.context['dlmgr']
     loader = manager.get_loader('ComputeContainer.by_session')
     containers = await loader.load(self.session_id)
     zero = ResourceSlot()
     return sum(
         (ResourceSlot(
             {SlotName(k): Decimal(v)
              for k, v in c.occupied_slots.items()}) for c in containers),
         start=zero,
     ).to_json()
Beispiel #11
0
 async def mutate(cls, root, info, name, props):
     data = {
         'name':
         name,
         'default_for_unspecified':
         DefaultForUnspecified[props.default_for_unspecified],
         'total_resource_slots':
         ResourceSlot.from_user_input(props.total_resource_slots, None),
         'max_concurrent_sessions':
         props.max_concurrent_sessions,
         'max_containers_per_session':
         props.max_containers_per_session,
         'idle_timeout':
         props.idle_timeout,
         'max_vfolder_count':
         props.max_vfolder_count,
         'max_vfolder_size':
         props.max_vfolder_size,
         'allowed_vfolder_hosts':
         props.allowed_vfolder_hosts,
     }
     insert_query = (keypair_resource_policies.insert().values(data))
     item_query = (keypair_resource_policies.select().where(
         keypair_resource_policies.c.name == name))
     return await simple_db_mutate_returning_item(
         cls,
         info.context,
         insert_query,
         item_query=item_query,
         item_cls=KeyPairResourcePolicy)
Beispiel #12
0
 async def mutate(cls, root, info, name, props):
     if _rx_slug.search(name) is None:
         return cls(False, 'invalid name format. slug format required.',
                    None)
     data = {
         'name':
         name,
         'description':
         props.description,
         'is_active':
         props.is_active,
         'total_resource_slots':
         ResourceSlot.from_user_input(props.total_resource_slots, None),
         'allowed_vfolder_hosts':
         props.allowed_vfolder_hosts,
         'allowed_docker_registries':
         props.allowed_docker_registries,
         'integration_id':
         props.integration_id,
     }
     insert_query = (domains.insert().values(data))
     item_query = domains.select().where(domains.c.name == name)
     return await simple_db_mutate_returning_item(cls,
                                                  info.context,
                                                  insert_query,
                                                  item_query=item_query,
                                                  item_cls=Domain)
async def check_group_resource_limit(
    db_conn: SAConnection,
    sched_ctx: SchedulingContext,
    sess_ctx: PendingSession,
) -> PredicateResult:
    query = (sa.select([groups.c.total_resource_slots
                        ]).where(groups.c.id == sess_ctx.group_id))
    group_resource_slots = await db_conn.scalar(query)
    group_resource_policy = {
        'total_resource_slots': group_resource_slots,
        'default_for_unspecified': DefaultForUnspecified.UNLIMITED
    }
    total_group_allowed = ResourceSlot.from_policy(group_resource_policy,
                                                   sched_ctx.known_slot_types)
    group_occupied = await sched_ctx.registry.get_group_occupancy(
        sess_ctx.group_id, conn=db_conn)
    log.debug('group:{} current-occupancy: {}', sess_ctx.group_id,
              group_occupied)
    log.debug('group:{} total-allowed: {}', sess_ctx.group_id,
              total_group_allowed)
    if not (group_occupied + sess_ctx.requested_slots <= total_group_allowed):
        return PredicateResult(
            False, "Your group resource quota is exceeded. ({})".format(
                ' '.join(f'{k}={v}'
                         for k, v in total_group_allowed.to_humanized(
                             sched_ctx.known_slot_types).items())))
    return PredicateResult(True)
Beispiel #14
0
 def process_result_value(self, raw_value: Dict[str, str], dialect):
     # legacy handling
     interim_value: Dict[str, Any] = raw_value
     mem = raw_value.get('mem')
     if isinstance(mem, str) and not mem.isdigit():
         interim_value['mem'] = BinarySize.from_str(mem)
     return ResourceSlot.from_json(
         interim_value) if raw_value is not None else None
Beispiel #15
0
def example_existing_sessions():
    return [
        ExistingSession(
            kernel_id=existing_kernel_ids[0],
            access_key=AccessKey('user01'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('1024'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('1'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernel_id=existing_kernel_ids[1],
            access_key=AccessKey('user02'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('1.0'),
                'mem': Decimal('2048'),
                'cuda.shares': Decimal('0.5'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
        ExistingSession(
            kernel_id=existing_kernel_ids[2],
            access_key=AccessKey('user03'),
            session_name='es01',
            session_type=SessionTypes.BATCH,
            occupying_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            scaling_group='sg01',
            **_common_dummy_for_existing_session,
        ),
    ]
Beispiel #16
0
 async def mutate(cls, root, info, name, props):
     data = {}
     set_if_set(props,
                data,
                'resource_slots',
                clean_func=lambda v: ResourceSlot.from_user_input(v, None))
     update_query = (resource_presets.update().values(data).where(
         resource_presets.c.name == name))
     return await simple_db_mutate(cls, info.context, update_query)
Beispiel #17
0
    async def mutate(cls, root, info, gid, props):
        data = {}
        set_if_set(props, data, 'name')
        set_if_set(props, data, 'description')
        set_if_set(props, data, 'is_active')
        set_if_set(props, data, 'domain_name')
        set_if_set(props,
                   data,
                   'total_resource_slots',
                   clean_func=lambda v: ResourceSlot.from_user_input(v, None))
        set_if_set(props, data, 'allowed_vfolder_hosts')
        set_if_set(props, data, 'integration_id')

        if 'name' in data and _rx_slug.search(data['name']) is None:
            raise ValueError('invalid name format. slug format required.')
        if props.user_update_mode not in (None, 'add', 'remove'):
            raise ValueError('invalid user_update_mode')
        if not props.user_uuids:
            props.user_update_mode = None
        if not data and props.user_update_mode is None:
            return cls(ok=False, msg='nothing to update', group=None)
        async with info.context['dbpool'].acquire() as conn, conn.begin():
            try:
                if props.user_update_mode == 'add':
                    values = [{
                        'user_id': uuid,
                        'group_id': gid
                    } for uuid in props.user_uuids]
                    query = sa.insert(association_groups_users).values(values)
                    await conn.execute(query)
                elif props.user_update_mode == 'remove':
                    query = (association_groups_users.delete().where(
                        association_groups_users.c.user_id.in_(
                            props.user_uuids)).where(
                                association_groups_users.c.group_id == gid))
                    await conn.execute(query)

                if data:
                    query = (groups.update().values(data).where(
                        groups.c.id == gid))
                    result = await conn.execute(query)
                    if result.rowcount > 0:
                        checkq = groups.select().where(groups.c.id == gid)
                        result = await conn.execute(checkq)
                        o = Group.from_row(info.context, await result.first())
                        return cls(ok=True, msg='success', group=o)
                    return cls(ok=False, msg='no such group', group=None)
                else:  # updated association_groups_users table
                    return cls(ok=True, msg='success', group=None)
            except (pg.IntegrityError, sa.exc.IntegrityError) as e:
                return cls(ok=False, msg=f'integrity error: {e}', group=None)
            except (asyncio.CancelledError, asyncio.TimeoutError):
                raise
            except Exception as e:
                return cls(ok=False, msg=f'unexpected error: {e}', group=None)
Beispiel #18
0
    async def get_image_slot_ranges(self, image_ref: ImageRef):
        '''
        Returns the minimum and maximum ResourceSlot values.
        All slot values are converted and normalized to Decimal.
        '''
        data = await self.etcd.get_prefix_dict(image_ref.tag_path)
        slot_units = await self.get_resource_slots()
        min_slot = ResourceSlot()
        max_slot = ResourceSlot()

        for slot_key, slot_range in data['resource'].items():
            slot_unit = slot_units.get(slot_key)
            if slot_unit is None:
                # ignore unknown slots
                continue
            min_value = slot_range.get('min')
            if min_value is None:
                min_value = Decimal(0)
            max_value = slot_range.get('max')
            if max_value is None:
                max_value = Decimal('Infinity')
            if slot_unit == 'bytes':
                if not isinstance(min_value, Decimal):
                    min_value = BinarySize.from_str(min_value)
                if not isinstance(max_value, Decimal):
                    max_value = BinarySize.from_str(max_value)
            else:
                if not isinstance(min_value, Decimal):
                    min_value = Decimal(min_value)
                if not isinstance(max_value, Decimal):
                    max_value = Decimal(max_value)
            min_slot[slot_key] = min_value
            max_slot[slot_key] = max_value

        # fill missing
        for slot_key in slot_units.keys():
            if slot_key not in min_slot:
                min_slot[slot_key] = Decimal(0)
            if slot_key not in max_slot:
                max_slot[slot_key] = Decimal('Infinity')

        return min_slot, max_slot
def example_agents_no_valid():
    return [
        AgentContext(
            agent_id=AgentId('i-001'),
            agent_addr='10.0.1.1:6001',
            scaling_group='sg01',
            available_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('4.0'),
                'mem': Decimal('4096'),
                'cuda.shares': Decimal('4.0'),
                'rocm.devices': Decimal('2'),
            }),
        ),
        AgentContext(
            agent_id=AgentId('i-101'),
            agent_addr='10.0.2.1:6001',
            scaling_group='sg02',
            available_slots=ResourceSlot({
                'cpu': Decimal('0'),
                'mem': Decimal('0'),
                'cuda.shares': Decimal('0'),
                'rocm.devices': Decimal('0'),
            }),
            occupied_slots=ResourceSlot({
                'cpu': Decimal('3.0'),
                'mem': Decimal('2560'),
                'cuda.shares': Decimal('1.0'),
                'rocm.devices': Decimal('8'),
            }),
        ),
    ]
Beispiel #20
0
 async def mutate(cls, root, info, name, props):
     data = {
         'name':
         name,
         'resource_slots':
         ResourceSlot.from_user_input(props.resource_slots, None),
     }
     insert_query = (resource_presets.insert().values(data))
     item_query = (resource_presets.select().where(
         resource_presets.c.name == name))
     return await simple_db_mutate_returning_item(cls,
                                                  info.context,
                                                  insert_query,
                                                  item_query=item_query,
                                                  item_cls=ResourcePreset)
Beispiel #21
0
async def recalc_agent_resource_occupancy(db_conn: SAConnection,
                                          agent_id: AgentId) -> None:
    query = (sa.select([
        kernels.c.occupied_slots,
    ]).select_from(kernels).where((kernels.c.agent == agent_id) & (
        kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))))
    occupied_slots = ResourceSlot()
    result = await db_conn.execute(query)
    async for row in result:
        occupied_slots += row['occupied_slots']
    print(f"ag:{agent_id}'s new occupied_slots:{occupied_slots}")
    query = (sa.update(agents).values({
        'occupied_slots': occupied_slots,
    }).where(agents.c.id == agent_id))
    await db_conn.execute(query)
Beispiel #22
0
def key_by_requested_slots(
    agent: AgentContext,
    requested_slots: ResourceSlot,
) -> Tuple[int, ResourceSlot]:
    unused_slot_keys = set()
    for k, v in requested_slots.items():
        if v == Decimal(0):
            unused_slot_keys.add(k)
    num_extras = 0
    for k, v in agent.available_slots.items():
        if k in unused_slot_keys and v > Decimal(0):
            num_extras += 1
    # Put back agents with more extra slot types
    # (e.g., accelerators)
    # Also put front agents with exactly required slot types
    return (-num_extras, agent.available_slots)
Beispiel #23
0
 def read_from_string(cls, text: str) -> 'KernelResourceSpec':
     kvpairs = {}
     for line in text.split('\n'):
         if '=' not in line:
             continue
         key, val = line.strip().split('=', maxsplit=1)
         kvpairs[key] = val
     allocations = cast(
         MutableMapping[DeviceName, MutableMapping[SlotName,
                                                   Mapping[DeviceId,
                                                           Decimal]]],
         defaultdict(lambda: defaultdict(Decimal)),
     )
     for key, val in kvpairs.items():
         if key.endswith('_SHARES'):
             slot_name = SlotName(key[:-7].lower())
             device_name = DeviceName(slot_name.split('.')[0])
             per_device_alloc: MutableMapping[DeviceId, Decimal] = {}
             for entry in val.split(','):
                 raw_dev_id, _, raw_alloc = entry.partition(':')
                 if not raw_dev_id or not raw_alloc:
                     continue
                 dev_id = DeviceId(raw_dev_id)
                 try:
                     if known_slot_types.get(slot_name, 'count') == 'bytes':
                         alloc = Decimal(BinarySize.from_str(raw_alloc))
                     else:
                         alloc = Decimal(raw_alloc)
                 except KeyError as e:
                     log.warning(
                         'A previously launched container has '
                         'unknown slot type: {}. Ignoring it.', e.args[0])
                     continue
                 per_device_alloc[dev_id] = alloc
             allocations[device_name][slot_name] = per_device_alloc
     mounts = [Mount.from_str(m) for m in kvpairs['MOUNTS'].split(',') if m]
     return cls(
         container_id=kvpairs.get('CID', 'unknown'),
         scratch_disk_size=BinarySize.finite_from_str(
             kvpairs['SCRATCH_SIZE']),
         allocations=dict(allocations),
         slots=ResourceSlot(json.loads(kvpairs['SLOTS'])),
         mounts=mounts,
     )
Beispiel #24
0
 async def mutate(cls, root, info, name, props):
     data = {}
     set_if_set(props,
                data,
                'default_for_unspecified',
                clean_func=lambda v: DefaultForUnspecified[v])
     set_if_set(props,
                data,
                'total_resource_slots',
                clean_func=lambda v: ResourceSlot.from_user_input(v, None))
     set_if_set(props, data, 'max_concurrent_sessions')
     set_if_set(props, data, 'max_containers_per_session')
     set_if_set(props, data, 'idle_timeout')
     set_if_set(props, data, 'max_vfolder_count')
     set_if_set(props, data, 'max_vfolder_size')
     set_if_set(props, data, 'allowed_vfolder_hosts')
     update_query = (keypair_resource_policies.update().values(data).where(
         keypair_resource_policies.c.name == name))
     return await simple_db_mutate(cls, info.context, update_query)
Beispiel #25
0
    def _assign_agent(
        self,
        agents: Sequence[AgentContext],
        access_key: AccessKey,
        requested_slots: ResourceSlot,
    ) -> Optional[AgentId]:
        # If some predicate checks for a picked session fail,
        # this method is NOT called at all for the picked session.
        # In such case, we just skip updating self.per_user_dominant_share state
        # and the scheduler dispatcher continues to pick another session within the same scaling group.

        possible_agents = []
        for agent in agents:
            remaining_slots = agent.available_slots - agent.occupied_slots
            if remaining_slots >= requested_slots:
                possible_agents.append(agent)

        if possible_agents:
            # We have one or more agents that can host the picked session.

            # Update the dominant share.
            # This is required to use to the latest dominant share information
            # when iterating over multiple pending sessions in a single scaling group.
            dominant_share_from_request = Decimal(0)
            for slot, value in requested_slots.items():
                self.total_capacity.sync_keys(requested_slots)
                slot_cap = Decimal(self.total_capacity[slot])
                if slot_cap == 0:
                    continue
                slot_share = Decimal(value) / slot_cap
                if dominant_share_from_request < slot_share:
                    dominant_share_from_request = slot_share
            if self.per_user_dominant_share[
                    access_key] < dominant_share_from_request:
                self.per_user_dominant_share[
                    access_key] = dominant_share_from_request

            # Choose the agent.
            chosen_agent = \
                max(possible_agents, key=lambda a: a.available_slots)
            return chosen_agent.agent_id

        return None
Beispiel #26
0
 async def mutate(cls, root, info, name, props):
     async with info.context['dbpool'].acquire() as conn, conn.begin():
         assert _rx_slug.search(
             name) is not None, 'invalid name format. slug format required.'
         data = {
             'name':
             name,
             'description':
             props.description,
             'is_active':
             props.is_active,
             'domain_name':
             props.domain_name,
             'total_resource_slots':
             ResourceSlot.from_user_input(props.total_resource_slots, None),
             'allowed_vfolder_hosts':
             props.allowed_vfolder_hosts,
             'integration_id':
             props.integration_id,
         }
         query = (groups.insert().values(data))
         try:
             result = await conn.execute(query)
             if result.rowcount > 0:
                 checkq = groups.select().where((groups.c.name == name) & (
                     groups.c.domain_name == props.domain_name))
                 result = await conn.execute(checkq)
                 o = Group.from_row(await result.first())
                 return cls(ok=True, msg='success', group=o)
             else:
                 return cls(ok=False,
                            msg='failed to create group',
                            group=None)
         except (pg.IntegrityError, sa.exc.IntegrityError) as e:
             return cls(ok=False, msg=f'integrity error: {e}', group=None)
         except (asyncio.CancelledError, asyncio.TimeoutError):
             raise
         except Exception as e:
             return cls(ok=False, msg=f'unexpected error: {e}', group=None)
async def check_keypair_resource_limit(
    db_conn: SAConnection,
    sched_ctx: SchedulingContext,
    sess_ctx: PendingSession,
) -> PredicateResult:
    query = (sa.select([
        keypair_resource_policies
    ]).select_from(keypair_resource_policies).where(
        keypair_resource_policies.c.name == sess_ctx.resource_policy))
    result = await db_conn.execute(query)
    resource_policy = await result.first()
    total_keypair_allowed = ResourceSlot.from_policy(
        resource_policy, sched_ctx.known_slot_types)
    key_occupied = await sched_ctx.registry.get_keypair_occupancy(
        sess_ctx.access_key, conn=db_conn)
    log.debug('keypair:{} current-occupancy: {}', sess_ctx.access_key,
              key_occupied)
    log.debug('keypair:{} total-allowed: {}', sess_ctx.access_key,
              total_keypair_allowed)
    if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed):

        async def update_status_info(
            db_conn: SAConnection,
            sched_ctx: SchedulingContext,
            sess_ctx: PendingSession,
        ) -> None:
            query = (sa.update(kernels).values(
                status_info='out-of-resource (keypair resource quota exceeded)'
            ).where(kernels.c.id == sess_ctx.kernel_id))
            await db_conn.execute(query)

        return PredicateResult(
            False,
            'Your keypair resource quota is exceeded. ({})'.format(' '.join(
                f'{k}={v}' for k, v in total_keypair_allowed.to_humanized(
                    sched_ctx.known_slot_types).items())),
            failure_cb=update_status_info)
    return PredicateResult(True)
Beispiel #28
0
async def _list_existing_sessions(
    db_conn: SAConnection,
    sgroup: str,
) -> List[ExistingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.startup_command,
            kernels.c.internal_data,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
            (kernels.c.scaling_group == sgroup)
        )
        .order_by(kernels.c.created_at)
    )
    items: MutableMapping[str, ExistingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = ExistingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                occupying_slots=ResourceSlot(),
            )
            items[row['session_id']] = session
        # TODO: support multi-container sessions
        session.kernels.append(KernelInfo(  # type: ignore
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=None,
            startup_command=None,
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.occupying_slots += row['occupied_slots']  # type: ignore
Beispiel #29
0
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    # TODO: extend for multi-container sessions
    items: MutableMapping[str, PendingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = PendingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                resource_policy=row['resource_policy'],
                resource_opts={},
                requested_slots=ResourceSlot(),
                internal_data=row['internal_data'],
                target_sgroup_names=[],
                environ={
                    k: v for k, v
                    in map(lambda s: s.split('=', maxsplit=1), row['environ'])
                },
                mounts=row['mounts'],
                mount_map=row['mount_map'],
                bootstrap_script=row['bootstrap_script'],
                startup_command=row['startup_command'],
                preopen_ports=row['preopen_ports'],
            )
            items[row['session_id']] = session
        session.kernels.append(KernelInfo(
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.requested_slots += row['occupied_slots']  # type: ignore
        merge_resource(session.resource_opts, row['resource_opts'])  # type: ignore
Beispiel #30
0
        async def _schedule_in_sgroup(db_conn: SAConnection, sgroup_name: str) -> None:
            async with db_conn.begin():
                scheduler = await self._load_scheduler(db_conn, sgroup_name)
                pending_sessions = await _list_pending_sessions(db_conn, sgroup_name)
                existing_sessions = await _list_existing_sessions(db_conn, sgroup_name)
            log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})',
                      sgroup_name, len(pending_sessions), len(existing_sessions))
            zero = ResourceSlot()
            while len(pending_sessions) > 0:
                async with db_conn.begin():
                    candidate_agents = await _list_agents_by_sgroup(db_conn, sgroup_name)
                    total_capacity = sum((ag.available_slots for ag in candidate_agents), zero)

                picked_session_id = scheduler.pick_session(
                    total_capacity,
                    pending_sessions,
                    existing_sessions,
                )
                if picked_session_id is None:
                    # no session is picked.
                    # continue to next sgroup.
                    return
                for picked_idx, sess_ctx in enumerate(pending_sessions):
                    if sess_ctx.session_id == picked_session_id:
                        break
                else:
                    # no matching entry for picked session?
                    raise RuntimeError('should not reach here')
                sess_ctx = pending_sessions.pop(picked_idx)

                log_args = (
                    sess_ctx.session_id,
                    sess_ctx.session_type,
                    sess_ctx.session_name,
                    sess_ctx.access_key,
                    sess_ctx.cluster_mode,
                )
                log.debug(log_fmt + 'try-scheduling', *log_args)
                session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]]

                async with db_conn.begin():
                    predicates: Sequence[Awaitable[PredicateResult]] = [
                        check_reserved_batch_session(db_conn, sched_ctx, sess_ctx),
                        check_concurrency(db_conn, sched_ctx, sess_ctx),
                        check_dependencies(db_conn, sched_ctx, sess_ctx),
                        check_keypair_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_group_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_domain_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_scaling_group(db_conn, sched_ctx, sess_ctx),
                    ]
                    check_results: List[Union[Exception, PredicateResult]] = []
                    for check in predicates:
                        try:
                            check_results.append(await check)
                        except Exception as e:
                            log.exception(log_fmt + 'predicate-error', *log_args)
                            check_results.append(e)
                    has_failure = False
                    for result in check_results:
                        if isinstance(result, Exception):
                            has_failure = True
                            continue
                        if not result.passed:
                            has_failure = True
                    if has_failure:
                        log.debug(log_fmt + 'predicate-checks-failed', *log_args)
                        await _invoke_failure_callbacks(
                            db_conn, sched_ctx, sess_ctx, check_results,
                        )
                        # Predicate failures are *NOT* permanent errors.
                        # We need to retry the scheduling afterwards.
                        continue

                    if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE:
                        # Assign agent resource per session.
                        try:
                            agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx)
                            if agent_id is None:
                                raise InstanceNotAvailable
                            agent_alloc_ctx = await _reserve_agent(
                                sched_ctx, db_conn, sgroup_name, agent_id, sess_ctx.requested_slots,
                            )
                        except InstanceNotAvailable:
                            log.debug(log_fmt + 'no-available-instances', *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        except Exception:
                            log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                          *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        query = kernels.update().values({
                            'agent': agent_alloc_ctx.agent_id,
                            'agent_addr': agent_alloc_ctx.agent_addr,
                            'scaling_group': sgroup_name,
                            'status': KernelStatus.PREPARING,
                            'status_info': 'scheduled',
                            'status_changed': datetime.now(tzutc()),
                        }).where(kernels.c.session_id == sess_ctx.session_id)
                        await db_conn.execute(query)
                        session_agent_binding = (
                            sess_ctx,
                            [
                                KernelAgentBinding(kernel, agent_alloc_ctx)
                                for kernel in sess_ctx.kernels
                            ],
                        )
                    elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE:
                        # Assign agent resource per kernel in the session.
                        agent_query_extra_conds = None
                        if len(sess_ctx.kernels) >= 2:
                            # We should use agents that supports overlay networking.
                            agent_query_extra_conds = (agents.c.clusterized)
                        kernel_agent_bindings = []
                        for kernel in sess_ctx.kernels:
                            try:
                                agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel)
                                if agent_id is None:
                                    raise InstanceNotAvailable
                                agent_alloc_ctx = await _reserve_agent(
                                    sched_ctx, db_conn, sgroup_name, agent_id, kernel.requested_slots,
                                    extra_conds=agent_query_extra_conds,
                                )
                            except InstanceNotAvailable:
                                log.debug(log_fmt + 'no-available-instances', *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            except Exception:
                                log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                              *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            # TODO: if error occurs for one kernel, should we cancel all others?
                            query = kernels.update().values({
                                'agent': agent_alloc_ctx.agent_id,
                                'agent_addr': agent_alloc_ctx.agent_addr,
                                'scaling_group': sgroup_name,
                                'status': KernelStatus.PREPARING,
                                'status_info': 'scheduled',
                                'status_changed': datetime.now(tzutc()),
                            }).where(kernels.c.id == kernel.kernel_id)
                            await db_conn.execute(query)
                            kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx))
                        session_agent_binding = (sess_ctx, kernel_agent_bindings)
                    start_task_args.append(
                        (
                            log_args,
                            sched_ctx,
                            session_agent_binding,
                            check_results,
                        )
                    )