Example #1
0
async def reenter_txn(pool: SAEngine, conn: SAConnection):
    if conn is None:
        async with pool.acquire() as conn, conn.begin():
            yield conn
    else:
        async with conn.begin_nested():
            yield conn
Example #2
0
async def _list_agents_by_sgroup(
    db_conn: SAConnection,
    sgroup_name: str,
) -> Sequence[AgentContext]:
    query = (
        sa.select([
            agents.c.id,
            agents.c.addr,
            agents.c.scaling_group,
            agents.c.available_slots,
            agents.c.occupied_slots,
        ], for_update=True)
        .select_from(agents)
        .where(
            (agents.c.status == AgentStatus.ALIVE) &
            (agents.c.scaling_group == sgroup_name) &
            (agents.c.schedulable == true())
        )
    )
    items = []
    async for row in db_conn.execute(query):
        item = AgentContext(
            row['id'],
            row['addr'],
            row['scaling_group'],
            row['available_slots'],
            row['occupied_slots'],
        )
        items.append(item)
    return items
Example #3
0
    async def group_vfolder_mounted_to_active_kernels(
        cls,
        conn: SAConnection,
        group_id: uuid.UUID,
    ) -> bool:
        """
        Check if no active kernel is using the group's virtual folders.

        :param conn: DB connection
        :param group_id: group's ID

        :return: True if a virtual folder is mounted to active kernels.
        """
        from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES
        query = (sa.select([
            vfolders.c.id
        ]).select_from(vfolders).where(vfolders.c.group == group_id))
        result = await conn.execute(query)
        rows = await result.fetchall()
        group_vfolder_ids = [row.id for row in rows]
        query = (sa.select([
            kernels.c.mounts
        ]).select_from(kernels).where((kernels.c.group_id == group_id) & (
            kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))))
        async for row in conn.execute(query):
            for _mount in row['mounts']:
                try:
                    vfolder_id = uuid.UUID(_mount[2])
                except Exception:
                    pass
                if vfolder_id in group_vfolder_ids:
                    return True
        return False
Example #4
0
 async def __load_user_groups(self, conn: SAConnection,
                              user_id: str) -> List[str]:
     user_groups: List[str] = []
     query = select([user_to_groups.c.gid
                     ]).where(user_to_groups.c.uid == user_id)
     async for row in conn.execute(query):
         user_groups.append(row[user_to_groups.c.gid])
     return user_groups
Example #5
0
 async def __load_user_groups(self, conn: SAConnection,
                              user_id: int) -> List[RowProxy]:
     user_groups: List[RowProxy] = []
     query = (select([groups
                      ]).select_from(groups.join(user_to_groups)).where(
                          user_to_groups.c.uid == user_id))
     async for row in conn.execute(query):
         user_groups.append(row)
     return user_groups
Example #6
0
    async def __load_projects(
        self,
        conn: SAConnection,
        query: str,
        user_id: int,
        user_groups: List[RowProxy],
        filter_by_services: Optional[List[Dict]] = None,
    ) -> Tuple[List[Dict[str, Any]], List[ProjectType]]:
        api_projects: List[Dict] = []  # API model-compatible projects
        db_projects: List[Dict] = []  # DB model-compatible projects
        project_types: List[ProjectType] = []
        async for row in conn.execute(query):
            try:
                _check_project_permissions(row, user_id, user_groups, "read")

                await asyncio.get_event_loop().run_in_executor(
                    None, ProjectAtDB.from_orm, row)

            except ProjectInvalidRightsError:
                continue

            except ValidationError as exc:
                log.warning(
                    "project  %s  failed validation, please check. error: %s",
                    f"{row.id=}",
                    exc,
                )
                continue

            prj = dict(row.items())

            if filter_by_services:
                if not await project_uses_available_services(
                        prj, filter_by_services):
                    log.warning(
                        "Project %s will not be listed for user %s since it has no access rights"
                        " for one or more of the services that includes.",
                        f"{row.id=}",
                        f"{user_id=}",
                    )
                    continue
            db_projects.append(prj)

        # NOTE: DO NOT nest _get_tags_by_project in async loop above !!!
        # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535
        for db_prj in db_projects:
            db_prj["tags"] = await self._get_tags_by_project(
                conn, project_id=db_prj["id"])
            user_email = await self._get_user_email(conn, db_prj["prj_owner"])
            api_projects.append(_convert_to_schema_names(db_prj, user_email))
            project_types.append(db_prj["type"])

        return (api_projects, project_types)
Example #7
0
async def _list_existing_sessions(
    db_conn: SAConnection,
    sgroup: str,
) -> List[ExistingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.registry,
            kernels.c.sess_type,
            kernels.c.sess_id,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.startup_command,
            kernels.c.internal_data,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
            (kernels.c.scaling_group == sgroup)
        )
        .order_by(kernels.c.created_at)
    )
    items = []
    async for row in db_conn.execute(query):
        items.append(ExistingSession(
            kernel_id=row['id'],
            access_key=row['access_key'],
            session_type=row['sess_type'],
            session_name=row['sess_id'],
            domain_name=row['domain_name'],
            group_id=row['group_id'],
            scaling_group=row['scaling_group'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            occupying_slots=row['occupied_slots'],
        ))
    return items
Example #8
0
async def _clusters_from_cluster_ids(
    conn: connection.SAConnection,
    cluster_ids: Iterable[PositiveInt],
    offset: int = 0,
    limit: Optional[int] = None,
) -> List[Cluster]:
    cluster_id_to_cluster: Dict[PositiveInt, Cluster] = {}
    async for row in conn.execute(
            sa.select([
                clusters,
                cluster_to_groups.c.gid,
                cluster_to_groups.c.read,
                cluster_to_groups.c.write,
                cluster_to_groups.c.delete,
            ]).select_from(
                clusters.join(
                    cluster_to_groups,
                    clusters.c.id == cluster_to_groups.c.cluster_id,
                )).where(clusters.c.id.in_(cluster_ids)).offset(offset).limit(
                    limit)):
        cluster_access_rights = {
            row[cluster_to_groups.c.gid]:
            ClusterAccessRights(
                **{
                    "read": row[cluster_to_groups.c.read],
                    "write": row[cluster_to_groups.c.write],
                    "delete": row[cluster_to_groups.c.delete],
                })
        }

        cluster_id = row[clusters.c.id]
        if cluster_id not in cluster_id_to_cluster:
            cluster_id_to_cluster[cluster_id] = Cluster(
                id=cluster_id,
                name=row[clusters.c.name],
                description=row[clusters.c.description],
                type=row[clusters.c.type],
                owner=row[clusters.c.owner],
                endpoint=row[clusters.c.endpoint],
                authentication=row[clusters.c.authentication],
                thumbnail=row[clusters.c.thumbnail],
                access_rights=cluster_access_rights,
            )
        else:
            cluster_id_to_cluster[cluster_id].access_rights.update(
                cluster_access_rights)

    return list(cluster_id_to_cluster.values())
Example #9
0
    async def __load_projects(self, conn: SAConnection, query) -> List[Dict]:
        api_projects: List[Dict] = []  # API model-compatible projects
        db_projects: List[Dict] = []  # DB model-compatible projects
        async for row in conn.execute(query):
            prj = dict(row.items())
            log.debug("found project: %s", prj)
            db_projects.append(prj)

        # NOTE: DO NOT nest _get_tags_by_project in async loop above !!!
        # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535
        for db_prj in db_projects:
            db_prj["tags"] = await self._get_tags_by_project(
                conn, project_id=db_prj["id"])
            user_email = await self._get_user_email(conn, db_prj["prj_owner"])
            api_projects.append(_convert_to_schema_names(db_prj, user_email))

        return api_projects
Example #10
0
async def batch_multiresult(
    context: Mapping[str, Any],
    conn: SAConnection,
    query: sa.sql.Select,
    obj_type: Type[_GenericSQLBasedGQLObject],
    key_list: Iterable[_Key],
    key_getter: Callable[[RowProxy], _Key],
) -> Sequence[Sequence[_GenericSQLBasedGQLObject]]:
    """
    A batched query adaptor for (key -> [item]) resolving patterns.
    """
    objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]]
    objs_per_key = collections.OrderedDict()
    for key in key_list:
        objs_per_key[key] = list()
    async for row in conn.execute(query):
        objs_per_key[key_getter(row)].append(obj_type.from_row(context, row))
    return [*objs_per_key.values()]
Example #11
0
    async def delete_vfolders(
        cls,
        conn: SAConnection,
        user_uuid: uuid.UUID,
        config_server,
    ) -> int:
        """
        Delete user's all virtual folders as well as their physical data.

        :param conn: DB connection
        :param user_uuid: user's UUID to delete virtual folders

        :return: number of deleted rows
        """
        from . import vfolders
        mount_prefix = Path(await config_server.get('volumes/_mount'))
        fs_prefix = await config_server.get('volumes/_fsprefix')
        fs_prefix = Path(fs_prefix.lstrip('/'))
        query = (
            sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.unmanaged_path])
            .select_from(vfolders)
            .where(vfolders.c.user == user_uuid)
        )
        async for row in conn.execute(query):
            if row['unmanaged_path']:
                folder_path = Path(row['unmanaged_path'])
            else:
                folder_path = (mount_prefix / row['host'] / fs_prefix / row['id'].hex)
            log.info('deleting physical files: {0}', folder_path)
            try:
                loop = current_loop()
                await loop.run_in_executor(None, lambda: shutil.rmtree(folder_path))  # type: ignore
            except IOError:
                pass
        query = (
            vfolders.delete()
            .where(vfolders.c.user == user_uuid)
        )
        result = await conn.execute(query)
        if result.rowcount > 0:
            log.info('deleted {0} user\'s virtual folders ({1})', result.rowcount, user_uuid)
        return result.rowcount
Example #12
0
async def list_projects_access_rights(
        conn: SAConnection, user_id: int) -> Dict[ProjectID, AccessRights]:
    """
    Returns access-rights of user (user_id) over all OWNED or SHARED projects
    """

    user_group_ids: List[int] = await _get_user_groups_ids(conn, user_id)

    smt = text(f"""\
    SELECT uuid, access_rights
    FROM projects
    WHERE (
        prj_owner = {user_id}
        OR jsonb_exists_any( access_rights, (
               SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
            )
        )
    )
    """)
    projects_access_rights = {}

    async for row in conn.execute(smt):
        assert isinstance(row.access_rights, dict)
        assert isinstance(row.uuid, ProjectID)

        if row.access_rights:
            # TODO: access_rights should be direclty filtered from result in stm instead calling again user_group_ids
            projects_access_rights[row.uuid] = _aggregate_access_rights(
                row.access_rights, user_group_ids)

        else:
            # backwards compatibility
            # - no access_rights defined BUT project is owned
            projects_access_rights[row.uuid] = AccessRights.all()

    return projects_access_rights
Example #13
0
    async def migrate_shared_vfolders(
        cls,
        conn: SAConnection,
        deleted_user_uuid: uuid.UUID,
        target_user_uuid: uuid.UUID,
        target_user_email: str,
    ) -> int:
        """
        Migrate shared virtual folders' ownership to a target user.

        If migrating virtual folder's name collides with target user's already
        existing folder, append random string to the migrating one.

        :param conn: DB connection
        :param deleted_user_uuid: user's UUID who will be deleted
        :param target_user_uuid: user's UUID who will get the ownership of virtual folders

        :return: number of deleted rows
        """
        from . import vfolders, vfolder_invitations, vfolder_permissions
        # Gather target user's virtual folders' names.
        query = (
            sa.select([vfolders.c.name])
            .select_from(vfolders)
            .where(vfolders.c.user == target_user_uuid)
        )
        existing_vfolder_names = [row.name async for row in conn.execute(query)]

        # Migrate shared virtual folders.
        # If virtual folder's name collides with target user's folder,
        # append random string to the name of the migrating folder.
        j = vfolder_permissions.join(
            vfolders,
            vfolder_permissions.c.vfolder == vfolders.c.id
        )
        query = (
            sa.select([vfolders.c.id, vfolders.c.name])
            .select_from(j)
            .where(vfolders.c.user == deleted_user_uuid)
        )
        migrate_updates = []
        async for row in conn.execute(query):
            name = row.name
            if name in existing_vfolder_names:
                name += f'-{uuid.uuid4().hex[:10]}'
            migrate_updates.append({'vid': row.id, 'vname': name})
        if migrate_updates:
            # Remove invitations and vfolder_permissions from target user.
            # Target user will be the new owner, and it does not make sense to have
            # invitation and shared permission for its own folder.
            migrate_vfolder_ids = [item['vid'] for item in migrate_updates]
            query = (
                vfolder_invitations.delete()
                .where((vfolder_invitations.c.invitee == target_user_email) &
                       (vfolder_invitations.c.vfolder.in_(migrate_vfolder_ids)))
            )
            await conn.execute(query)
            query = (
                vfolder_permissions.delete()
                .where((vfolder_permissions.c.user == target_user_uuid) &
                       (vfolder_permissions.c.vfolder.in_(migrate_vfolder_ids)))
            )
            await conn.execute(query)

            rowcount = 0
            for item in migrate_updates:
                query = (
                    vfolders.update()
                    .values(
                        user=target_user_uuid,
                        name=item['vname'],
                    )
                    .where(vfolders.c.id == item['vid'])
                )
                result = await conn.execute(query)
                rowcount += result.rowcount
            if rowcount > 0:
                log.info('{0} shared folders detected. migrated to user {1}',
                         rowcount, target_user_uuid)
            return rowcount
        else:
            return 0
async def test_basic_workflow(project: RowProxy, conn: SAConnection):

    # git init
    async with conn.begin():
        # create repo
        repo_orm = ReposOrm(conn)
        repo_id = await repo_orm.insert(project_uuid=project.uuid)
        assert repo_id is not None
        assert isinstance(repo_id, int)

        repo_orm.set_filter(rowid=repo_id)
        repo = await repo_orm.fetch()
        assert repo
        assert repo.project_uuid == project.uuid
        assert repo.project_checksum is None
        assert repo.created == repo.modified

        # create main branch
        branches_orm = BranchesOrm(conn)
        branch_id = await branches_orm.insert(repo_id=repo.id)
        assert branch_id is not None
        assert isinstance(branch_id, int)

        branches_orm.set_filter(rowid=branch_id)
        main_branch: Optional[RowProxy] = await branches_orm.fetch()
        assert main_branch
        assert main_branch.name == "main", "Expected 'main' as default branch"
        assert main_branch.head_commit_id is None, "still not assigned"
        assert main_branch.created == main_branch.modified

        # assign head branch
        heads_orm = HeadsOrm(conn)
        await heads_orm.insert(repo_id=repo.id, head_branch_id=branch_id)

        heads_orm.set_filter(rowid=repo.id)
        head = await heads_orm.fetch()
        assert head

    #
    # create first commit -- TODO: separate tests

    # fetch a *full copy* of the project (WC)
    repo = await repo_orm.fetch("id project_uuid project_checksum")
    assert repo

    project_orm = ProjectsOrm(conn).set_filter(uuid=repo.project_uuid)
    project_wc = await project_orm.fetch()
    assert project_wc
    assert project == project_wc

    # call external function to compute checksum
    checksum = eval_checksum(project_wc.workbench)
    assert repo.project_checksum != checksum

    # take snapshot <=> git add & commit
    async with conn.begin():

        snapshot_checksum = await add_snapshot(project_wc, checksum, repo,
                                               conn)

        # get HEAD = repo.branch_id -> .head_commit_id
        assert head.repo_id == repo.id
        branches_orm.set_filter(head.head_branch_id)
        branch = await branches_orm.fetch("head_commit_id name")
        assert branch
        assert branch.name == "main"
        assert branch.head_commit_id is None, "First commit"

        # create commit
        commits_orm = CommitsOrm(conn)
        commit_id = await commits_orm.insert(
            repo_id=repo.id,
            parent_commit_id=branch.head_commit_id,
            snapshot_checksum=snapshot_checksum,
            message="first commit",
        )
        assert commit_id
        assert isinstance(commit_id, int)

        # update branch head
        await branches_orm.update(head_commit_id=commit_id)

        # update checksum cache
        await repo_orm.update(project_checksum=snapshot_checksum)

    # log history
    commits = await commits_orm.fetch_all()
    assert len(commits) == 1
    assert commits[0].id == commit_id

    # tag
    tag_orm = TagsOrm(conn)
    tag_id = await tag_orm.insert(
        repo_id=repo.id,
        commit_id=commit_id,
        name="v1",
    )
    assert tag_id is not None
    assert isinstance(tag_id, int)

    tag = await tag_orm.fetch(rowid=tag_id)
    assert tag
    assert tag.name == "v1"

    ############# NEW COMMIT #####################

    # user add some changes
    repo = await repo_orm.fetch()
    assert repo

    project_orm.set_filter(uuid=repo.project_uuid)
    assert project_orm.is_filter_set()

    await project_orm.update(workbench={"node": {
        "input": 3,
    }})

    project_wc = await project_orm.fetch("workbench ui")
    assert project_wc
    assert project.workbench != project_wc.workbench

    # get HEAD = repo.branch_id -> .head_commit_id
    head = await heads_orm.fetch()
    assert head
    branch = await branches_orm.fetch("head_commit_id",
                                      rowid=head.head_branch_id)
    assert branch
    # TODO: get subquery ... and compose
    head_commit = await commits_orm.fetch(rowid=branch.head_commit_id)
    assert head_commit

    # compare checksums between wc and HEAD
    checksum = eval_checksum(project_wc.workbench)
    assert head_commit.snapshot_checksum != checksum

    # updates wc checksum cache
    await repo_orm.update(project_checksum=checksum)

    # take snapshot = add & commit
    async with conn.begin():
        snapshot_uuid: str = await add_snapshot(project_wc, checksum, repo,
                                                conn)

        commit_id = await commits_orm.insert(
            repo_id=head_commit.repo_id,
            parent_commit_id=head_commit.id,
            snapshot_checksum=checksum,
            message="second commit",
        )
        assert commit_id
        assert isinstance(commit_id, int)

        # update branch head
        await branches_orm.update(head_commit_id=commit_id)

    # log history
    commits = await commits_orm.fetch_all()
    assert len(commits) == 2
    assert commits[1].id == commit_id
Example #15
0
async def _list_existing_sessions(
    db_conn: SAConnection,
    sgroup: str,
) -> List[ExistingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.startup_command,
            kernels.c.internal_data,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
            (kernels.c.scaling_group == sgroup)
        )
        .order_by(kernels.c.created_at)
    )
    items: MutableMapping[str, ExistingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = ExistingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                occupying_slots=ResourceSlot(),
            )
            items[row['session_id']] = session
        # TODO: support multi-container sessions
        session.kernels.append(KernelInfo(  # type: ignore
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=None,
            startup_command=None,
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.occupying_slots += row['occupied_slots']  # type: ignore
Example #16
0
 async def _get_tags_by_project(self, conn: SAConnection,
                                project_id: str) -> List:
     query = sa.select([study_tags.c.tag_id
                        ]).where(study_tags.c.study_id == project_id)
     return [row.tag_id async for row in conn.execute(query)]
Example #17
0
        async def _schedule_in_sgroup(db_conn: SAConnection, sgroup_name: str) -> None:
            async with db_conn.begin():
                scheduler = await self._load_scheduler(db_conn, sgroup_name)
                pending_sessions = await _list_pending_sessions(db_conn, sgroup_name)
                existing_sessions = await _list_existing_sessions(db_conn, sgroup_name)
            log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})',
                      sgroup_name, len(pending_sessions), len(existing_sessions))
            zero = ResourceSlot()
            while len(pending_sessions) > 0:
                async with db_conn.begin():
                    candidate_agents = await _list_agents_by_sgroup(db_conn, sgroup_name)
                    total_capacity = sum((ag.available_slots for ag in candidate_agents), zero)

                picked_session_id = scheduler.pick_session(
                    total_capacity,
                    pending_sessions,
                    existing_sessions,
                )
                if picked_session_id is None:
                    # no session is picked.
                    # continue to next sgroup.
                    return
                for picked_idx, sess_ctx in enumerate(pending_sessions):
                    if sess_ctx.session_id == picked_session_id:
                        break
                else:
                    # no matching entry for picked session?
                    raise RuntimeError('should not reach here')
                sess_ctx = pending_sessions.pop(picked_idx)

                log_args = (
                    sess_ctx.session_id,
                    sess_ctx.session_type,
                    sess_ctx.session_name,
                    sess_ctx.access_key,
                    sess_ctx.cluster_mode,
                )
                log.debug(log_fmt + 'try-scheduling', *log_args)
                session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]]

                async with db_conn.begin():
                    predicates: Sequence[Awaitable[PredicateResult]] = [
                        check_reserved_batch_session(db_conn, sched_ctx, sess_ctx),
                        check_concurrency(db_conn, sched_ctx, sess_ctx),
                        check_dependencies(db_conn, sched_ctx, sess_ctx),
                        check_keypair_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_group_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_domain_resource_limit(db_conn, sched_ctx, sess_ctx),
                        check_scaling_group(db_conn, sched_ctx, sess_ctx),
                    ]
                    check_results: List[Union[Exception, PredicateResult]] = []
                    for check in predicates:
                        try:
                            check_results.append(await check)
                        except Exception as e:
                            log.exception(log_fmt + 'predicate-error', *log_args)
                            check_results.append(e)
                    has_failure = False
                    for result in check_results:
                        if isinstance(result, Exception):
                            has_failure = True
                            continue
                        if not result.passed:
                            has_failure = True
                    if has_failure:
                        log.debug(log_fmt + 'predicate-checks-failed', *log_args)
                        await _invoke_failure_callbacks(
                            db_conn, sched_ctx, sess_ctx, check_results,
                        )
                        # Predicate failures are *NOT* permanent errors.
                        # We need to retry the scheduling afterwards.
                        continue

                    if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE:
                        # Assign agent resource per session.
                        try:
                            agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx)
                            if agent_id is None:
                                raise InstanceNotAvailable
                            agent_alloc_ctx = await _reserve_agent(
                                sched_ctx, db_conn, sgroup_name, agent_id, sess_ctx.requested_slots,
                            )
                        except InstanceNotAvailable:
                            log.debug(log_fmt + 'no-available-instances', *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        except Exception:
                            log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                          *log_args)
                            await _invoke_failure_callbacks(
                                db_conn, sched_ctx, sess_ctx, check_results,
                            )
                            raise
                        query = kernels.update().values({
                            'agent': agent_alloc_ctx.agent_id,
                            'agent_addr': agent_alloc_ctx.agent_addr,
                            'scaling_group': sgroup_name,
                            'status': KernelStatus.PREPARING,
                            'status_info': 'scheduled',
                            'status_changed': datetime.now(tzutc()),
                        }).where(kernels.c.session_id == sess_ctx.session_id)
                        await db_conn.execute(query)
                        session_agent_binding = (
                            sess_ctx,
                            [
                                KernelAgentBinding(kernel, agent_alloc_ctx)
                                for kernel in sess_ctx.kernels
                            ],
                        )
                    elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE:
                        # Assign agent resource per kernel in the session.
                        agent_query_extra_conds = None
                        if len(sess_ctx.kernels) >= 2:
                            # We should use agents that supports overlay networking.
                            agent_query_extra_conds = (agents.c.clusterized)
                        kernel_agent_bindings = []
                        for kernel in sess_ctx.kernels:
                            try:
                                agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel)
                                if agent_id is None:
                                    raise InstanceNotAvailable
                                agent_alloc_ctx = await _reserve_agent(
                                    sched_ctx, db_conn, sgroup_name, agent_id, kernel.requested_slots,
                                    extra_conds=agent_query_extra_conds,
                                )
                            except InstanceNotAvailable:
                                log.debug(log_fmt + 'no-available-instances', *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            except Exception:
                                log.exception(log_fmt + 'unexpected-error, during agent allocation',
                                              *log_args)
                                await _invoke_failure_callbacks(
                                    db_conn, sched_ctx, sess_ctx, check_results,
                                )
                                # continue
                                raise
                            # TODO: if error occurs for one kernel, should we cancel all others?
                            query = kernels.update().values({
                                'agent': agent_alloc_ctx.agent_id,
                                'agent_addr': agent_alloc_ctx.agent_addr,
                                'scaling_group': sgroup_name,
                                'status': KernelStatus.PREPARING,
                                'status_info': 'scheduled',
                                'status_changed': datetime.now(tzutc()),
                            }).where(kernels.c.id == kernel.kernel_id)
                            await db_conn.execute(query)
                            kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx))
                        session_agent_binding = (sess_ctx, kernel_agent_bindings)
                    start_task_args.append(
                        (
                            log_args,
                            sched_ctx,
                            session_agent_binding,
                            check_results,
                        )
                    )
Example #18
0
    async def _schedule_single_node_session(
        self,
        sched_ctx: SchedulingContext,
        scheduler: AbstractScheduler,
        agent_db_conn: SAConnection,
        kernel_db_conn: SAConnection,
        sgroup_name: str,
        candidate_agents: Sequence[AgentContext],
        sess_ctx: PendingSession,
        check_results: List[Tuple[str, Union[Exception, PredicateResult]]],
    ) -> Tuple[PendingSession, List[KernelAgentBinding]]:
        # Assign agent resource per session.
        log_fmt = _log_fmt.get()
        log_args = _log_args.get()
        try:
            agent_id = scheduler.assign_agent_for_session(
                candidate_agents, sess_ctx)
            if agent_id is None:
                raise InstanceNotAvailable
            async with agent_db_conn.begin():
                agent_alloc_ctx = await _reserve_agent(
                    sched_ctx,
                    agent_db_conn,
                    sgroup_name,
                    agent_id,
                    sess_ctx.requested_slots,
                )
        except InstanceNotAvailable:
            log.debug(log_fmt + 'no-available-instances', *log_args)
            async with kernel_db_conn.begin():
                await _invoke_failure_callbacks(
                    kernel_db_conn,
                    sched_ctx,
                    sess_ctx,
                    check_results,
                )
                query = kernels.update().values({
                    'status_info':
                    "no-available-instances",
                    'status_data':
                    sql_json_increment(kernels.c.status_data,
                                       ('scheduler', 'retries'),
                                       parent_updates={
                                           'last_try':
                                           datetime.now(tzutc()).isoformat(),
                                       }),
                }).where(kernels.c.id == sess_ctx.session_id)
                await kernel_db_conn.execute(query)
            raise
        except Exception as e:
            log.exception(
                log_fmt + 'unexpected-error, during agent allocation',
                *log_args,
            )
            async with kernel_db_conn.begin():
                await _invoke_failure_callbacks(
                    kernel_db_conn,
                    sched_ctx,
                    sess_ctx,
                    check_results,
                )
                query = kernels.update().values({
                    'status_info':
                    "scheduler-error",
                    'status_data':
                    convert_to_status_data(e),
                }).where(kernels.c.id == sess_ctx.session_id)
                await kernel_db_conn.execute(query)
            raise

        async with kernel_db_conn.begin():
            query = kernels.update().values({
                'agent':
                agent_alloc_ctx.agent_id,
                'agent_addr':
                agent_alloc_ctx.agent_addr,
                'scaling_group':
                sgroup_name,
                'status':
                KernelStatus.PREPARING,
                'status_info':
                'scheduled',
                'status_data': {},
                'status_changed':
                datetime.now(tzutc()),
            }).where(kernels.c.session_id == sess_ctx.session_id)
            await kernel_db_conn.execute(query)
        return (
            sess_ctx,
            [
                KernelAgentBinding(kernel, agent_alloc_ctx)
                for kernel in sess_ctx.kernels
            ],
        )
Example #19
0
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    # TODO: extend for multi-container sessions
    items: MutableMapping[str, PendingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = PendingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                resource_policy=row['resource_policy'],
                resource_opts={},
                requested_slots=ResourceSlot(),
                internal_data=row['internal_data'],
                target_sgroup_names=[],
                environ={
                    k: v for k, v
                    in map(lambda s: s.split('=', maxsplit=1), row['environ'])
                },
                mounts=row['mounts'],
                mount_map=row['mount_map'],
                bootstrap_script=row['bootstrap_script'],
                startup_command=row['startup_command'],
                preopen_ports=row['preopen_ports'],
            )
            items[row['session_id']] = session
        session.kernels.append(KernelInfo(
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.requested_slots += row['occupied_slots']  # type: ignore
        merge_resource(session.resource_opts, row['resource_opts'])  # type: ignore
Example #20
0
    async def _schedule_multi_node_session(
        self,
        sched_ctx: SchedulingContext,
        scheduler: AbstractScheduler,
        agent_db_conn: SAConnection,
        kernel_db_conn: SAConnection,
        sgroup_name: str,
        candidate_agents: Sequence[AgentContext],
        sess_ctx: PendingSession,
        check_results: List[Tuple[str, Union[Exception, PredicateResult]]],
    ) -> Tuple[PendingSession, List[KernelAgentBinding]]:
        # Assign agent resource per kernel in the session.
        log_fmt = _log_fmt.get()
        log_args = _log_args.get()
        agent_query_extra_conds = None
        kernel_agent_bindings: List[KernelAgentBinding] = []
        async with agent_db_conn.begin(isolation_level="REPEATABLE READ"):
            # This outer transaction is rolled back when any exception occurs inside,
            # including scheduling failures of a kernel.
            # It ensures that occupied_slots are recovered when there are partial
            # scheduling failures.
            for kernel in sess_ctx.kernels:
                try:
                    agent_id = scheduler.assign_agent_for_kernel(
                        candidate_agents, kernel)
                    if agent_id is None:
                        raise InstanceNotAvailable
                    async with agent_db_conn.begin_nested():
                        agent_alloc_ctx = await _reserve_agent(
                            sched_ctx,
                            agent_db_conn,
                            sgroup_name,
                            agent_id,
                            kernel.requested_slots,
                            extra_conds=agent_query_extra_conds,
                        )
                        candidate_agents = await _list_agents_by_sgroup(
                            agent_db_conn, sgroup_name)
                except InstanceNotAvailable:
                    log.debug(log_fmt + 'no-available-instances', *log_args)
                    async with kernel_db_conn.begin():
                        await _invoke_failure_callbacks(
                            kernel_db_conn,
                            sched_ctx,
                            sess_ctx,
                            check_results,
                        )
                        query = kernels.update().values({
                            'status_info':
                            "no-available-instances",
                            'status_data':
                            sql_json_increment(kernels.c.status_data,
                                               ('scheduler', 'retries'),
                                               parent_updates={
                                                   'last_try':
                                                   datetime.now(
                                                       tzutc()).isoformat(),
                                               }),
                        }).where(kernels.c.id == kernel.kernel_id)
                        await kernel_db_conn.execute(query)
                    raise
                except Exception as e:
                    log.exception(
                        log_fmt + 'unexpected-error, during agent allocation',
                        *log_args,
                    )
                    async with kernel_db_conn.begin():
                        await _invoke_failure_callbacks(
                            kernel_db_conn,
                            sched_ctx,
                            sess_ctx,
                            check_results,
                        )
                        query = kernels.update().values({
                            'status_info':
                            "scheduler-error",
                            'status_data':
                            convert_to_status_data(e),
                        }).where(kernels.c.id == kernel.kernel_id)
                        await kernel_db_conn.execute(query)
                    raise
                else:
                    kernel_agent_bindings.append(
                        KernelAgentBinding(kernel, agent_alloc_ctx))

        if len(kernel_agent_bindings) == len(sess_ctx.kernels):
            # Proceed to PREPARING only when all kernels are successfully scheduled.
            async with kernel_db_conn.begin():
                for binding in kernel_agent_bindings:
                    query = kernels.update().values({
                        'agent':
                        binding.agent_alloc_ctx.agent_id,
                        'agent_addr':
                        binding.agent_alloc_ctx.agent_addr,
                        'scaling_group':
                        sgroup_name,
                        'status':
                        KernelStatus.PREPARING,
                        'status_info':
                        'scheduled',
                        'status_data': {},
                        'status_changed':
                        datetime.now(tzutc()),
                    }).where(kernels.c.id == binding.kernel.kernel_id)
                    await kernel_db_conn.execute(query)

        return (sess_ctx, kernel_agent_bindings)
Example #21
0
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.registry,
            kernels.c.sess_type,
            kernels.c.sess_id,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    items = []
    async for row in db_conn.execute(query):
        items.append(PendingSession(
            kernel_id=row['id'],
            access_key=row['access_key'],
            session_type=row['sess_type'],
            session_name=row['sess_id'],
            domain_name=row['domain_name'],
            group_id=row['group_id'],
            scaling_group=row['scaling_group'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            resource_policy=row['resource_policy'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
            internal_data=row['internal_data'],
            target_sgroup_names=[],
            environ={
                k: v for k, v
                in map(lambda s: s.split('=', maxsplit=1), row['environ'])
            },
            mounts=row['mounts'],
            mount_map=row['mount_map'],
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            preopen_ports=row['preopen_ports'],
        ))
    return items
Example #22
0
    async def _schedule_in_sgroup(
        self,
        sched_ctx: SchedulingContext,
        agent_db_conn: SAConnection,
        kernel_db_conn: SAConnection,
        sgroup_name: str,
    ) -> List[StartTaskArgs]:
        async with kernel_db_conn.begin():
            scheduler = await self._load_scheduler(kernel_db_conn, sgroup_name)
            pending_sessions = await _list_pending_sessions(
                kernel_db_conn, sgroup_name)
            existing_sessions = await _list_existing_sessions(
                kernel_db_conn, sgroup_name)
        log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})',
                  sgroup_name, len(pending_sessions), len(existing_sessions))
        zero = ResourceSlot()
        args_list: List[StartTaskArgs] = []
        while len(pending_sessions) > 0:
            async with agent_db_conn.begin():
                candidate_agents = await _list_agents_by_sgroup(
                    agent_db_conn, sgroup_name)
                total_capacity = sum(
                    (ag.available_slots for ag in candidate_agents), zero)

            picked_session_id = scheduler.pick_session(
                total_capacity,
                pending_sessions,
                existing_sessions,
            )
            if picked_session_id is None:
                # no session is picked.
                # continue to next sgroup.
                return []
            for picked_idx, sess_ctx in enumerate(pending_sessions):
                if sess_ctx.session_id == picked_session_id:
                    break
            else:
                # no matching entry for picked session?
                raise RuntimeError('should not reach here')
            sess_ctx = pending_sessions.pop(picked_idx)

            log_fmt = 'schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): '
            log_args = (
                sess_ctx.session_id,
                sess_ctx.session_type,
                sess_ctx.session_name,
                sess_ctx.access_key,
                sess_ctx.cluster_mode,
            )
            _log_fmt.set(log_fmt)
            _log_args.set(log_args)
            log.debug(log_fmt + 'try-scheduling', *log_args)
            session_agent_binding: Tuple[PendingSession,
                                         List[KernelAgentBinding]]

            async with kernel_db_conn.begin():
                predicates: Sequence[Tuple[
                    str, Awaitable[PredicateResult]]] = [
                        (
                            'reserved_time',
                            check_reserved_batch_session(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        ('concurrency',
                         check_concurrency(kernel_db_conn, sched_ctx,
                                           sess_ctx)),
                        ('dependencies',
                         check_dependencies(kernel_db_conn, sched_ctx,
                                            sess_ctx)),
                        (
                            'keypair_resource_limit',
                            check_keypair_resource_limit(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        (
                            'user_group_resource_limit',
                            check_group_resource_limit(kernel_db_conn,
                                                       sched_ctx, sess_ctx),
                        ),
                        (
                            'domain_resource_limit',
                            check_domain_resource_limit(
                                kernel_db_conn, sched_ctx, sess_ctx),
                        ),
                        (
                            'scaling_group_resource_limit',
                            check_scaling_group(kernel_db_conn, sched_ctx,
                                                sess_ctx),
                        ),
                    ]
                check_results: List[Tuple[str, Union[Exception,
                                                     PredicateResult]]] = []
                for predicate_name, check_coro in predicates:
                    try:
                        check_results.append((predicate_name, await
                                              check_coro))
                    except Exception as e:
                        log.exception(log_fmt + 'predicate-error', *log_args)
                        check_results.append((predicate_name, e))
            has_failure = False
            # has_permanent_failure = False
            failed_predicates = []
            passed_predicates = []
            for predicate_name, result in check_results:
                if isinstance(result, Exception):
                    has_failure = True
                    failed_predicates.append({
                        'name': predicate_name,
                        'msg': repr(result),
                    })
                    continue
                if result.passed:
                    passed_predicates.append({
                        'name': predicate_name,
                    })
                else:
                    failed_predicates.append({
                        'name': predicate_name,
                        'msg': result.message or "",
                    })
                    has_failure = True
                    # if result.permanent:
                    #     has_permanent_failure = True
            if has_failure:
                log.debug(log_fmt + 'predicate-checks-failed (temporary)',
                          *log_args)
                # TODO: handle has_permanent_failure as cancellation
                #  - An early implementation of it has caused DB query blocking due to
                #    the inclusion of the kernels.status field. :(
                #    Let's fix it.
                async with kernel_db_conn.begin():
                    await _invoke_failure_callbacks(
                        kernel_db_conn,
                        sched_ctx,
                        sess_ctx,
                        check_results,
                    )
                    query = kernels.update().values({
                        'status_info':
                        "predicate-checks-failed",
                        'status_data':
                        sql_json_increment(
                            kernels.c.status_data, ('scheduler', 'retries'),
                            parent_updates={
                                'last_try': datetime.now(tzutc()).isoformat(),
                                'failed_predicates': failed_predicates,
                                'passed_predicates': passed_predicates,
                            }),
                    }).where(kernels.c.id == sess_ctx.session_id)
                    await kernel_db_conn.execute(query)
                # Predicate failures are *NOT* permanent errors.
                # We need to retry the scheduling afterwards.
                continue
            else:
                async with kernel_db_conn.begin():
                    query = kernels.update().values({
                        'status_data':
                        sql_json_merge(
                            kernels.c.status_data, ('scheduler', ), {
                                'last_try': datetime.now(tzutc()).isoformat(),
                                'failed_predicates': failed_predicates,
                                'passed_predicates': passed_predicates,
                            }),
                    }).where(kernels.c.id == sess_ctx.session_id)
                    await kernel_db_conn.execute(query)

            if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE:
                session_agent_binding = await self._schedule_single_node_session(
                    sched_ctx,
                    scheduler,
                    agent_db_conn,
                    kernel_db_conn,
                    sgroup_name,
                    candidate_agents,
                    sess_ctx,
                    check_results,
                )
            elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE:
                session_agent_binding = await self._schedule_multi_node_session(
                    sched_ctx,
                    scheduler,
                    agent_db_conn,
                    kernel_db_conn,
                    sgroup_name,
                    candidate_agents,
                    sess_ctx,
                    check_results,
                )
            else:
                raise RuntimeError(
                    f"should not reach here; unknown cluster_mode: {sess_ctx.cluster_mode}"
                )
            args_list.append((
                log_args,
                sched_ctx,
                session_agent_binding,
                check_results,
            ))

        return args_list