Exemple #1
0
    def test_no_column_clause(self):
        columns = [
            "sepal_length",
            "sepal_width",
            "petal_length",
            "petal_width",
        ]

        select = "select %s, class from iris.train" % ",".join(columns)

        conn = testing.get_singleton_db_connection()
        features = None
        label = NumericColumn(
            FieldDesc(name='class', dtype=DataType.INT64, shape=[1]))
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 4)

        for i, f in enumerate(features):
            self.assertTrue(isinstance(f, NumericColumn))
            self.assertEqual(len(f.get_field_desc()), 1)
            field_desc = f.get_field_desc()[0]
            self.assertEqual(field_desc.name, columns[i])
            self.assertEqual(field_desc.dtype, DataType.FLOAT32)
            self.assertEqual(field_desc.format, DataFormat.PLAIN)
            self.assertFalse(field_desc.is_sparse)
            self.assertEqual(field_desc.shape, [1])

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])
Exemple #2
0
    def test_with_cross(self):
        c1 = NumericColumn(
            FieldDesc(name='c1', dtype=DataType.INT64, shape=[1]))
        c2 = NumericColumn(
            FieldDesc(name='c2', dtype=DataType.INT64, shape=[1]))
        c4 = NumericColumn(
            FieldDesc(name='c4', dtype=DataType.INT64, shape=[1]))
        c5 = NumericColumn(
            FieldDesc(name='c5',
                      dtype=DataType.INT64,
                      shape=[1],
                      is_sparse=True))

        features = {
            'feature_columns': [
                c1,
                c2,
                CrossColumn([c4, c5], 128),
                CrossColumn([c1, c2], 256),
            ]
        }

        label = NumericColumn(
            FieldDesc(name='class', dtype=DataType.INT64, shape=[1]))
        select = "select c1, c2, c3, c4, c5, class " \
                 "from feature_derivation_case.train"

        conn = testing.get_singleton_db_connection()
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 5)

        fc1 = features[0]
        self.assertTrue(isinstance(fc1, NumericColumn))
        self.assertEqual(len(fc1.get_field_desc()), 1)
        field_desc = fc1.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c1")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc2 = features[1]
        self.assertTrue(isinstance(fc2, NumericColumn))
        self.assertEqual(len(fc2.get_field_desc()), 1)
        field_desc = fc2.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c2")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc3 = features[2]
        self.assertTrue(isinstance(fc3, NumericColumn))
        self.assertEqual(len(fc3.get_field_desc()), 1)
        field_desc = fc3.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c3")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])

        fc4 = features[3]
        self.assertTrue(isinstance(fc4, CrossColumn))
        self.assertEqual(len(fc4.get_field_desc()), 2)
        field_desc1 = fc4.get_field_desc()[0]
        self.assertEqual(field_desc1.name, "c4")
        self.assertEqual(field_desc1.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc1.format, DataFormat.CSV)
        self.assertEqual(field_desc1.shape, [4])
        self.assertFalse(field_desc1.is_sparse)
        field_desc2 = fc4.get_field_desc()[1]
        self.assertEqual(field_desc2.name, "c5")
        self.assertEqual(field_desc2.dtype, DataType.INT64)
        self.assertEqual(field_desc2.format, DataFormat.CSV)
        self.assertTrue(field_desc2.is_sparse)

        fc5 = features[4]
        self.assertTrue(isinstance(fc5, CrossColumn))
        self.assertEqual(len(fc4.get_field_desc()), 2)
        field_desc1 = fc5.get_field_desc()[0]
        self.assertEqual(field_desc1.name, "c1")
        self.assertEqual(field_desc1.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc1.format, DataFormat.PLAIN)
        self.assertEqual(field_desc1.shape, [1])
        self.assertFalse(field_desc1.is_sparse)
        field_desc2 = fc5.get_field_desc()[1]
        self.assertEqual(field_desc2.name, "c2")
        self.assertEqual(field_desc2.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc2.format, DataFormat.PLAIN)
        self.assertEqual(field_desc2.shape, [1])
        self.assertFalse(field_desc2.is_sparse)

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])
Exemple #3
0
    def test_without_cross(self):
        features = {
            'feature_columns': [
                EmbeddingColumn(dimension=256, combiner="mean", name="c3"),
                EmbeddingColumn(category_column=CategoryIDColumn(
                    FieldDesc(name="c5",
                              dtype=DataType.INT64,
                              shape=[10000],
                              delimiter=",",
                              is_sparse=True),
                    bucket_size=5000),
                                dimension=64,
                                combiner="sqrtn",
                                name="c5"),
            ]
        }

        label = NumericColumn(
            FieldDesc(name="class", dtype=DataType.INT64, shape=[1]))

        select = "select c1, c2, c3, c4, c5, c6, class " \
                 "from feature_derivation_case.train"
        conn = testing.get_singleton_db_connection()
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 6)

        fc1 = features[0]
        self.assertTrue(isinstance(fc1, NumericColumn))
        self.assertEqual(len(fc1.get_field_desc()), 1)
        field_desc = fc1.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c1")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc2 = features[1]
        self.assertTrue(isinstance(fc2, NumericColumn))
        self.assertEqual(len(fc2.get_field_desc()), 1)
        field_desc = fc2.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c2")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc3 = features[2]
        self.assertTrue(isinstance(fc3, EmbeddingColumn))
        self.assertEqual(len(fc3.get_field_desc()), 1)
        field_desc = fc3.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c3")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])
        self.assertEqual(fc3.dimension, 256)
        self.assertEqual(fc3.combiner, "mean")
        self.assertEqual(fc3.name, "c3")
        self.assertTrue(isinstance(fc3.category_column, CategoryIDColumn))
        self.assertEqual(fc3.category_column.bucket_size, 10)

        fc4 = features[3]
        self.assertTrue(isinstance(fc4, NumericColumn))
        self.assertEqual(len(fc4.get_field_desc()), 1)
        field_desc = fc4.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c4")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])

        fc5 = features[4]
        self.assertTrue(isinstance(fc5, EmbeddingColumn))
        self.assertEqual(len(fc5.get_field_desc()), 1)
        field_desc = fc5.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c5")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertTrue(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [10000])
        self.assertEqual(fc5.dimension, 64)
        self.assertEqual(fc5.combiner, "sqrtn")
        self.assertEqual(fc5.name, "c5")
        self.assertTrue(isinstance(fc5.category_column, CategoryIDColumn))
        self.assertEqual(fc5.category_column.bucket_size, 5000)

        fc6 = features[5]
        self.assertTrue(isinstance(fc6, EmbeddingColumn))
        self.assertEqual(len(fc6.get_field_desc()), 1)
        field_desc = fc6.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c6")
        self.assertEqual(field_desc.dtype, DataType.STRING)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])
        self.assertEqual(field_desc.vocabulary, set(['FEMALE', 'MALE',
                                                     'NULL']))
        self.assertEqual(fc6.dimension, 128)
        self.assertEqual(fc6.combiner, "sum")
        self.assertEqual(fc6.name, "c6")
        self.assertTrue(isinstance(fc6.category_column, CategoryIDColumn))
        self.assertEqual(fc6.category_column.bucket_size, 3)

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])
Exemple #4
0
def submit_local_train(datasource,
                       original_sql,
                       select,
                       validation_select,
                       estimator_string,
                       model_image,
                       feature_column_map,
                       label_column,
                       model_params,
                       train_params,
                       validation_params,
                       save,
                       load,
                       user=""):
    """This function run train task locally.

    Args:
        datasource: string
            Like: odps://access_id:[email protected]/api?
                         curr_project=test_ci&scheme=http
        select: string
            The SQL statement for selecting data for train
        validation_select: string
            Ths SQL statement for selecting data for validation
        estimator_string: string
            TensorFlow estimator name, Keras class name, or XGBoost
        model_image: string
            Docker image used to train this model,
            default: sqlflow/sqlflow:step
        feature_column_map: string
            A map of Python feature column IR.
        label_column: string
            Feature column instance describing the label.
        model_params: dict
            Params for training, crossponding to WITH clause
        train_params: dict
            Extra train params, will be passed to runtime.tensorflow.train
            or runtime.xgboost.train. Optional fields:
            - disk_cache: Use dmatrix disk cache if True, default: False.
            - batch_size: Split data to batches and train, default: 1.
            - epoch: Epochs to train, default: 1.
        validation_params: dict
            Params for validation.
        save: string
            Model name to be saved.
        load: string
            The pre-trained model name to load
        user: string
            Not used for local submitter, used in runtime.pai only.
    """
    if estimator_string.lower().startswith("xgboost"):
        train_func = xgboost_train
    else:
        train_func = tf_train

    with db.connect_with_data_source(datasource) as conn:
        feature_column_map, label_column = infer_feature_columns(
            conn, select, feature_column_map, label_column, n=1000)

    return train_func(original_sql=original_sql,
                      model_image=model_image,
                      estimator_string=estimator_string,
                      datasource=datasource,
                      select=select,
                      validation_select=validation_select,
                      model_params=model_params,
                      train_params=train_params,
                      validation_params=validation_params,
                      feature_column_map=feature_column_map,
                      label_column=label_column,
                      save=save,
                      load=load)
Exemple #5
0
def train(original_sql,
          model_image,
          estimator_string,
          datasource,
          select,
          validation_select,
          model_params,
          train_params,
          feature_column_map,
          label_column,
          save,
          load=None):
    """
    Train, evaluate and save the XGBoost model locally.

    Args:
        original_sql (str): the original SQL statement.
        model_image (str): the model repo docker image.
        estimator (str): the XGBoost booster type like xgboost.gbtree.
        datasource (str): the database connection URI.
        select (str): the SQL statement for training.
        validation_select (str): the SQL statement for evaluation.
        model_params (dict): the XGBoost model parameters.
        train_params (dict): the training parameters, can have
                             disk_cache(bool), batch_size(int), epoch(int)
                             settings in the dict.
        feature_column_map (dict): the feature column map to do derivation.
        label_column (FeatureColumn): the label column.
        save (str): the table name to save the trained model and meta.
        load (str): the table name to load the pretrained model.

    Returns:
        A dict which indicates the evaluation result.
    """
    conn = db.connect_with_data_source(datasource)
    fc_map_ir, fc_label_ir = infer_feature_columns(conn,
                                                   select,
                                                   feature_column_map,
                                                   label_column,
                                                   n=1000)
    fc_map = compile_ir_feature_columns(fc_map_ir, EstimatorType.XGBOOST)

    feature_column_list = fc_map["feature_columns"]
    field_descs = get_ordered_field_descs(fc_map_ir)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict()) for fd in field_descs])
    label_meta = label_column.get_field_desc()[0].to_dict()

    # NOTE: in the current implementation, we are generating a transform_fn
    # from the COLUMN clause. The transform_fn is executed during the process
    # of dumping the original data into DMatrix SVM file.
    transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(
        feature_column_names, *feature_column_list)

    disk_cache = False
    batch_size = None
    epoch = 1
    if "disk_cache" in train_params:
        disk_cache = train_params.pop("disk_cache")
    if "batch_size" in train_params:
        batch_size = train_params.pop("batch_size")
    if "epoch" in train_params:
        epoch = train_params.pop("epoch")

    def build_dataset(fn, slct):
        return xgb_dataset(datasource,
                           fn,
                           slct,
                           feature_metas,
                           feature_column_names,
                           label_meta,
                           cache=disk_cache,
                           batch_size=batch_size,
                           epoch=epoch,
                           transform_fn=transform_fn)

    file_name = "my_model"
    if load:
        Model.load_from_db(datasource, load)
        bst = xgb.Booster()
        bst.load_model(file_name)
    else:
        bst = None

    with temp_file.TemporaryDirectory() as tmp_dir_name:
        train_fn = os.path.join(tmp_dir_name, 'train.txt')
        val_fn = os.path.join(tmp_dir_name, 'val.txt')
        train_dataset = build_dataset(train_fn, select)
        if validation_select:
            val_dataset = build_dataset(val_fn, validation_select)
        else:
            val_dataset = None

        eval_result = dict()
        watchlist = [None]
        if val_dataset:
            # The `xgboost.train` API only accepts the XGBoost DMatrix
            # object as the training or validation dataset, so we should
            # convert the generator to DMatrix.
            if isinstance(val_dataset, types.GeneratorType):
                val_dataset = list(val_dataset)[0]
            watchlist.append((val_dataset, "validate"))

        for per_batch_dmatrix in train_dataset:
            watchlist[0] = (per_batch_dmatrix, "train")
            bst = xgb.train(model_params,
                            per_batch_dmatrix,
                            evals=watchlist,
                            evals_result=eval_result,
                            xgb_model=bst,
                            **train_params)
            print("Evaluation result: %s" % eval_result)

    meta = collect_metadata(original_sql=original_sql,
                            select=select,
                            validation_select=validation_select,
                            model_repo_image=model_image,
                            class_name=estimator_string,
                            attributes=model_params,
                            features=fc_map_ir,
                            label=fc_label_ir,
                            evaluation=eval_result,
                            num_workers=1)

    save_model_to_local_file(bst, model_params, file_name)
    model = Model(EstimatorType.XGBOOST, meta)
    model.save_to_db(datasource, save)
    return eval_result
Exemple #6
0
def train_step(original_sql,
               model_image,
               estimator_string,
               datasource,
               select,
               validation_select,
               pai_table,
               pai_val_table,
               model_params,
               train_params,
               feature_column_map,
               label_column,
               save,
               load=None):
    FLAGS = define_tf_flags()
    num_workers = len(FLAGS.worker_hosts.split(","))
    is_dist_train = num_workers > 1
    oss_model_dir = FLAGS.sqlflow_oss_modeldir

    oss_path_to_load = train_params.pop("oss_path_to_load")
    if load:
        oss.load_file(oss_path_to_load, "my_model")

    conn = db.connect_with_data_source(datasource)
    fc_map_ir, fc_label_ir = infer_feature_columns(conn,
                                                   select,
                                                   feature_column_map,
                                                   label_column,
                                                   n=1000)
    feature_columns = compile_ir_feature_columns(fc_map_ir,
                                                 EstimatorType.XGBOOST)
    field_descs = get_ordered_field_descs(fc_map_ir)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict()) for fd in field_descs])
    label_meta = label_column.get_field_desc()[0].to_dict()

    transform_fn = ComposedColumnTransformer(
        feature_column_names, *feature_columns["feature_columns"])

    batch_size = train_params.pop("batch_size", None)
    epoch = train_params.pop("epoch", 1)
    load_pretrained_model = True if load else False
    disk_cache = train_params.pop("disk_cache", False)

    if is_dist_train:
        dist_train(flags=FLAGS,
                   datasource=datasource,
                   select=select,
                   model_params=model_params,
                   train_params=train_params,
                   feature_metas=feature_metas,
                   feature_column_names=feature_column_names,
                   label_meta=label_meta,
                   validation_select=validation_select,
                   disk_cache=disk_cache,
                   batch_size=batch_size,
                   epoch=epoch,
                   load_pretrained_model=load_pretrained_model,
                   is_pai=True,
                   pai_train_table=pai_table,
                   pai_validate_table=pai_val_table,
                   oss_model_dir=oss_model_dir,
                   transform_fn=transform_fn,
                   feature_column_code=fc_map_ir,
                   model_repo_image=model_image,
                   original_sql=original_sql)
    else:
        local_train(datasource=datasource,
                    select=select,
                    model_params=model_params,
                    train_params=train_params,
                    feature_metas=feature_metas,
                    feature_column_names=feature_column_names,
                    label_meta=label_meta,
                    validation_select=validation_select,
                    disk_cache=disk_cache,
                    batch_size=batch_size,
                    epoch=epoch,
                    load_pretrained_model=load_pretrained_model,
                    is_pai=True,
                    pai_train_table=pai_table,
                    pai_validate_table=pai_val_table,
                    rank=0,
                    nworkers=1,
                    oss_model_dir=oss_model_dir,
                    transform_fn=transform_fn,
                    feature_column_code=fc_map_ir,
                    model_repo_image=model_image,
                    original_sql=original_sql)
Exemple #7
0
def train_step(original_sql,
               model_image,
               estimator_string,
               datasource,
               select,
               validation_select,
               model_params,
               train_params,
               validation_params,
               feature_column_map,
               label_column,
               save,
               load=None,
               pai_table=None,
               pai_val_table=None):
    if model_params is None:
        model_params = {}

    if train_params is None:
        train_params = {}

    if validation_params is None:
        validation_params = {}

    conn = db.connect_with_data_source(datasource)
    fc_map_ir, fc_label_ir = infer_feature_columns(conn,
                                                   select,
                                                   feature_column_map,
                                                   label_column,
                                                   n=1000)
    fc_map = compile_ir_feature_columns(fc_map_ir, EstimatorType.TENSORFLOW)
    field_descs = get_ordered_field_descs(fc_map_ir)
    feature_column_names = [fd.name for fd in field_descs]
    feature_metas = dict([(fd.name, fd.to_dict(dtype_to_string=True))
                          for fd in field_descs])
    label_meta = fc_label_ir.get_field_desc()[0].to_dict(dtype_to_string=True)

    feature_column_names_map = dict()
    for target in fc_map_ir:
        fclist = fc_map_ir[target]
        feature_column_names_map[target] = [
            fc.get_field_desc()[0].name for fc in fclist
        ]

    # Construct optimizer objects to pass to model initializer.
    # The original model_params is serializable (do not have tf.xxx objects).
    model_params_constructed = copy.deepcopy(model_params)
    for optimizer_arg in ["optimizer", "dnn_optimizer", "linear_optimizer"]:
        if optimizer_arg in model_params_constructed:
            model_params_constructed[optimizer_arg] = get_tf_optimizer(
                model_params_constructed[optimizer_arg])

    if "loss" in model_params_constructed:
        model_params_constructed["loss"] = get_tf_loss(
            model_params_constructed["loss"])

    # extract params for training.
    verbose = train_params.get("verbose", 1)
    batch_size = train_params.get("batch_size", 1)
    epoch = train_params.get("epoch", 1)
    save_checkpoints_steps = train_params.get("save_checkpoints_steps", 100)
    max_steps = train_params.get("max_steps", None)
    if max_steps is not None and max_steps <= 0:
        max_steps = None

    validation_metrics = validation_params.get("metrics", "Accuracy")
    validation_metrics = [v.strip() for v in validation_metrics.split(",")]
    validation_steps = validation_params.get("steps", 1)
    validation_start_delay_secs = validation_params.get("start_delay_secs", 0)
    validation_throttle_secs = validation_params.get("throttle_secs", 0)

    estimator = import_model(estimator_string)
    is_estimator = is_tf_estimator(estimator)

    is_pai = True if pai_table else False
    # always use verbose == 1 when using PAI to get more logs
    if verbose < 1:
        verbose = 1
    set_log_level(verbose, is_estimator)

    model_params_constructed.update(fc_map)

    FLAGS = define_tf_flags()
    set_oss_environs(FLAGS)
    num_workers = len(FLAGS.worker_hosts.split(","))
    worker_id = FLAGS.task_index

    train_dataset_fn = get_dataset_fn(select,
                                      datasource,
                                      feature_column_names,
                                      feature_metas,
                                      label_meta,
                                      is_pai,
                                      pai_table,
                                      batch_size,
                                      epochs=epoch,
                                      shuffle_size=1000,
                                      num_workers=num_workers,
                                      worker_id=worker_id)
    val_dataset_fn = None
    if validation_select:
        val_dataset_fn = get_dataset_fn(validation_select, datasource,
                                        feature_column_names, feature_metas,
                                        label_meta, is_pai, pai_val_table,
                                        batch_size)

    model_meta = collect_metadata(original_sql=original_sql,
                                  select=select,
                                  validation_select=validation_select,
                                  model_repo_image=model_image,
                                  class_name=estimator_string,
                                  attributes=model_params,
                                  features=fc_map_ir,
                                  label=fc_label_ir)

    # FIXME(typhoonzero): avoid save model_meta twice, keras_train_and_save,
    # estimator_train_and_save also dumps model_meta to a file under cwd.
    # should only keep the model.save_to_db part.
    save_dir = "model_save"
    if not is_estimator:
        if isinstance(estimator, types.FunctionType):
            # functional model need field_metas parameter
            model_params_constructed["field_metas"] = feature_metas
        keras_train_and_save(estimator, model_params_constructed, save_dir,
                             FLAGS, train_dataset_fn, val_dataset_fn,
                             label_meta, epoch, verbose, validation_metrics,
                             validation_steps, load, model_meta, is_pai)
    else:
        estimator_train_and_save(estimator, model_params_constructed, save_dir,
                                 FLAGS, train_dataset_fn, val_dataset_fn,
                                 max_steps, validation_start_delay_secs,
                                 validation_throttle_secs,
                                 save_checkpoints_steps, validation_metrics,
                                 load, model_meta)

    # save model to DB
    if num_workers == 1 or worker_id == 0:
        if is_pai:
            oss_model_dir = FLAGS.sqlflow_oss_modeldir
            oss.save_oss_model(oss_model_dir, estimator_string, is_estimator,
                               feature_column_names, feature_column_names_map,
                               feature_metas, label_meta, model_params,
                               fc_map_ir, num_workers)
            print("Model saved to OSS: %s" % oss_model_dir)
        else:
            model = Model(EstimatorType.TENSORFLOW, model_meta)
            model.save_to_db(datasource, save)
            print("Model saved to db: %s" % save)

    print("Done training")
    conn.close()