예제 #1
0
def start_presto_query(presto_server, presto_user, presto_catalog, presto_schema, function_name, query):
    try:
        # preserve search_path if explicitly set
        search_path = _get_session_search_path_array()
        if search_path != ['$user', 'public'] and len(search_path) > 0:
            # search_path is changed explicitly. use the first schema
            presto_schema = search_path[0]

        # start query
        client = presto_client.Client(server=presto_server, user=presto_user, catalog=presto_catalog, schema=presto_schema, time_zone=_get_session_time_zone())

        query = client.query(query)
        session.query_auto_close = QueryAutoClose(query)
        try:
            # result schema
            column_names = []
            column_types = []
            for column in query.columns():
                column_names.append(column.name)
                column_types.append(_pg_result_type(column.type))

            column_names = _rename_duplicated_column_names(column_names, "a query result")
            session.query_auto_close.column_names = column_names
            session.query_auto_close.column_types = column_types

            # CREATE TABLE for return type of the function
            type_name = function_name + "_type"
            create_type_sql = _build_create_temp_table_sql(type_name, column_names, column_types)

            # CREATE FUNCTION
            create_function_sql = \
                """
                create or replace function pg_temp.%s()
                returns setof pg_temp.%s as $$
                    import prestogres
                    return prestogres.fetch_presto_query_results()
                $$ language plpythonu
                """ % \
                (plpy.quote_ident(function_name), plpy.quote_ident(type_name))

            # run statements
            plpy.execute("drop table if exists pg_temp.%s cascade" % \
                    (plpy.quote_ident(type_name)))
            plpy.execute(create_type_sql)
            plpy.execute(create_function_sql)

            query = None

        finally:
            if query is not None:
                # close query
                session.query_auto_close = None

    except (plpy.SPIError, presto_client.PrestoException) as e:
        # PL/Python converts an exception object in Python to an error message in PostgreSQL
        # using exception class name if exc.__module__ is either of "builtins", "exceptions",
        # or "__main__". Otherwise using "module.name" format. Set __module__ = "__module__"
        # to generate pretty messages.
        e.__class__.__module__ = "__main__"
        raise
예제 #2
0
def setup_system_catalog(presto_server, presto_user, presto_catalog, presto_schema, access_role):
    search_path = _get_session_search_path_array()
    if search_path == ['$user', 'public']:
        # search_path is default value.
        plpy.execute("set search_path to %s" % plpy.quote_ident(presto_schema))

    client = presto_client.Client(server=presto_server, user=presto_user, catalog=presto_catalog, schema='default')

    # get table list
    sql = "select table_schema, table_name, column_name, is_nullable, data_type" \
          " from information_schema.columns"
    columns, rows = client.run(sql)
    if rows is None:
        rows = []

    schemas = {}

    for row in rows:
        schema_name = row[0]
        table_name = row[1]
        column_name = row[2]
        is_nullable = row[3]
        column_type = row[4]

        if schema_name == "sys" or schema_name == "information_schema":
            # skip system schemas
            continue

        if len(schema_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Schema %s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), PG_NAMEDATALEN - 1))
            continue

        tables = schemas.setdefault(schema_name, {})

        if len(table_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Table %s.%s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), plpy.quote_ident(table_name), PG_NAMEDATALEN - 1))
            continue

        columns = tables.setdefault(table_name, [])

        if len(column_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Column %s.%s.%s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), plpy.quote_ident(table_name), \
                     plpy.quote_ident(column_name), PG_NAMEDATALEN - 1))
            continue

        columns.append(Column(column_name, column_type, is_nullable))

    # drop all schemas excepting prestogres_catalog, information_schema and pg_%
    sql = "select n.nspname as schema_name from pg_catalog.pg_namespace n" \
          " where n.nspname not in ('prestogres_catalog', 'information_schema')" \
          " and n.nspname not like 'pg_%'"
    for row in plpy.cursor(sql):
        plpy.execute("drop schema %s cascade" % plpy.quote_ident(row["schema_name"]))

    # create schema and tables
    for schema_name, tables in sorted(schemas.items(), key=lambda (k,v): k):
        try:
            plpy.execute("create schema %s" % (plpy.quote_ident(schema_name)))
        except:
            # ignore error?
            pass

        for table_name, columns in sorted(tables.items(), key=lambda (k,v): k):
            column_names = []
            column_types = []
            not_nulls = []

            if len(columns) >= 1600:
                plpy.warning("Table %s.%s contains more than 1600 columns. Some columns will be inaccessible" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))

            for column in columns[0:1600]:
                column_names.append(column.name)
                column_types.append(_pg_table_type(column.type))
                not_nulls.append(not column.nullable)

            # change columns
            column_names = _rename_duplicated_column_names(column_names,
                    "%s.%s table" % (plpy.quote_ident(schema_name), plpy.quote_ident(table_name)))
            create_sql = _build_create_table(schema_name, table_name, column_names, column_types, not_nulls)
            plpy.execute(create_sql)

        # grant access on the schema to the restricted user so that
        # pg_table_is_visible(reloid) used by \d of psql command returns true
        plpy.execute("grant usage on schema %s to %s" % \
                (plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))
        # this SELECT privilege is unnecessary because queries against those tables
        # won't run on PostgreSQL. causing an exception is good if Prestogres has
        # a bug sending a presto query to PostgreSQL without rewriting.
        # TODO however, it's granted for now because some BI tools might check
        #      has_table_privilege. the best solution is to grant privilege but
        #      actually selecting from those tables causes an exception.
        plpy.execute("grant select on all tables in schema %s to %s" % \
                (plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))

    # fake current_database() to return Presto's catalog name to be compatible with some
    # applications that use db.schema.table syntax to identify a table
    if plpy.execute("select pg_catalog.current_database()")[0].values()[0] != presto_catalog:
        plpy.execute("delete from pg_catalog.pg_proc where proname='current_database'")
        plpy.execute("create function pg_catalog.current_database() returns name as $$begin return %s::name; end$$ language plpgsql stable strict" % \
                plpy.quote_literal(presto_catalog))
예제 #3
0
def setup_system_catalog(presto_server, presto_user, presto_catalog, access_role):
    client = presto_client.Client(server=presto_server, user=presto_user, catalog=presto_catalog, schema='default')

    # get table list
    sql = "select table_schema, table_name, column_name, is_nullable, data_type" \
          " from information_schema.columns"
    columns, rows = client.run(sql)
    if rows is None:
        rows = []

    schemas = {}

    for row in rows:
        schema_name = row[0]
        table_name = row[1]
        column_name = row[2]
        is_nullable = row[3]
        column_type = row[4]

        if schema_name == "sys" or schema_name == "information_schema":
            # skip system schemas
            continue

        if len(schema_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Schema %s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), PG_NAMEDATALEN - 1))
            continue

        tables = schemas.setdefault(schema_name, {})

        if len(table_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Table %s.%s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), plpy.quote_ident(table_name), PG_NAMEDATALEN - 1))
            continue

        columns = tables.setdefault(table_name, [])

        if len(column_name) > PG_NAMEDATALEN - 1:
            plpy.warning("Column %s.%s.%s is skipped because its name is longer than %d characters" % \
                    (plpy.quote_ident(schema_name), plpy.quote_ident(table_name), \
                     plpy.quote_ident(column_name), PG_NAMEDATALEN - 1))
            continue

        columns.append(Column(column_name, column_type, is_nullable))

    # drop all schemas excepting prestogres_catalog, information_schema and pg_%
    sql = "select n.nspname as schema_name from pg_catalog.pg_namespace n" \
          " where n.nspname not in ('prestogres_catalog', 'information_schema')" \
          " and n.nspname not like 'pg_%'"
    for row in plpy.cursor(sql):
        plpy.execute("drop schema %s cascade" % plpy.quote_ident(row["schema_name"]))

    # create schema and tables
    for schema_name, tables in sorted(schemas.items(), key=lambda (k,v): k):
        try:
            plpy.execute("create schema %s" % (plpy.quote_ident(schema_name)))
        except:
            # ignore error?
            pass

        # grant access on the all tables to the restricted user
        plpy.execute("grant select on all tables in schema %s to %s" % \
                (plpy.quote_ident(schema_name), plpy.quote_ident(access_role)))

        for table_name, columns in sorted(tables.items(), key=lambda (k,v): k):
            column_names = []
            column_types = []
            not_nulls = []
            for column in columns:
                column_names.append(column.name)
                column_types.append(_pg_table_type(column.type))
                not_nulls.append(not column.nullable)

            # change columns
            create_sql = _build_create_table(schema_name, table_name, column_names, column_types, not_nulls)
            plpy.execute(create_sql)

    # update pg_database
    plpy.execute("update pg_database set datname=%s where datname=current_database()" % \
            plpy.quote_literal(presto_catalog))