Пример #1
0
def db_generator(driver,
                 conn,
                 statement,
                 feature_column_names,
                 label_column_name,
                 feature_specs,
                 fetch_size=128):
    def read_feature(raw_val, feature_spec, feature_name):
        # FIXME(typhoonzero): Should use correct dtype here.
        if feature_spec["is_sparse"]:
            indices = np.fromstring(raw_val,
                                    dtype=int,
                                    sep=feature_spec["delimiter"])
            indices = indices.reshape(indices.size, 1)
            values = np.ones([indices.size], dtype=np.int32)
            dense_shape = np.array(feature_spec["shape"], dtype=np.int64)
            return (indices, values, dense_shape)
        else:
            # Dense string vector
            if feature_spec["delimiter"] != "":
                if feature_spec["dtype"] == "float32":
                    return np.fromstring(raw_val,
                                         dtype=float,
                                         sep=feature_spec["delimiter"])
                elif feature_spec["dtype"] == "int64":
                    return np.fromstring(raw_val,
                                         dtype=int,
                                         sep=feature_spec["delimiter"])
                else:
                    raise ValueError('unrecognize dtype {}'.format(
                        feature_spec[feature_name]["dtype"]))
            else:
                return (raw_val, )

    def reader():
        if driver == "hive":
            cursor = conn.cursor(configuration=conn.session_cfg)
        else:
            cursor = conn.cursor()
        cursor.execute(statement)
        if driver == "hive":
            field_names = None if cursor.description is None \
                else [i[0][i[0].find('.') + 1:] for i in cursor.description]
        else:
            field_names = None if cursor.description is None \
                else [i[0] for i in cursor.description]
        if label_column_name:
            try:
                label_idx = field_names.index(label_column_name)
            except ValueError:
                # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting.
                label_idx = None
        else:
            label_idx = None

        while True:
            rows = cursor.fetchmany(size=fetch_size)
            if not rows:
                break
            # NOTE: keep the connection while training or connection will lost if no activities appear.
            if driver == "mysql":
                conn.ping(True)
            for row in rows:
                # NOTE: If there is no label clause in the extened SQL, the default label value would
                # be -1, the Model implementation can determine use it or not.
                label = row[label_idx] if label_idx is not None else -1
                features = []
                for name in feature_column_names:
                    feature = read_feature(row[field_names.index(name)],
                                           feature_specs[name], name)
                    features.append(feature)
                if label_idx is None:
                    yield (tuple(features), )
                else:
                    yield tuple(features), label
            if len(rows) < fetch_size:
                break
        cursor.close()

    if driver == "maxcompute":
        from sqlflow_submitter.maxcompute import MaxCompute
        return MaxCompute.db_generator(conn, statement, feature_column_names,
                                       label_column_name, feature_specs,
                                       fetch_size)
    if driver == "hive":
        # trip the suffix ';' to avoid the ParseException in hive
        statement = statement.rstrip(';')
    return reader
Пример #2
0
def db_generator(driver,
                 conn,
                 statement,
                 feature_column_names,
                 label_spec,
                 feature_specs,
                 fetch_size=128):
    def reader():
        if driver == "hive":
            cursor = conn.cursor(configuration=conn.session_cfg)
        else:
            cursor = conn.cursor()
        cursor.execute(statement)
        if driver == "hive":
            field_names = None if cursor.description is None \
                else [i[0][i[0].find('.') + 1:] for i in cursor.description]
        else:
            field_names = None if cursor.description is None \
                else [i[0] for i in cursor.description]

        if label_spec:
            try:
                label_idx = field_names.index(label_spec["feature_name"])
            except ValueError:
                # NOTE(typhoonzero): For clustering model, label_column_name may not in field_names when predicting.
                label_idx = None
        else:
            label_idx = None

        while True:
            rows = cursor.fetchmany(size=fetch_size)
            if not rows:
                break
            # NOTE: keep the connection while training or connection will lost if no activities appear.
            if driver == "mysql":
                conn.ping(True)
            for row in rows:
                # NOTE: If there is no label clause in the extended SQL, the default label value would
                # be -1, the Model implementation can determine use it or not.
                label = row[label_idx] if label_idx is not None else -1
                if label_spec and label_spec["delimiter"] != "":
                    if label_spec["dtype"] == "float32":
                        label = np.fromstring(label,
                                              dtype=float,
                                              sep=label_spec["delimiter"])
                    elif label_spec["dtype"] == "int64":
                        label = np.fromstring(label,
                                              dtype=int,
                                              sep=label_spec["delimiter"])
                if label_idx is None:
                    yield list(row), None
                else:
                    yield list(row), label
            if len(rows) < fetch_size:
                break
        cursor.close()

    if driver == "maxcompute":
        from sqlflow_submitter.maxcompute import MaxCompute
        return MaxCompute.db_generator(conn, statement, feature_column_names,
                                       label_spec, feature_specs, fetch_size)
    if driver == "hive":
        # trip the suffix ';' to avoid the ParseException in hive
        statement = statement.rstrip(';')
    return reader