コード例 #1
0
ファイル: cluster_conf.py プロジェクト: zlb1028/sqlflow
def get_cluster_config(attrs):
    """Get PAI cluster config from attrs

    Args:
        attrs: input config

    Returns:
        The merged config by attrs and default
    """
    default_map = {
        "train.num_ps": 0,
        "train.num_workers": 1,
        "train.worker_cpu": 400,
        "train.worker_gpu": 0,
        "train.ps_cpu": 200,
        "train.ps_gpu": 0,
        "train.num_evaluator": 0,
        "train.evaluator_cpu": 200,
        "train.evaluator_gpu": 0,
    }

    update = {}
    for k, v in attrs.items():
        if k in default_map:
            update[k] = v
        elif "train." + k in default_map:
            update["train." + k] = v

    if not all(isinstance(v, int) for v in update.values()):
        raise SQLFlowDiagnostic("value for cluster config should be int")
    default_map.update(attrs)

    ps = {
        "count": default_map["train.num_ps"],
        "cpu": default_map["train.ps_cpu"],
        "gpu": default_map["train.ps_gpu"],
    }
    worker = {
        "count": default_map["train.num_workers"],
        "cpu": default_map["train.worker_cpu"],
        "gpu": default_map["train.worker_gpu"],
    }
    # FIXME(weiguoz): adhoc for running distributed xgboost train on pai
    if worker["count"] > 1 and ps["count"] < 1:
        ps["count"] = 1

    if default_map["train.num_evaluator"] == 0:
        evaluator = None
    elif default_map["train.num_evaluator"] == 1:
        evaluator = {
            "count": default_map["train.num_evaluator"],
            "cpu": default_map["train.evaluator_cpu"],
            "gpu": default_map["train.evaluator_gpu"],
        }
    else:
        raise SQLFlowDiagnostic("train.num_evaluator should only be 1 or 0")
    conf = {"ps": ps, "worker": worker}
    if evaluator is not None:
        conf["evaluator"] = evaluator
    return conf
コード例 #2
0
def create_explain_result_table(datasource, data_table, result_table,
                                model_type, estimator, label_column):
    """Create explain result table from given datasource

    Args:
        datasource: current datasource
        data_table: input data table name
        result_table: table name to store the result
        model_type: type of the model to use
        estimator: estimator class if the model is TensorFlow estimator
        label_column: column name of the predict label
    """
    conn = db.connect_with_data_source(datasource)
    drop_stmt = "DROP TABLE IF EXISTS %s" % result_table
    conn.execute(drop_stmt)

    create_stmt = ""
    if model_type == EstimatorType.PAIML:
        return
    elif model_type == EstimatorType.TENSORFLOW:
        if estimator.startswith("BoostedTrees"):
            column_def = ""
            if conn.driver == "mysql":
                column_def = "(feature VARCHAR(255), dfc FLOAT, gain FLOAT)"
            else:
                # Hive & MaxCompute
                column_def = "(feature STRING, dfc STRING, gain STRING)"
            create_stmt = "CREATE TABLE IF NOT EXISTS %s %s;" % (result_table,
                                                                 column_def)
        else:
            if not label_column:
                raise SQLFlowDiagnostic(
                    "need to specify WITH label_col=lable_col_name "
                    "when explaining deep models")
            create_stmt = get_create_shap_result_sql(conn, data_table,
                                                     result_table,
                                                     label_column)
    elif model_type == EstimatorType.XGBOOST:
        if not label_column:
            raise SQLFlowDiagnostic(
                "need to specify WITH label_col=lable_col_name "
                "when explaining xgboost models")
        create_stmt = get_create_shap_result_sql(conn, data_table,
                                                 result_table, label_column)
    else:
        raise SQLFlowDiagnostic(
            "not supported modelType %d for creating Explain result table" %
            model_type)

    if not conn.execute(create_stmt):
        raise SQLFlowDiagnostic("Can't create explain result table")
コード例 #3
0
ファイル: oss.py プロジェクト: redskycry/sqlflow
def delete_oss_dir_recursive(bucket, directory):
    """
    Recursively delete a directory on the OSS

    Args:
        bucket: bucket on OSS
        directory (str): the directory to delete

    Returns:
        None.
    """
    if not directory.endswith("/"):
        raise SQLFlowDiagnostic("dir to delete must end with /")

    loc = bucket.list_objects(prefix=directory, delimiter="/")
    object_path_list = []
    for obj in loc.object_list:
        object_path_list.append(obj.key)

    # delete sub dir first
    if len(loc.prefix_list) > 0:
        for sub_prefix in loc.prefix_list:
            delete_oss_dir_recursive(bucket, sub_prefix)
    # empty list param will raise error
    if len(object_path_list) > 0:
        bucket.batch_delete_objects(object_path_list)
コード例 #4
0
ファイル: entry.py プロジェクト: hsjung6/sqlflow
def call_fun(func, params):
    """Call a function with given params, entries in params will be treated
    as func' param if the key matches some argument name. Do not support
    var-args in func.

    Arags:
        func: callable
            a Python callable object
        params: dict
            dict of params
    Returns:
        the return value of func if success

    Raises:
        SQLFlowDiagnostic if none-optional argument is not found in params
    """
    # getargspec returns (pos_args, var_args, dict_args, defaults)
    sig = getargspec(func)
    required_len = len(sig[0]) - (0 if sig[3] is None else len(sig[3]))
    # if func has dict args, pass all params into it
    if sig[2] is not None:
        return func(**params)

    # if func has no dict args, we need to remove non-param entries in params
    dict_args = dict()
    for i, name in enumerate(sig[0]):
        if i < required_len:
            if name not in params:
                raise SQLFlowDiagnostic("Non-default param is not passed:%s" %
                                        name)
        if name in params:
            dict_args[name] = params[name]
    return func(**dict_args)
コード例 #5
0
ファイル: submitter.py プロジェクト: vmnet04/sqlflow
def submit_alisa_task(datasource, task_type, submit_code, args):
    """Submit an Alias task

    Args:
        datasource: the datasource to use
        task_type: AlisaTaskTypePAI or AlisaTaskTypePyODPS
        submit_code: the code to submit a PAI task
        args: map of arguments, like codeResourceURL and others
    """
    cfg = parse_alisa_config(datasource)

    if task_type == AlisaTaskTypePAI:
        cfg["Env"]["RES_DOWNLOAD_URL"] = (
            """[{"downloadUrl":"%s", "resourceName":"%s"}, """
            """{"downloadUrl":"%s", "resourceName":"%s"}]""") % (
                args["codeResourceURL"], args["resourceName"],
                args["paramsResourceURL"], args["paramsFile"])

    cfg["Verbose"] = True

    if task_type == AlisaTaskTypePAI:
        alisa_execute(submit_code, None)
    elif task_type == AlisaTaskTypePyODPS:
        alisa_execute(submit_code, args)
    else:
        return SQLFlowDiagnostic("Unknown AlisaTaskType %d" % task_type)
コード例 #6
0
def get_evaluate_metrics(model_type, model_attrs):
    """Get evaluate metrics from model attributes or return default

    Args:
        mode_type: type of the model, see runtime.model.EstimatorType
        model_attrs: model attributs passed by WITH clause

    Returns:
        An array of metrics names
    """
    metrics = []
    met_conf = model_attrs.get("validation.metrics") or model_attrs.get(
        "validationMetrics")
    if met_conf:
        [
            metrics.append(m) for m in met_conf.split(",")
            if m and m not in metrics
        ]
    # add default if no extra metrics is provided
    if len(metrics) == 0:
        if model_type == EstimatorType.XGBOOST:
            metrics.append("accuracy_score")
        elif model_type == EstimatorType.TENSORFLOW:
            metrics.append("Accuracy")
        else:
            raise SQLFlowDiagnostic("No metrics is provided.")
    return metrics
コード例 #7
0
def prepare_archive(cwd, conf, project, estimator, model_name, train_tbl,
                    val_tbl, model_save_path, train_params):
    """package needed resource into a tarball"""
    create_pai_hyper_param_file(cwd, PARAMS_FILE, model_save_path)

    with open(path.join(cwd, TRAIN_PARAMS_FILE), "wb") as param_file:
        pickle.dump(train_params, param_file, protocol=2)

    with open(path.join(cwd, "requirements.txt"), "w") as require:
        require.write(get_requirement(estimator))

    # copy entry.py to top level directory, so the package name `xgboost`
    # and `tensorflow` in runtime.pai will not conflict with the global ones
    shutil.copyfile(path.join(path.dirname(__file__), ENTRY_FILE),
                    path.join(cwd, ENTRY_FILE))
    copy_python_package("runtime", cwd)
    copy_python_package("sqlflow_models", cwd)
    copy_custom_package(estimator, cwd)

    args = [
        "tar", "czf", JOB_ARCHIVE_FILE, ENTRY_FILE, "runtime",
        "sqlflow_models", "requirements.txt", TRAIN_PARAMS_FILE
    ]
    if subprocess.call(args, cwd=cwd) != 0:
        raise SQLFlowDiagnostic("Can't zip resource")
コード例 #8
0
def get_explain_random_forests_cmd(datasource, model_name, data_table,
                                   result_table, label_column):
    """Get PAI random forest explanation command

    Args:
        datasource: current datasoruce
        model_name: model name on PAI
        data_table: input data table name
        result_table: result table name
        label_column: name of the label column

    Returns:
        a PAI cmd to explain the data using given model
    """
    # NOTE(typhoonzero): for PAI random forests predicting, we can not load
    # the TrainStmt since the model saving is fully done by PAI. We directly
    # use the columns in SELECT statement for prediction, error will be
    # reported by PAI job if the columns not match.
    if not label_column:
        raise SQLFlowDiagnostic("must specify WITH label_column when using "
                                "pai random forest to explain models")

    conn = db.connect_with_data_source(datasource)
    # drop result table if exists
    conn.execute("DROP TABLE IF EXISTS %s;" % result_table)
    schema = db.get_table_schema(conn, data_table)
    fields = [f[0] for f in schema if f[0] != label_column]
    return ('''pai -name feature_importance -project algo_public '''
            '''-DmodelName="%s" -DinputTableName="%s"  '''
            '''-DoutputTableName="%s" -DlabelColName="%s" '''
            '''-DfeatureColNames="%s" ''') % (model_name, data_table,
                                              result_table, label_column,
                                              ",".join(fields))
コード例 #9
0
def setup_explain_entry(params, model_type):
    """Setup PAI prediction entry function according to model type"""
    if model_type == EstimatorType.TENSORFLOW:
        params["entry_type"] = "explain_tf"
    elif model_type == EstimatorType.PAIML:
        params["entry_type"] = ""
    elif model_type == EstimatorType.XGBOOST:
        params["entry_type"] = "explain_xgb"
    else:
        raise SQLFlowDiagnostic("unsupported model type: %d" % model_type)
コード例 #10
0
def copy_python_package(module, dest):
    """Copy given Python module to dist

    Args:
        module: The module to copy
        dest: the destination directory
    """
    module_path = find_python_module_path(module)
    if not module_path:
        raise SQLFlowDiagnostic("Can't find module %s" % module)
    shutil.copytree(module_path, path.join(dest, path.basename(module_path)))
コード例 #11
0
def submit_pai_evaluate(datasource,
                        model_name,
                        select,
                        result_table,
                        model_attrs,
                        user=""):
    """Submit a PAI evaluation task

    Args:
        datasource: current datasource
        model_name: model used to do evaluation
        select: sql statement to get evaluate data set
        result_table: the table name to save result
        model_params: dict, Params for training, crossponding to WITH claus
    """

    params = dict(locals())
    cwd = tempfile.mkdtemp(prefix="sqlflow", dir="/tmp")

    project = table_ops.get_project(datasource)
    if result_table.count(".") == 0:
        result_table = "%s.%s" % (project, result_table)
    oss_model_path = pai_model.get_oss_model_save_path(datasource,
                                                       model_name,
                                                       user=user)
    params["oss_model_path"] = oss_model_path

    model_type, estimator = pai_model.get_oss_saved_model_type_and_estimator(
        oss_model_path, project)
    if model_type == EstimatorType.PAIML:
        raise SQLFlowDiagnostic("PAI model evaluation is not supported yet.")

    data_table = table_ops.create_tmp_table_from_select(select, datasource)
    params["data_table"] = data_table

    metrics = get_evaluate_metrics(model_type, model_attrs)
    params["metrics"] = metrics
    create_evaluate_result_table(datasource, result_table, metrics)

    conf = cluster_conf.get_cluster_config(model_attrs)

    if model_type == EstimatorType.XGBOOST:
        params["entry_type"] = "evaluate_xgb"
    else:
        params["entry_type"] = "evaluate_tf"
    prepare_archive(cwd, estimator, oss_model_path, params)
    cmd = get_pai_tf_cmd(conf, "file://" + os.path.join(cwd, JOB_ARCHIVE_FILE),
                         "file://" + os.path.join(cwd, PARAMS_FILE),
                         ENTRY_FILE, model_name, oss_model_path, data_table,
                         "", result_table, project)
    submit_pai_task(cmd, datasource)
    table_ops.drop_tables([data_table], datasource)
コード例 #12
0
ファイル: submitter.py プロジェクト: vmnet04/sqlflow
def getAlisaBucket():
    """Get Alisa oss bucket, this function get params from env variables"""
    ep = os.getenv("SQLFLOW_OSS_ALISA_ENDPOINT")
    ak = os.getenv("SQLFLOW_OSS_AK")
    sk = os.getenv("SQLFLOW_OSS_SK")
    bucketName = os.getenv("SQLFLOW_OSS_ALISA_BUCKET")

    if ep == "" or ak == "" or sk == "":
        return SQLFlowDiagnostic(
            "should define SQLFLOW_OSS_ALISA_ENDPOINT, "
            "SQLFLOW_OSS_ALISA_BUCKET, SQLFLOW_OSS_AK, SQLFLOW_OSS_SK "
            "when using submitter alisa")

    return oss.get_bucket(bucketName, ak, sk, endpoint=ep)
コード例 #13
0
ファイル: prepare_archive.py プロジェクト: zlb1028/sqlflow
def _create_pai_hyper_param_file(cwd, filename, model_path):
    with open(path.join(cwd, filename), "w") as file:
        oss_ak = os.getenv("SQLFLOW_OSS_AK")
        oss_sk = os.getenv("SQLFLOW_OSS_SK")
        oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
        if oss_ak == "" or oss_sk == "" or oss_ep == "":
            raise SQLFlowDiagnostic(
                "must define SQLFLOW_OSS_AK, SQLFLOW_OSS_SK, "
                "SQLFLOW_OSS_MODEL_ENDPOINT when submitting to PAI")
        file.write("sqlflow_oss_ak=\"%s\"\n" % oss_ak)
        file.write("sqlflow_oss_sk=\"%s\"\n" % oss_sk)
        file.write("sqlflow_oss_ep=\"%s\"\n" % oss_ep)
        oss_model_url = pai_model.get_oss_model_url(model_path)
        file.write("sqlflow_oss_modeldir=\"%s\"\n" % oss_model_url)
        file.flush()
コード例 #14
0
def submit_pai_task(pai_cmd, datasource):
    """Submit given cmd to PAI which manipulate datasource

    Args:
        pai_cmd: The command to submit
        datasource: The datasource this cmd will manipulate
    """
    user, passwd, address, project = parse_maxcompute_dsn(datasource)
    cmd = [
        "odpscmd", "--instance-priority", "9", "-u", user, "-p", passwd,
        "--project", project, "--endpoint", address, "-e", pai_cmd
    ]
    print(" ".join(cmd))
    if subprocess.call(cmd) != 0:
        raise SQLFlowDiagnostic("Execute odps cmd fail: cmd is %s" %
                                " ".join(cmd))
コード例 #15
0
ファイル: db.py プロジェクト: jackylovecoding/sqlflow
def _create_table(conn, table):
    if conn.driver == "mysql":
        stmt = "CREATE TABLE IF NOT EXISTS {0} (id INT, block TEXT,\
        PRIMARY KEY (id))".format(table)
    elif conn.driver == "hive":
        stmt = 'CREATE TABLE IF NOT EXISTS {0} (id INT, block STRING) ROW\
            FORMAT DELIMITED FIELDS TERMINATED BY "\\001" \
                STORED AS TEXTFILE'.format(table)
    elif conn.driver == "maxcompute":
        stmt = "CREATE TABLE IF NOT EXISTS {0} (id INT,\
            block STRING)".format(table)
    else:
        raise SQLFlowDiagnostic("unsupported driver {0} on creating\
            table.".format(conn.driver))

    conn.execute(stmt)
コード例 #16
0
def submit_pai_task(pai_cmd, datasource):
    """Submit given cmd to PAI which manipulate datasource

    Args:
        pai_cmd: The command to submit
        datasource: The datasource this cmd will manipulate
    """
    user, passwd, address, project = MaxComputeConnection.get_uri_parts(
        datasource)
    cmd = [
        "odpscmd", "--instance-priority", "9", "-u", user, "-p", passwd,
        "--project", project, "--endpoint", address, "-e", pai_cmd
    ]
    exitcode = run_command_and_log(cmd)
    if exitcode != 0:
        raise SQLFlowDiagnostic("Execute odps cmd fail: cmd is %s" %
                                " ".join(cmd))
コード例 #17
0
def create_tmp_table_from_select(select, datasource):
    """Create temp table for given select query

    Args:
        select: string, the selection statement
        datasource: string, the datasource to connect
    """
    if not select:
        return None
    conn = db.connect_with_data_source(datasource)
    project = get_project(datasource)
    tmp_tb_name = gen_rand_string()
    create_sql = "CREATE TABLE %s LIFECYCLE %s AS %s" % (
        tmp_tb_name, LIFECYCLE_ON_TMP_TABLE, select)
    # (NOTE: lhw) maxcompute conn doesn't support close
    # we should unify db interface
    if not conn.execute(create_sql):
        raise SQLFlowDiagnostic("Can't crate tmp table for %s" % select)
    return "%s.%s" % (project, tmp_tb_name)
コード例 #18
0
def create_pai_hyper_param_file(cwd, filename, model_path):
    """Create param needed by PAI training

    Args:
        cwd: current working dir
        filename: the output file name
        model_path: the model saving path
    """
    with open(path.join(cwd, filename), "w") as file:
        oss_ak = os.getenv("SQLFLOW_OSS_AK")
        oss_sk = os.getenv("SQLFLOW_OSS_SK")
        oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
        if oss_ak == "" or oss_sk == "" or oss_ep == "":
            raise SQLFlowDiagnostic(
                "must define SQLFLOW_OSS_AK, SQLFLOW_OSS_SK, "
                "SQLFLOW_OSS_MODEL_ENDPOINT when submitting to PAI")
        file.write("sqlflow_oss_ak=\"%s\"\n" % oss_ak)
        file.write("sqlflow_oss_sk=\"%s\"\n" % oss_sk)
        file.write("sqlflow_oss_ep=\"%s\"\n" % oss_ep)
        oss_model_url = get_oss_model_url(model_path)
        file.write("sqlflow_oss_modeldir=\"%s\"\n" % oss_model_url)
        file.flush()
コード例 #19
0
ファイル: kmeans.py プロジェクト: vmnet04/sqlflow
def get_train_kmeans_pai_cmd(datasource, model_name, data_table, model_attrs,
                             feature_column_names):
    """Get a command to submit a KMeans training task to PAI

    Args:
        datasource: current datasoruce
        model_name: model name on PAI
        data_table: input data table name
        model_attrs: model attributes for KMeans
        feature_column_names: names of feature columns

    Returns:
        A string which is a PAI cmd
    """
    [
        model_attrs.update({k: v}) for k, v in default_attrs.items()
        if k not in model_attrs
    ]
    center_count = model_attrs["center_count"]
    idx_table_name = model_attrs["idx_table_name"]
    if not idx_table_name:
        raise SQLFlowDiagnostic("Need to set idx_table_name in WITH clause")
    exclude_columns = model_attrs["excluded_columns"].split(",")

    # selectedCols indicates feature columns used to clustering
    selected_cols = [
        fc for fc in feature_column_names if fc not in exclude_columns
    ]

    conn = db.connect_with_data_source(datasource)
    db.execute(conn, "DROP TABLE IF EXISTS %s" % idx_table_name)

    return (
        """pai -name kmeans -project algo_public """
        """-DinputTableName=%s -DcenterCount=%d -DmodelName %s """
        """-DidxTableName=%s -DselectedColNames="%s" -DappendColNames="%s" """
    ) % (data_table, center_count, model_name, idx_table_name,
         ",".join(selected_cols), ",".join(feature_column_names))
コード例 #20
0
def submit_pai_evaluate(datasource,
                        original_sql,
                        select,
                        label_name,
                        model,
                        model_params,
                        result_table,
                        user=""):
    """Submit a PAI evaluation task

    Args:
        datasource: string
            Like: maxcompute://ak:[email protected]/api?
                  curr_project=test_ci&scheme=http
        original_sql: string
            Original "TO PREDICT" statement.
        select: string
            SQL statement to get prediction data set.
        model: string
            Model to load and do prediction.
        label_name: string
            The label name to evaluate.
        model_params: dict
            Params for training, crossponding to WITH clause.
        result_table: string
            The table name to save prediction result.
        user: string
            A string to identify the user, used to load model from the user's
            directory.
    """

    params = dict(locals())
    project = table_ops.get_project(datasource)
    if result_table.count(".") == 0:
        result_table = "%s.%s" % (project, result_table)
    params["result_table"] = result_table

    oss_model_path = pai_model.get_oss_model_save_path(datasource,
                                                       model,
                                                       user=user)

    model_type, estimator = pai_model.get_saved_model_type_and_estimator(
        datasource, model)
    if model_type == EstimatorType.PAIML:
        raise SQLFlowDiagnostic("PAI model evaluation is not supported yet.")

    if model_type == EstimatorType.XGBOOST:
        params["entry_type"] = "evaluate_xgb"
        validation_metrics = model_params.get("validation.metrics",
                                              "accuracy_score")
    else:
        params["entry_type"] = "evaluate_tf"
        validation_metrics = model_params.get("validation.metrics", "Accuracy")

    validation_metrics = [m.strip() for m in validation_metrics.split(",")]
    with db.connect_with_data_source(datasource) as conn:
        result_column_names = create_evaluate_table(conn, result_table,
                                                    validation_metrics)

    with table_ops.create_tmp_tables_guard(select, datasource) as data_table:
        params["pai_table"] = data_table
        params["result_column_names"] = result_column_names

        if try_pai_local_run(params, oss_model_path):
            return

        conf = cluster_conf.get_cluster_config(model_params)
        with temp_file.TemporaryDirectory(prefix="sqlflow", dir="/tmp") as cwd:
            prepare_archive(cwd, estimator, oss_model_path, params)
            cmd = get_pai_tf_cmd(
                conf, "file://" + os.path.join(cwd, JOB_ARCHIVE_FILE),
                "file://" + os.path.join(cwd, PARAMS_FILE), ENTRY_FILE, model,
                oss_model_path, data_table, "", result_table, project)
            submit_pai_task(cmd, datasource)
コード例 #21
0
ファイル: get_pai_tf_cmd.py プロジェクト: zlb1028/sqlflow
def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
                   model_name, oss_model_path, train_table, val_table,
                   res_table, project):
    """Get PAI-TF cmd for training

    Args:
        cluster_config: PAI cluster config
        tarball: the zipped resource name
        params_file: PAI param file name
        entry_file: entry file in the tarball
        model_name: trained model name
        oss_model_path: path to save the model
        train_table: train data table
        val_table: evaluate data table
        res_table: table to save train model, if given
        project: current odps project

    Retruns:
        The cmd to run on PAI
    """
    job_name = "_".join(["sqlflow", model_name]).replace(".", "_")
    cf_quote = json.dumps(cluster_config).replace("\"", "\\\"")

    # submit table should format as: odps://<project>/tables/<table >,
    # odps://<project>/tables/<table > ...
    submit_tables = _max_compute_table_url(train_table)
    if train_table != val_table and val_table:
        val_table = _max_compute_table_url(val_table)
        submit_tables = "%s,%s" % (submit_tables, val_table)
    output_tables = ""
    if res_table != "":
        table = _max_compute_table_url(res_table)
        output_tables = "-Doutputs=%s" % table

    # NOTE(typhoonzero): use - DhyperParameters to define flags passing
    # OSS credentials.
    # TODO(typhoonzero): need to find a more secure way to pass credentials.
    cmd = ("pai -name tensorflow1150 -project algo_public_dev "
           "-DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn "
           "-Dscript=%s -DentryFile=%s -Dtables=%s %s -DhyperParameters='%s'"
           ) % (job_name, tarball, entry_file, submit_tables, output_tables,
                params_file)

    # format the oss checkpoint path with ARN authorization, should use eval
    # because we use '''json''' in the workflow yaml file.
    oss_checkpoint_configs = eval(os.getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG"))
    if not oss_checkpoint_configs:
        raise SQLFlowDiagnostic(
            "need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when "
            "submitting to PAI")

    if isinstance(oss_checkpoint_configs, dict):
        ckpt_conf = oss_checkpoint_configs
    else:
        ckpt_conf = json.loads(oss_checkpoint_configs)

    model_url = pai_model.get_oss_model_url(oss_model_path)
    role_name = _get_project_role_name(project)
    # format the oss checkpoint path with ARN authorization.
    oss_checkpoint_path = "%s/?role_arn=%s/%s&host=%s" % (
        model_url, ckpt_conf["arn"], role_name, ckpt_conf["host"])
    cmd = "%s -DcheckpointDir='%s'" % (cmd, oss_checkpoint_path)

    if cluster_config["worker"]["count"] > 1:
        cmd = "%s -Dcluster=\"%s\"" % (cmd, cf_quote)
    else:
        cmd = "%s -DgpuRequired='%d'" % (cmd, cluster_config["worker"]["gpu"])
    return cmd
コード例 #22
0
ファイル: get_pai_tf_cmd.py プロジェクト: zlb1028/sqlflow
def _max_compute_table_url(table):
    parts = table.split(".")
    if len(parts) != 2:
        raise SQLFlowDiagnostic("odps table: %s should be format db.table" %
                                table)
    return "odps://%s/tables/%s" % (parts[0], parts[1])
コード例 #23
0
ファイル: submitter_evaluate.py プロジェクト: hsjung6/sqlflow
def submit_pai_evaluate(datasource,
                        original_sql,
                        select,
                        model_name,
                        model_params,
                        result_table,
                        user=""):
    """Submit a PAI evaluation task

    Args:
        datasource: string
            Like: maxcompute://ak:[email protected]/api?
                  curr_project=test_ci&scheme=http
        original_sql: string
            Original "TO PREDICT" statement.
        select: string
            SQL statement to get prediction data set.
        model_name: string
            Model to load and do prediction.
        model_params: dict
            Params for training, crossponding to WITH clause.
        result_table: string
            The table name to save prediction result.
        user: string
            A string to identify the user, used to load model from the user's
            directory.
    """

    params = dict(locals())
    cwd = tempfile.mkdtemp(prefix="sqlflow", dir="/tmp")

    project = table_ops.get_project(datasource)
    if result_table.count(".") == 0:
        result_table = "%s.%s" % (project, result_table)
    params["result_table"] = result_table

    oss_model_path = pai_model.get_oss_model_save_path(datasource,
                                                       model_name,
                                                       user=user)
    params["oss_model_path"] = oss_model_path

    model_type, estimator = pai_model.get_oss_saved_model_type_and_estimator(
        oss_model_path, project)
    if model_type == EstimatorType.PAIML:
        raise SQLFlowDiagnostic("PAI model evaluation is not supported yet.")

    data_table = table_ops.create_tmp_table_from_select(select, datasource)
    params["data_table"] = data_table

    metrics = get_evaluate_metrics(model_type, model_params)
    params["metrics"] = metrics
    create_evaluate_result_table(datasource, result_table, metrics)

    conf = cluster_conf.get_cluster_config(model_params)

    if model_type == EstimatorType.XGBOOST:
        params["entry_type"] = "evaluate_xgb"
    else:
        params["entry_type"] = "evaluate_tf"
    prepare_archive(cwd, estimator, oss_model_path, params)
    cmd = get_pai_tf_cmd(conf, "file://" + os.path.join(cwd, JOB_ARCHIVE_FILE),
                         "file://" + os.path.join(cwd, PARAMS_FILE),
                         ENTRY_FILE, model_name, oss_model_path, data_table,
                         "", result_table, project)
    submit_pai_task(cmd, datasource)
    table_ops.drop_tables([data_table], datasource)
コード例 #24
0
ファイル: prepare_archive.py プロジェクト: zlb1028/sqlflow
def _copy_python_package(module, dest):
    module_path = _find_python_module_path(module)
    if not module_path:
        raise SQLFlowDiagnostic("Can't find module %s" % module)
    shutil.copytree(module_path, path.join(dest, path.basename(module_path)))