Ejemplo n.º 1
0
    def test_pai_train_step(self):
        from runtime.step.tensorflow.train import train_step
        model_params = dict()
        model_params["hidden_units"] = [10, 20]
        model_params["n_classes"] = 3

        original_sql = """
SELECT * FROM alifin_jtest_dev.sqlflow_test_iris_train
TO TRAIN DNNClassifier
WITH model.n_classes = 3, model.hidden_units = [10, 20]
LABEL class
INTO e2etest_pai_dnn;"""
        datasource = testing.get_datasource()
        save = "e2etest_pai_dnn"

        FLAGS = define_tf_flags()
        FLAGS.sqlflow_oss_ak = os.getenv("SQLFLOW_OSS_AK")
        FLAGS.sqlflow_oss_sk = os.getenv("SQLFLOW_OSS_SK")
        FLAGS.sqlflow_oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")

        oss_path_to_save = pai_model.get_oss_model_save_path(datasource,
                                                             save,
                                                             user="")
        FLAGS.sqlflow_oss_modeldir = pai_model.get_oss_model_url(
            oss_path_to_save)

        train_step(original_sql, "", "DNNClassifier", datasource,
                   "SELECT * FROM alifin_jtest_dev.sqlflow_iris_train", "",
                   "alifin_jtest_dev.sqlflow_iris_train", "", model_params, {},
                   feature_column_map, label_column, save, None)
Ejemplo n.º 2
0
def init_pai_local_tf_flags_and_envs(oss_model_dir):
    FLAGS = define_tf_flags()
    FLAGS.sqlflow_oss_ak = os.getenv("SQLFLOW_OSS_AK")
    FLAGS.sqlflow_oss_sk = os.getenv("SQLFLOW_OSS_SK")
    FLAGS.sqlflow_oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
    if not oss_model_dir.startswith("oss://"):
        oss_model_dir = pai_model.get_oss_model_url(oss_model_dir)
    FLAGS.sqlflow_oss_modeldir = oss_model_dir
    FLAGS.checkpointDir = os.getcwd()
    set_oss_environs(FLAGS)
Ejemplo n.º 3
0
def _create_pai_hyper_param_file(cwd, filename, model_path):
    with open(path.join(cwd, filename), "w") as file:
        oss_ak = os.getenv("SQLFLOW_OSS_AK")
        oss_sk = os.getenv("SQLFLOW_OSS_SK")
        oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
        if oss_ak == "" or oss_sk == "" or oss_ep == "":
            raise SQLFlowDiagnostic(
                "must define SQLFLOW_OSS_AK, SQLFLOW_OSS_SK, "
                "SQLFLOW_OSS_MODEL_ENDPOINT when submitting to PAI")
        file.write("sqlflow_oss_ak=\"%s\"\n" % oss_ak)
        file.write("sqlflow_oss_sk=\"%s\"\n" % oss_sk)
        file.write("sqlflow_oss_ep=\"%s\"\n" % oss_ep)
        oss_model_url = pai_model.get_oss_model_url(model_path)
        file.write("sqlflow_oss_modeldir=\"%s\"\n" % oss_model_url)
        file.flush()
Ejemplo n.º 4
0
def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file,
                   model_name, oss_model_path, train_table, val_table,
                   res_table, project):
    """Get PAI-TF cmd for training

    Args:
        cluster_config: PAI cluster config
        tarball: the zipped resource name
        params_file: PAI param file name
        entry_file: entry file in the tarball
        model_name: trained model name
        oss_model_path: path to save the model
        train_table: train data table
        val_table: evaluate data table
        res_table: table to save train model, if given
        project: current odps project

    Retruns:
        The cmd to run on PAI
    """
    job_name = "_".join(["sqlflow", model_name]).replace(".", "_")
    cf_quote = json.dumps(cluster_config).replace("\"", "\\\"")

    # submit table should format as: odps://<project>/tables/<table >,
    # odps://<project>/tables/<table > ...
    submit_tables = _max_compute_table_url(train_table)
    if train_table != val_table and val_table:
        val_table = _max_compute_table_url(val_table)
        submit_tables = "%s,%s" % (submit_tables, val_table)
    output_tables = ""
    if res_table != "":
        table = _max_compute_table_url(res_table)
        output_tables = "-Doutputs=%s" % table

    # NOTE(typhoonzero): use - DhyperParameters to define flags passing
    # OSS credentials.
    # TODO(typhoonzero): need to find a more secure way to pass credentials.
    cmd = ("pai -name tensorflow1150 -project algo_public_dev "
           "-DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn "
           "-Dscript=%s -DentryFile=%s -Dtables=%s %s -DhyperParameters='%s'"
           ) % (job_name, tarball, entry_file, submit_tables, output_tables,
                params_file)

    # format the oss checkpoint path with ARN authorization, should use eval
    # because we use '''json''' in the workflow yaml file.
    oss_checkpoint_configs = eval(os.getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG"))
    if not oss_checkpoint_configs:
        raise SQLFlowDiagnostic(
            "need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when "
            "submitting to PAI")

    if isinstance(oss_checkpoint_configs, dict):
        ckpt_conf = oss_checkpoint_configs
    else:
        ckpt_conf = json.loads(oss_checkpoint_configs)

    model_url = pai_model.get_oss_model_url(oss_model_path)
    role_name = _get_project_role_name(project)
    # format the oss checkpoint path with ARN authorization.
    oss_checkpoint_path = "%s/?role_arn=%s/%s&host=%s" % (
        model_url, ckpt_conf["arn"], role_name, ckpt_conf["host"])
    cmd = "%s -DcheckpointDir='%s'" % (cmd, oss_checkpoint_path)

    if cluster_config["worker"]["count"] > 1:
        cmd = "%s -Dcluster=\"%s\"" % (cmd, cf_quote)
    else:
        cmd = "%s -DgpuRequired='%d'" % (cmd, cluster_config["worker"]["gpu"])
    return cmd
Ejemplo n.º 5
0
 def test_get_oss_model_url(self):
     url = pai_model.get_oss_model_url("user_a/model")
     self.assertEqual("oss://sqlflow-models/user_a/model", url)