Exemplo n.º 1
0
def connect_with_data_source(driver_dsn):
    driver, dsn = driver_dsn.split("://")
    if driver == "mysql":
        # NOTE: use MySQLdb to avoid bugs like infinite reading:
        # https://bugs.mysql.com/bug.php?id=91971
        from MySQLdb import connect
        user, passwd, host, port, database, config = parseMySQLDSN(dsn)
        conn = connect(user=user,
                       passwd=passwd,
                       db=database,
                       host=host,
                       port=int(port))
    elif driver == "hive":
        from impala.dbapi import connect
        user, passwd, host, port, database, auth, session_cfg = parseHiveDSN(
            dsn)
        conn = connect(user=user,
                       password=passwd,
                       database=database,
                       host=host,
                       port=int(port),
                       auth_mechanism=auth)
        conn.session_cfg = session_cfg
        conn.default_db = database
    elif driver == "maxcompute":
        from runtime.maxcompute import MaxCompute
        user, passwd, address, database = parseMaxComputeDSN(dsn)
        conn = MaxCompute.connect(database, user, passwd, address)
    else:
        raise ValueError(
            "connect_with_data_source doesn't support driver type {}".format(
                driver))

    conn.driver = driver
    return conn
Exemplo n.º 2
0
def selected_columns_and_types(conn, select):
    """Get the columns and types returned by the select statement.

    Args:
        conn: the connection object.
        select (str): the select SQL statement.

    Returns:
        A tuple whose each element is (column_name, column_type).
    """
    select = select.strip().rstrip(";")
    select = limit_select(select, 1)

    driver = conn.driver
    if driver == "mysql":
        cursor = conn.cursor()
        cursor.execute(select)
        try:
            return _get_mysql_columns_and_types(cursor)
        finally:
            cursor.close()

    if driver == "hive":
        cursor = conn.cursor(configuration=conn.session_cfg)
        cursor.execute(select)
        name_and_type = _get_hive_columns_and_types(cursor)
        cursor.close()
        return name_and_type

    if driver == "maxcompute":
        from runtime.maxcompute import MaxCompute
        return MaxCompute.selected_columns_and_types(conn, select)

    raise NotImplementedError("unsupported driver {}".format(driver))
Exemplo n.º 3
0
def db_generator(conn, statement, label_meta=None, fetch_size=128):
    driver = conn.driver

    def reader():
        if driver == "hive":
            cursor = conn.cursor(configuration=conn.session_cfg)
        else:
            cursor = conn.cursor()
        cursor.execute(statement)
        if driver == "hive":
            field_names = None if cursor.description is None \
                else [i[0][i[0].find('.') + 1:] for i in cursor.description]
        else:
            field_names = None if cursor.description is None \
                else [i[0] for i in cursor.description]

        if label_meta:
            try:
                label_idx = field_names.index(label_meta["feature_name"])
            except ValueError:
                # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting.
                label_idx = None
        else:
            label_idx = None

        while True:
            rows = cursor.fetchmany(size=fetch_size)
            if not rows:
                break
            # NOTE: keep the connection while training or connection will lost if no activities appear.
            if driver == "mysql":
                conn.ping(True)
            for row in rows:
                # NOTE: If there is no label clause in the extended SQL, the default label value would
                # be -1, the Model implementation can determine use it or not.
                label = row[label_idx] if label_idx is not None else -1
                if label_meta and label_meta["delimiter"] != "":
                    if label_meta["dtype"] == "float32":
                        label = np.fromstring(label,
                                              dtype=float,
                                              sep=label_meta["delimiter"])
                    elif label_meta["dtype"] == "int64":
                        label = np.fromstring(label,
                                              dtype=int,
                                              sep=label_meta["delimiter"])
                if label_idx is None:
                    yield list(row), None
                else:
                    yield list(row), label
            if len(rows) < fetch_size:
                break
        cursor.close()

    if driver == "maxcompute":
        from runtime.maxcompute import MaxCompute
        return MaxCompute.db_generator(conn, statement, label_meta, fetch_size)
    if driver == "hive":
        # trip the suffix ';' to avoid the ParseException in hive
        statement = statement.rstrip(';')
    return reader
Exemplo n.º 4
0
def connect(driver,
            database,
            user,
            password,
            host,
            port,
            session_cfg={},
            auth=""):
    if driver == "mysql":
        # NOTE: use MySQLdb to avoid bugs like infinite reading:
        # https://bugs.mysql.com/bug.php?id=91971
        from MySQLdb import connect
        conn = connect(user=user,
                       passwd=password,
                       db=database,
                       host=host,
                       port=int(port))
    elif driver == "hive":
        from impala.dbapi import connect
        conn = connect(user=user,
                       password=password,
                       database=database,
                       host=host,
                       port=int(port),
                       auth_mechanism=auth)
        conn.default_db = database
        conn.session_cfg = session_cfg
    elif driver == "maxcompute":
        from runtime.maxcompute import MaxCompute
        conn = MaxCompute.connect(database, user, password, host)
    else:
        raise ValueError("unrecognized database driver: %s" % driver)

    conn.driver = driver
    return conn
Exemplo n.º 5
0
def selected_cols(conn, select):
    select = select.strip().rstrip(";")
    limited = re.findall("LIMIT [0-9]*$", select.upper())
    if not limited:
        select += " LIMIT 1"

    driver = conn.driver
    if driver == "hive":
        cursor = conn.cursor(configuration=conn.session_cfg)
        cursor.execute(select)
        field_names = None if cursor.description is None \
            else [i[0][i[0].find('.') + 1:] for i in cursor.description]
        cursor.close()
    elif driver == "maxcompute":
        from runtime.maxcompute import MaxCompute
        field_names = MaxCompute.selected_cols(conn, select)
    else:
        cursor = conn.cursor()
        cursor.execute(select)
        field_names = None if cursor.description is None \
            else [i[0] for i in cursor.description]
        cursor.close()
    return field_names