def db_generator(driver, conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size=128): def read_feature(raw_val, feature_spec, feature_name): # FIXME(typhoonzero): Should use correct dtype here. if feature_spec["is_sparse"]: indices = np.fromstring(raw_val, dtype=int, sep=feature_spec["delimiter"]) indices = indices.reshape(indices.size, 1) values = np.ones([indices.size], dtype=np.int32) dense_shape = np.array(feature_spec["shape"], dtype=np.int64) return (indices, values, dense_shape) else: # Dense string vector if feature_spec["delimiter"] != "": if feature_spec["dtype"] == "float32": return np.fromstring(raw_val, dtype=float, sep=feature_spec["delimiter"]) elif feature_spec["dtype"] == "int64": return np.fromstring(raw_val, dtype=int, sep=feature_spec["delimiter"]) else: raise ValueError('unrecognize dtype {}'.format( feature_spec[feature_name]["dtype"])) else: return (raw_val, ) def reader(): if driver == "hive": cursor = conn.cursor(configuration=conn.session_cfg) else: cursor = conn.cursor() cursor.execute(statement) if driver == "hive": field_names = None if cursor.description is None \ else [i[0][i[0].find('.') + 1:] for i in cursor.description] else: field_names = None if cursor.description is None \ else [i[0] for i in cursor.description] if label_column_name: try: label_idx = field_names.index(label_column_name) except ValueError: # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting. label_idx = None else: label_idx = None while True: rows = cursor.fetchmany(size=fetch_size) if not rows: break # NOTE: keep the connection while training or connection will lost if no activities appear. if driver == "mysql": conn.ping(True) for row in rows: # NOTE: If there is no label clause in the extened SQL, the default label value would # be -1, the Model implementation can determine use it or not. label = row[label_idx] if label_idx is not None else -1 features = [] for name in feature_column_names: feature = read_feature(row[field_names.index(name)], feature_specs[name], name) features.append(feature) if label_idx is None: yield (tuple(features), ) else: yield tuple(features), label if len(rows) < fetch_size: break cursor.close() if driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.db_generator(conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size) if driver == "hive": # trip the suffix ';' to avoid the ParseException in hive statement = statement.rstrip(';') return reader
def db_generator(driver, conn, statement, feature_column_names, label_spec, feature_specs, fetch_size=128): def reader(): if driver == "hive": cursor = conn.cursor(configuration=conn.session_cfg) else: cursor = conn.cursor() cursor.execute(statement) if driver == "hive": field_names = None if cursor.description is None \ else [i[0][i[0].find('.') + 1:] for i in cursor.description] else: field_names = None if cursor.description is None \ else [i[0] for i in cursor.description] if label_spec: try: label_idx = field_names.index(label_spec["feature_name"]) except ValueError: # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting. label_idx = None else: label_idx = None while True: rows = cursor.fetchmany(size=fetch_size) if not rows: break # NOTE: keep the connection while training or connection will lost if no activities appear. if driver == "mysql": conn.ping(True) for row in rows: # NOTE: If there is no label clause in the extended SQL, the default label value would # be -1, the Model implementation can determine use it or not. label = row[label_idx] if label_idx is not None else -1 if label_spec and label_spec["delimiter"] != "": if label_spec["dtype"] == "float32": label = np.fromstring(label, dtype=float, sep=label_spec["delimiter"]) elif label_spec["dtype"] == "int64": label = np.fromstring(label, dtype=int, sep=label_spec["delimiter"]) if label_idx is None: yield list(row), None else: yield list(row), label if len(rows) < fetch_size: break cursor.close() if driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.db_generator(conn, statement, feature_column_names, label_spec, feature_specs, fetch_size) if driver == "hive": # trip the suffix ';' to avoid the ParseException in hive statement = statement.rstrip(';') return reader