Exemple #1
0
    def build_query(
        self,
        container: IContainer,
        query: ParsedQueryInfo,
        select_fields: typing.List[str],
        distinct: typing.Optional[bool] = False,
    ) -> typing.Tuple[str, typing.List[typing.Any]]:
        if query["sort_on"] is None:
            # always need a sort otherwise paging never works
            order_by_index = get_pg_index("uuid")
        else:
            order_by_index = get_pg_index(query["sort_on"]) or BasicJsonIndex(
                query["sort_on"])

        sql_arguments = []
        sql_wheres = []
        arg_index = 1
        for idx, select in enumerate(query["selects"]):
            select_fields.append(select.format(arg=arg_index))
            sql_arguments.append(query["selects_arguments"][idx])
            arg_index += 1

        where_arg_index = 0
        for where in query["wheres"]:
            if isinstance(where, tuple):
                operator, sub_wheres = where
                sub_result = []
                for sub_where in sub_wheres:
                    sub_result.append(
                        sub_where.format(arg=arg_index + where_arg_index))
                    sql_arguments.append(
                        query["wheres_arguments"][where_arg_index])
                    where_arg_index += 1
                sql_wheres.append("(" + operator.join(sub_result) + ")")
            else:
                sql_wheres.append(where.format(arg=arg_index +
                                               where_arg_index))
                sql_arguments.append(
                    query["wheres_arguments"][where_arg_index])
                where_arg_index += 1

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        sql_wheres.extend(self.get_default_where_clauses(container))

        sql = """select {} {}
                 from {}
                 where {}
                 {}
                 limit {} offset {}""".format(
            "distinct" if distinct else "",
            ",".join(select_fields),
            sqlq(txn.storage.objects_table_name),
            " AND ".join(sql_wheres),
            "" if distinct else order_by_index.order_by(query["sort_dir"]),
            sqlq(query["size"]),
            sqlq(query["_from"]),
        )
        return sql, sql_arguments
Exemple #2
0
    async def aggregation(self, container: IContainer, query: ParsedQueryInfo):
        select_fields = [
            "json->'" + sqlq(field) + "' as " + sqlq(field)
            for field in query["metadata"] or []
        ]  # noqa
        sql, arguments = self.build_query(container, query, select_fields,
                                          True)

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        conn = await txn.get_connection()

        results = []
        logger.debug(f"Running search:\n{sql}\n{arguments}")
        async with txn.lock:
            records = await conn.fetch(sql, *arguments)
        for record in records:
            results.append([
                json.loads(record[field]) for field in query["metadata"] or []
            ])

        total = len(results)
        if total >= query["size"] or query["_from"] != 0:
            sql, arguments = self.build_count_query(container, query)
            logger.debug(f"Running search:\n{sql}\n{arguments}")
            async with txn.lock:
                records = await conn.fetch(sql, *arguments)
            total = records[0]["count"]
        return {"items": results, "items_total": total}
Exemple #3
0
    def build_count_query(
        self,
        context,
        query: ParsedQueryInfo,
        unrestricted: bool = False,
    ) -> typing.Tuple[str, typing.List[typing.Any]]:
        sql_arguments = []
        sql_wheres = []
        select_fields = ["count(*)"]
        arg_index = 1
        for idx, where in enumerate(query["wheres"]):
            sql_wheres.append(where.format(arg=arg_index))
            sql_arguments.append(query["wheres_arguments"][idx])
            arg_index += 1

        sql_wheres.extend(
            self.get_default_where_clauses(context, unrestricted=unrestricted))

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        sql = """select {}
                 from {}
                 where {}""".format(",".join(select_fields),
                                    sqlq(txn.storage.objects_table_name),
                                    " AND ".join(sql_wheres))
        return sql, sql_arguments
Exemple #4
0
    async def search(self, container: IContainer,
                     query: ParsedQueryInfo):  # type: ignore
        sql, arguments = self.build_query(container, query,
                                          ['id', 'zoid', 'json'])
        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        conn = await txn.get_connection()

        results = []
        try:
            context_url = get_object_url(container)
        except RequestNotFound:
            context_url = get_content_path(container)
        logger.debug(f'Running search:\n{sql}\n{arguments}')
        for record in await conn.fetch(sql, *arguments):
            data = json.loads(record['json'])
            result = self.load_meatdata(query, data)
            result['@name'] = record['id']
            result['@uid'] = record['zoid']
            result['@id'] = data['@absolute_url'] = context_url + data['path']
            results.append(result)

        # also do count...
        total = len(results)
        if total >= query['size'] or query['_from'] != 0:
            sql, arguments = self.build_count_query(container, query)
            logger.debug(f'Running search:\n{sql}\n{arguments}')
            records = await conn.fetch(sql, *arguments)
            total = records[0]['count']
        return {'member': results, 'items_count': total}
Exemple #5
0
 def _get_transaction(self) -> ITransaction:
     txn = get_transaction()
     if txn is not None:
         return txn
     if self.__txn__ is not None:
         return self.__txn__
     raise TransactionNotFound()
Exemple #6
0
    async def aggregation(self, container: IContainer, query: ParsedQueryInfo):
        select_fields = [
            'json->\'' + sqlq(field) + '\' as ' + sqlq(field)
            for field in query['metadata'] or []
        ]  # noqa
        sql, arguments = self.build_query(container, query, select_fields,
                                          True)

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        conn = await txn.get_connection()

        results = []
        logger.debug(f'Running search:\n{sql}\n{arguments}')
        for record in await conn.fetch(sql, *arguments):
            results.append([
                json.loads(record[field]) for field in query['metadata'] or []
            ])

        total = len(results)
        if total >= query['size'] or query['_from'] != 0:
            sql, arguments = self.build_count_query(container, query)
            logger.debug(f'Running search:\n{sql}\n{arguments}')
            records = await conn.fetch(sql, *arguments)
            total = records[0]['count']
        return {'member': results, 'items_count': total}
Exemple #7
0
 def add_job_after_commit(self, func: typing.Callable[[], typing.Coroutine], args=None, kwargs=None):
     txn = get_transaction()
     if txn is not None:
         txn.add_after_commit_hook(
             self._add_job_after_commit, args=[func], kws={"args": args, "kwargs": kwargs}
         )
     else:
         raise TransactionNotFound("Could not find transaction to run job with")
Exemple #8
0
    def build_query(
        self,
        container: IContainer,
        query: ParsedQueryInfo,
        select_fields: typing.List[str],
        distinct: typing.Optional[bool] = False
    ) -> typing.Tuple[str, typing.List[typing.Any]]:
        if query['sort_on'] is None:
            # always need a sort otherwise paging never works
            order_by_index = get_pg_index('uuid')
        else:
            order_by_index = get_pg_index(query['sort_on']) or BasicJsonIndex(
                query['sort_on'])

        sql_arguments = []
        sql_wheres = []
        arg_index = 1
        for idx, select in enumerate(query['selects']):
            select_fields.append(select.format(arg=arg_index))
            sql_arguments.append(query['selects_arguments'][idx])
            arg_index += 1

        where_arg_index = 0
        for where in query['wheres']:
            if isinstance(where, tuple):
                operator, sub_wheres = where
                sub_result = []
                for sub_where in sub_wheres:
                    sub_result.append(
                        sub_where.format(arg=arg_index + where_arg_index))
                    sql_arguments.append(
                        query['wheres_arguments'][where_arg_index])
                    where_arg_index += 1
                sql_wheres.append('(' + operator.join(sub_result) + ')')
            else:
                sql_wheres.append(where.format(arg=arg_index +
                                               where_arg_index))
                sql_arguments.append(
                    query['wheres_arguments'][where_arg_index])
                where_arg_index += 1

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        sql_wheres.extend(self.get_default_where_clauses(container))

        sql = '''select {} {}
                 from {}
                 where {}
                 {}
                 limit {} offset {}'''.format(
            'distinct' if distinct else '', ','.join(select_fields),
            sqlq(txn.storage.objects_table_name), ' AND '.join(sql_wheres),
            '' if distinct else order_by_index.order_by(query['sort_dir']),
            sqlq(query['size']), sqlq(query['_from']))
        return sql, sql_arguments
Exemple #9
0
 def register(self, prefer_local=False):
     if not prefer_local:
         txn = get_transaction()
         if txn is not None:
             txn.register(self)
             return
     if self.__txn__ is not None:
         self.__txn__.register(self)
         return
     raise TransactionNotFound()
Exemple #10
0
def get_current_transaction() -> ITransaction:
    """
    Return the current request by heuristically looking it up from stack
    """
    try:
        task_context = task_vars.txn.get()
        if task_context is not None:
            return task_context
    except (ValueError, AttributeError, RuntimeError):
        pass

    raise TransactionNotFound(TransactionNotFound.__doc__)
Exemple #11
0
def after_commit(func: Callable, *args, **kwargs):
    '''
    Execute a commit to the database.

    :param func: function to be queued
    :param \\*args: arguments to call the func with
    :param \\**kwargs: keyword arguments to call the func with
    '''
    kwargs.pop('_request', None)  # b/w compat pop unused param
    txn = get_transaction()
    if txn is not None:
        txn.add_after_commit_hook(func, args=args, kwargs=kwargs)
    else:
        raise TransactionNotFound('Could not find transaction to run job with')
Exemple #12
0
def before_commit(func: Callable[..., Coroutine[Any, Any, Any]], *args,
                  **kwargs):
    '''
    Execute before a commit to the database.

    :param func: function to be queued
    :param _request: provide request object to prevent request lookup
    :param \\*args: arguments to call the func with
    :param \\**kwargs: keyword arguments to call the func with
    '''
    kwargs.pop('_request', None)  # b/w compat pop unused param
    txn = get_transaction()
    if txn is not None:
        txn.add_before_commit_hook(func, args=args, kwargs=kwargs)
    else:
        raise TransactionNotFound('Could not find transaction to run job with')
Exemple #13
0
 def add_job_after_commit(self,
                          func: typing.Callable[[], typing.Coroutine],
                          request=None,
                          args=None,
                          kwargs=None):
     txn = get_transaction()
     if txn is not None:
         txn.add_after_commit_hook(self._add_job_after_commit,
                                   args=[func],
                                   kws={
                                       'request': request,
                                       'args': args,
                                       'kwargs': kwargs
                                   })
     else:
         raise TransactionNotFound(
             'Could not find transaction to run job with')
Exemple #14
0
    async def get_user(self,
                       token: typing.Dict) -> typing.Optional[IPrincipal]:
        try:
            container = get_current_container()
            users = await container.async_get("users")
        except (AttributeError, KeyError, ContainerNotFound):
            return None

        catalog = query_utility(ICatalogUtility)
        if not isinstance(catalog, PGSearchUtility):
            raise NoCatalogException()

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        conn = await txn.get_connection()
        # The catalog doesn't work because we are still
        # not authenticated
        sql = f"""
            SELECT id FROM
                {txn.storage.objects_table_name}
            WHERE
              json->>'type_name' = 'User'
              AND parent_id != 'DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD'
              AND json->>'container_id' = $1::varchar
              AND lower(json->>'user_email') = lower($2::varchar)
        """
        async with txn.lock:
            row = await conn.fetchrow(sql, container.id, token.get("id"))
        if not row:
            return None

        user = await users.async_get(row["id"])
        if user.disabled:
            # User is disabled
            return None

        # Load groups into cache
        for ident in user.groups:
            try:
                user._groups_cache[ident] = await navigate_to(
                    container, f"groups/{ident}")
            except KeyError:
                continue
        return user
Exemple #15
0
    def build_count_query(
            self, context, query: ParsedQueryInfo
    ) -> typing.Tuple[str, typing.List[typing.Any]]:
        sql_arguments = []
        sql_wheres = []
        select_fields = ['count(*)']
        arg_index = 1
        for idx, where in enumerate(query['wheres']):
            sql_wheres.append(where.format(arg=arg_index))
            sql_arguments.append(query['wheres_arguments'][idx])
            arg_index += 1

        sql_wheres.extend(self.get_default_where_clauses(context))

        txn = get_transaction()
        if txn is None:
            raise TransactionNotFound()
        sql = '''select {}
                 from {}
                 where {}'''.format(','.join(select_fields),
                                    sqlq(txn.storage.objects_table_name),
                                    ' AND '.join(sql_wheres))
        return sql, sql_arguments
Exemple #16
0
 async def get_root(self, txn=None) -> IBaseObject:
     if txn is None:
         txn = task_vars.txn.get()
         if txn is None:
             raise TransactionNotFound()
     return await txn.get(ROOT_ID)