def connect_with_data_source(driver_dsn): driver, dsn = driver_dsn.split("://") if driver == "mysql": # NOTE: use MySQLdb to avoid bugs like infinite reading: # https://bugs.mysql.com/bug.php?id=91971 from MySQLdb import connect user, passwd, host, port, database, config = parseMySQLDSN(dsn) conn = connect(user=user, passwd=passwd, db=database, host=host, port=int(port)) elif driver == "hive": from impala.dbapi import connect user, passwd, host, port, database, auth, session_cfg = parseHiveDSN( dsn) conn = connect(user=user, password=passwd, database=database, host=host, port=int(port), auth_mechanism=auth) conn.session_cfg = session_cfg conn.default_db = database elif driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute user, passwd, address, database = parseMaxComputeDSN(dsn) conn = MaxCompute.connect(database, user, passwd, address) else: raise ValueError( "connect_with_data_source doesn't support driver type {}".format( driver)) conn.driver = driver return conn
def connect(driver, database, user, password, host, port, session_cfg={}, auth=""): if driver == "mysql": # NOTE: use MySQLdb to avoid bugs like infinite reading: # https://bugs.mysql.com/bug.php?id=91971 from MySQLdb import connect return connect(user=user, passwd=password, db=database, host=host, port=int(port)) elif driver == "hive": from impala.dbapi import connect conn = connect(user=user, password=password, database=database, host=host, port=int(port), auth_mechanism=auth) conn.default_db = database conn.session_cfg = session_cfg return conn elif driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.connect(database, user, password, host) raise ValueError("unrecognized database driver: %s" % driver)
def db_generator(driver, conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size=128): def read_feature(raw_val, feature_spec, feature_name): # FIXME(typhoonzero): Should use correct dtype here. if feature_spec["is_sparse"]: indices = np.fromstring(raw_val, dtype=int, sep=feature_spec["delimiter"]) indices = indices.reshape(indices.size, 1) values = np.ones([indices.size], dtype=np.int32) dense_shape = np.array(feature_spec["shape"], dtype=np.int64) return (indices, values, dense_shape) else: # Dense string vector if feature_spec["delimiter"] != "": if feature_spec["dtype"] == "float32": return np.fromstring(raw_val, dtype=float, sep=feature_spec["delimiter"]) elif feature_spec["dtype"] == "int64": return np.fromstring(raw_val, dtype=int, sep=feature_spec["delimiter"]) else: raise ValueError('unrecognize dtype {}'.format( feature_spec[feature_name]["dtype"])) else: return (raw_val, ) def reader(): if driver == "hive": cursor = conn.cursor(configuration=conn.session_cfg) else: cursor = conn.cursor() cursor.execute(statement) if driver == "hive": field_names = None if cursor.description is None \ else [i[0][i[0].find('.') + 1:] for i in cursor.description] else: field_names = None if cursor.description is None \ else [i[0] for i in cursor.description] if label_column_name: try: label_idx = field_names.index(label_column_name) except ValueError: # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting. label_idx = None else: label_idx = None while True: rows = cursor.fetchmany(size=fetch_size) if not rows: break # NOTE: keep the connection while training or connection will lost if no activities appear. if driver == "mysql": conn.ping(True) for row in rows: # NOTE: If there is no label clause in the extened SQL, the default label value would # be -1, the Model implementation can determine use it or not. label = row[label_idx] if label_idx is not None else -1 features = [] for name in feature_column_names: feature = read_feature(row[field_names.index(name)], feature_specs[name], name) features.append(feature) if label_idx is None: yield (tuple(features), ) else: yield tuple(features), label if len(rows) < fetch_size: break cursor.close() if driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.db_generator(conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size) if driver == "hive": # trip the suffix ';' to avoid the ParseException in hive statement = statement.rstrip(';') return reader
def db_generator(driver, conn, statement, feature_column_names, label_spec, feature_specs, fetch_size=128): def reader(): if driver == "hive": cursor = conn.cursor(configuration=conn.session_cfg) else: cursor = conn.cursor() cursor.execute(statement) if driver == "hive": field_names = None if cursor.description is None \ else [i[0][i[0].find('.') + 1:] for i in cursor.description] else: field_names = None if cursor.description is None \ else [i[0] for i in cursor.description] if label_spec: try: label_idx = field_names.index(label_spec["feature_name"]) except ValueError: # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting. label_idx = None else: label_idx = None while True: rows = cursor.fetchmany(size=fetch_size) if not rows: break # NOTE: keep the connection while training or connection will lost if no activities appear. if driver == "mysql": conn.ping(True) for row in rows: # NOTE: If there is no label clause in the extended SQL, the default label value would # be -1, the Model implementation can determine use it or not. label = row[label_idx] if label_idx is not None else -1 if label_spec and label_spec["delimiter"] != "": if label_spec["dtype"] == "float32": label = np.fromstring(label, dtype=float, sep=label_spec["delimiter"]) elif label_spec["dtype"] == "int64": label = np.fromstring(label, dtype=int, sep=label_spec["delimiter"]) if label_idx is None: yield list(row), None else: yield list(row), label if len(rows) < fetch_size: break cursor.close() if driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.db_generator(conn, statement, feature_column_names, label_spec, feature_specs, fetch_size) if driver == "hive": # trip the suffix ';' to avoid the ParseException in hive statement = statement.rstrip(';') return reader