def test_pai_train_step(self): from runtime.step.tensorflow.train import train_step model_params = dict() model_params["hidden_units"] = [10, 20] model_params["n_classes"] = 3 original_sql = """ SELECT * FROM alifin_jtest_dev.sqlflow_test_iris_train TO TRAIN DNNClassifier WITH model.n_classes = 3, model.hidden_units = [10, 20] LABEL class INTO e2etest_pai_dnn;""" datasource = testing.get_datasource() save = "e2etest_pai_dnn" FLAGS = define_tf_flags() FLAGS.sqlflow_oss_ak = os.getenv("SQLFLOW_OSS_AK") FLAGS.sqlflow_oss_sk = os.getenv("SQLFLOW_OSS_SK") FLAGS.sqlflow_oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT") oss_path_to_save = pai_model.get_oss_model_save_path(datasource, save, user="") FLAGS.sqlflow_oss_modeldir = pai_model.get_oss_model_url( oss_path_to_save) train_step(original_sql, "", "DNNClassifier", datasource, "SELECT * FROM alifin_jtest_dev.sqlflow_iris_train", "", "alifin_jtest_dev.sqlflow_iris_train", "", model_params, {}, feature_column_map, label_column, save, None)
def init_pai_local_tf_flags_and_envs(oss_model_dir): FLAGS = define_tf_flags() FLAGS.sqlflow_oss_ak = os.getenv("SQLFLOW_OSS_AK") FLAGS.sqlflow_oss_sk = os.getenv("SQLFLOW_OSS_SK") FLAGS.sqlflow_oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT") if not oss_model_dir.startswith("oss://"): oss_model_dir = pai_model.get_oss_model_url(oss_model_dir) FLAGS.sqlflow_oss_modeldir = oss_model_dir FLAGS.checkpointDir = os.getcwd() set_oss_environs(FLAGS)
def _create_pai_hyper_param_file(cwd, filename, model_path): with open(path.join(cwd, filename), "w") as file: oss_ak = os.getenv("SQLFLOW_OSS_AK") oss_sk = os.getenv("SQLFLOW_OSS_SK") oss_ep = os.getenv("SQLFLOW_OSS_MODEL_ENDPOINT") if oss_ak == "" or oss_sk == "" or oss_ep == "": raise SQLFlowDiagnostic( "must define SQLFLOW_OSS_AK, SQLFLOW_OSS_SK, " "SQLFLOW_OSS_MODEL_ENDPOINT when submitting to PAI") file.write("sqlflow_oss_ak=\"%s\"\n" % oss_ak) file.write("sqlflow_oss_sk=\"%s\"\n" % oss_sk) file.write("sqlflow_oss_ep=\"%s\"\n" % oss_ep) oss_model_url = pai_model.get_oss_model_url(model_path) file.write("sqlflow_oss_modeldir=\"%s\"\n" % oss_model_url) file.flush()
def get_pai_tf_cmd(cluster_config, tarball, params_file, entry_file, model_name, oss_model_path, train_table, val_table, res_table, project): """Get PAI-TF cmd for training Args: cluster_config: PAI cluster config tarball: the zipped resource name params_file: PAI param file name entry_file: entry file in the tarball model_name: trained model name oss_model_path: path to save the model train_table: train data table val_table: evaluate data table res_table: table to save train model, if given project: current odps project Retruns: The cmd to run on PAI """ job_name = "_".join(["sqlflow", model_name]).replace(".", "_") cf_quote = json.dumps(cluster_config).replace("\"", "\\\"") # submit table should format as: odps://<project>/tables/<table >, # odps://<project>/tables/<table > ... submit_tables = _max_compute_table_url(train_table) if train_table != val_table and val_table: val_table = _max_compute_table_url(val_table) submit_tables = "%s,%s" % (submit_tables, val_table) output_tables = "" if res_table != "": table = _max_compute_table_url(res_table) output_tables = "-Doutputs=%s" % table # NOTE(typhoonzero): use - DhyperParameters to define flags passing # OSS credentials. # TODO(typhoonzero): need to find a more secure way to pass credentials. cmd = ("pai -name tensorflow1150 -project algo_public_dev " "-DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn " "-Dscript=%s -DentryFile=%s -Dtables=%s %s -DhyperParameters='%s'" ) % (job_name, tarball, entry_file, submit_tables, output_tables, params_file) # format the oss checkpoint path with ARN authorization, should use eval # because we use '''json''' in the workflow yaml file. oss_checkpoint_configs = eval(os.getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")) if not oss_checkpoint_configs: raise SQLFlowDiagnostic( "need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when " "submitting to PAI") if isinstance(oss_checkpoint_configs, dict): ckpt_conf = oss_checkpoint_configs else: ckpt_conf = json.loads(oss_checkpoint_configs) model_url = pai_model.get_oss_model_url(oss_model_path) role_name = _get_project_role_name(project) # format the oss checkpoint path with ARN authorization. oss_checkpoint_path = "%s/?role_arn=%s/%s&host=%s" % ( model_url, ckpt_conf["arn"], role_name, ckpt_conf["host"]) cmd = "%s -DcheckpointDir='%s'" % (cmd, oss_checkpoint_path) if cluster_config["worker"]["count"] > 1: cmd = "%s -Dcluster=\"%s\"" % (cmd, cf_quote) else: cmd = "%s -DgpuRequired='%d'" % (cmd, cluster_config["worker"]["gpu"]) return cmd
def test_get_oss_model_url(self): url = pai_model.get_oss_model_url("user_a/model") self.assertEqual("oss://sqlflow-models/user_a/model", url)