def build_preprocess_config(self, flags):
        first_sequence, second_sequence, label_name = \
            flags.firstSequence, flags.secondSequence, flags.labelName
        input_table = FLAGS.tables if "PAI" in tf.__version__ else flags.inputTable
        output_table = FLAGS.outputs if "PAI" in tf.__version__ else flags.outputTable
        append_columns = flags.appendCols.split(
            ",") if flags.appendCols else []
        if "PAI" in tf.__version__:
            input_schema = get_selected_columns_schema(
                input_table,
                set([first_sequence, second_sequence, label_name] +
                    append_columns))
        else:
            input_schema = flags.inputSchema
        output_schema = flags.outputSchema
        for column_name in append_columns:
            output_schema += "," + column_name

        user_param_dict = get_user_defined_prams_dict(flags.advancedParameters)
        if flags.modelName in _name_to_app_model:
            tokenizer_name_or_path = user_param_dict.get(
                "tokenizer_name_or_path", "google-bert-base-zh")
            setattr(self, "model_name", "serialization")
            setattr(self, "app_model_name", flags.modelName)
        else:
            tokenizer_name_or_path = flags.modelName
            setattr(self, "model_name", "serialization")
            setattr(self, "app_model_name", "text_classify_bert")
        config_json = {
            "preprocess_config": {
                "preprocess_input_fp":
                input_table,
                "preprocess_output_fp":
                output_table,
                "preprocess_batch_size":
                flags.batchSize,
                "sequence_length":
                flags.sequenceLength,
                "tokenizer_name_or_path":
                tokenizer_name_or_path,
                "input_schema":
                input_schema,
                "first_sequence":
                flags.firstSequence,
                "second_sequence":
                flags.secondSequence,
                "label_name":
                flags.labelName,
                "label_enumerate_values":
                get_label_enumerate_values(flags.labelEnumerateValues),
                "output_schema":
                output_schema
            }
        }
        for key, val in user_param_dict.items():
            setattr(self, key, val)
        return config_json
 def get_default_postprocessor(self):
     if hasattr(self.config, "label_enumerate_values"):
         label_enumerate_values = get_label_enumerate_values(self.config.label_enumerate_values)
     else:
         label_enumerate_values = None
     if hasattr(self.config, "model_name"):
         app_model_name = self.config.model_name
     else:
         app_model_name = None
     return postprocessors.get_postprocessors(
         label_enumerate_values=label_enumerate_values,
         output_schema=self.config.output_schema,
         thread_num=self.thread_num,
         input_queue=queue.Queue(),
         output_queue=queue.Queue(),
         app_model_name=app_model_name)
    def build_train_config(self, flags):
        # Parse input table/csv schema
        first_sequence, second_sequence, label_name = \
            flags.firstSequence, flags.secondSequence, flags.labelName
        label_enumerate_values = get_label_enumerate_values(
            flags.labelEnumerateValues)
        if "PAI" in tf.__version__:
            train_input_fp, eval_input_fp = FLAGS.tables.split(",")
            if first_sequence is None:
                assert flags.sequenceLength is not None
                input_schema = _name_to_app_model[
                    flags.modelName].get_input_tensor_schema(
                        sequence_length=flags.sequenceLength)
            else:
                input_schema = get_selected_columns_schema(
                    train_input_fp,
                    [first_sequence, second_sequence, label_name])
        else:
            train_input_fp, eval_input_fp = flags.inputTable.split(",")
            input_schema = flags.inputSchema
        train_input_fp, eval_input_fp = train_input_fp.strip(
        ), eval_input_fp.strip()
        # Parse args from APP's FLAGS
        config_json = {
            "preprocess_config": {
                "input_schema": input_schema,
                "first_sequence": first_sequence,
                "second_sequence": second_sequence,
                "sequence_length": flags.sequenceLength,
                "label_name": label_name,
                "label_enumerate_values": label_enumerate_values
            },
            "model_config": {
                "model_name": flags.modelName,
            },
            "train_config": {
                "train_input_fp": train_input_fp,
                "num_epochs": flags.numEpochs,
                "save_steps": flags.saveCheckpointSteps,
                "train_batch_size": flags.batchSize,
                "model_dir5": flags.checkpointDir,
                "optimizer_config": {
                    "optimizer": flags.optimizerType,
                    "learning_rate": flags.learningRate
                },
                "distribution_config": {
                    "distribution_strategy": flags.distributionStrategy,
                }
            },
            "evaluate_config": {
                "eval_input_fp": eval_input_fp,
                "eval_batch_size": 32,
                "num_eval_steps": None
            }
        }

        tf.logging.info(flags.advancedParameters)

        user_param_dict = get_user_defined_prams_dict(flags.advancedParameters)

        tf.logging.info(user_param_dict)
        if flags.modelName in _name_to_app_model:
            default_model_params = _name_to_app_model[
                flags.modelName].default_model_params()
        else:
            raise NotImplementedError
        for key, _ in default_model_params.items():
            default_val = default_model_params[key]
            if key in user_param_dict:
                if isinstance(default_val, bool):
                    tmp_val = (user_param_dict[key].lower() == "true")
                else:
                    tmp_val = type(default_val)(user_param_dict[key])
                config_json["model_config"][key] = tmp_val
            else:
                config_json["model_config"][key] = default_val

        config_json["model_config"]["num_labels"] = len(
            label_enumerate_values.split(","))
        if "pretrain_model_name_or_path" in config_json["model_config"]:
            pretrain_model_name_or_path = config_json["model_config"][
                "pretrain_model_name_or_path"]
            contrib_models_path = os.path.join(FLAGS.modelZooBasePath,
                                               "contrib_models.json")
            if not "PAI" in tf.__version__ and "oss://" in contrib_models_path:
                pass
            elif tf.gfile.Exists(contrib_models_path):
                with tf.gfile.Open(
                        os.path.join(FLAGS.modelZooBasePath,
                                     "contrib_models.json")) as f:
                    contrib_models = json.load(f)
                if pretrain_model_name_or_path in contrib_models:
                    pretrain_model_name_or_path = contrib_models[
                        pretrain_model_name_or_path]

            config_json["model_config"][
                "pretrain_model_name_or_path"] = pretrain_model_name_or_path
            config_json["preprocess_config"][
                "tokenizer_name_or_path"] = pretrain_model_name_or_path
        else:
            config_json["preprocess_config"]["tokenizer_name_or_path"] = ""

        if "num_accumulated_batches" in user_param_dict:
            config_json["train_config"]["distribution_config"]["num_accumulated_batches"] = \
                user_param_dict["num_accumulated_batches"]

        if "pull_evaluation_in_multiworkers_training" in user_param_dict:
            config_json["train_config"]["distribution_config"]["pull_evaluation_in_multiworkers_training"] = \
                (user_param_dict["pull_evaluation_in_multiworkers_training"].lower() == "true")

        other_param_keys = {
            "train_config":
            ["throttle_secs", "keep_checkpoint_max", "log_step_count_steps"],
            "optimizer_config": [
                "weight_decay_ratio", "lr_decay", "warmup_ratio",
                "gradient_clip", "clip_norm_value"
            ],
            "evaluate_config": ["eval_batch_size", "num_eval_steps"],
        }
        for first_key, second_key_list in other_param_keys.items():
            for second_key in second_key_list:
                if second_key in user_param_dict:
                    obj = config_json["train_config"][first_key] if first_key == "optimizer_config" \
                        else config_json[first_key]
                    obj[second_key] = user_param_dict[second_key]

        if "shuffle_buffer_size" in user_param_dict:
            setattr(self, "shuffle_buffer_size",
                    int(user_param_dict["shuffle_buffer_size"]))
        else:
            setattr(self, "shuffle_buffer_size", None)

        if "init_checkpoint_path" in user_param_dict:
            setattr(self, "init_checkpoint_path",
                    user_param_dict["init_checkpoint_path"])

        if "export_best_checkpoint" in user_param_dict:
            assert user_param_dict["export_best_checkpoint"].lower() in [
                "true", "false"
            ]
            if user_param_dict["export_best_checkpoint"].lower() == "true":
                setattr(self, "export_best_checkpoint", True)
            else:
                setattr(self, "export_best_checkpoint", False)

        if "export_best_checkpoint_metric" in user_param_dict:
            setattr(self, "export_best_checkpoint_metric",
                    user_param_dict["export_best_checkpoint_metric"])
        else:
            if flags.modelName.startswith("text_classify"):
                setattr(self, "export_best_checkpoint_metric", "py_accuracy")
            elif flags.modelName.startswith(
                    "text_match") and label_enumerate_values is None:
                setattr(self, "export_best_checkpoint_metric", "mse")
            else:
                setattr(self, "export_best_checkpoint_metric", "accuracy")

        return config_json