def _generate_hugectr_config(name, output_path, hugectr_params, max_batch_size=None): config = model_config.ModelConfig(name=name, backend="hugectr", max_batch_size=max_batch_size) config.input.append( model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1])) config.input.append( model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1])) config.input.append( model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1])) for i in range(hugectr_params["n_outputs"]): config.output.append( model_config.ModelOutput(name="OUTPUT" + str(i), data_type=model_config.TYPE_FP32, dims=[-1])) config.instance_group.append( model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)) config_hugectr = model_config.ModelParameter( string_value=hugectr_params["config"]) config.parameters["config"].CopyFrom(config_hugectr) gpucache_val = hugectr_params.get("gpucache", "true") gpucache = model_config.ModelParameter(string_value=gpucache_val) config.parameters["gpucache"].CopyFrom(gpucache) gpucacheper_val = str(hugectr_params.get("gpucacheper_val", "0.5")) gpucacheper = model_config.ModelParameter(string_value=gpucacheper_val) config.parameters["gpucacheper"].CopyFrom(gpucacheper) label_dim = model_config.ModelParameter( string_value=str(hugectr_params["label_dim"])) config.parameters["label_dim"].CopyFrom(label_dim) slots = model_config.ModelParameter( string_value=str(hugectr_params["slots"])) config.parameters["slots"].CopyFrom(slots) des_feature_num = model_config.ModelParameter( string_value=str(hugectr_params["des_feature_num"])) config.parameters["des_feature_num"].CopyFrom(des_feature_num) cat_feature_num = model_config.ModelParameter( string_value=str(hugectr_params["cat_feature_num"])) config.parameters["cat_feature_num"].CopyFrom(cat_feature_num) max_nnz = model_config.ModelParameter( string_value=str(hugectr_params["max_nnz"])) config.parameters["max_nnz"].CopyFrom(max_nnz) embedding_vector_size = model_config.ModelParameter( string_value=str(hugectr_params["embedding_vector_size"])) config.parameters["embedding_vector_size"].CopyFrom(embedding_vector_size) embeddingkey_long_type_val = hugectr_params.get("embeddingkey_long_type", "true") embeddingkey_long_type = model_config.ModelParameter( string_value=embeddingkey_long_type_val) config.parameters["embeddingkey_long_type"].CopyFrom( embeddingkey_long_type) with open(os.path.join(output_path, "config.pbtxt"), "w") as o: text_format.PrintMessage(config, o) return config
def _generate_nvtabular_config( workflow, name, output_path, output_model=None, max_batch_size=None, cats=None, conts=None, output_info=None, backend="python", ): """given a workflow generates the trton modelconfig proto object describing the inputs and outputs to that workflow""" config = model_config.ModelConfig(name=name, backend=backend, max_batch_size=max_batch_size) config.parameters[ "python_module"].string_value = "nvtabular.inference.triton.model" config.parameters[ "output_model"].string_value = output_model if output_model else "" if output_model == "hugectr": config.instance_group.append(model_config.ModelInstanceGroup(kind=2)) for column in workflow.column_group.input_column_names: dtype = workflow.input_dtypes[column] config.input.append( model_config.ModelInput(name=column, data_type=_convert_dtype(dtype), dims=[-1])) config.output.append( model_config.ModelOutput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1])) config.output.append( model_config.ModelOutput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1])) config.output.append( model_config.ModelOutput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1])) elif output_model == "pytorch": for column, dtype in workflow.input_dtypes.items(): _add_model_param(column, dtype, model_config.ModelInput, config.input) for col, val in output_info.items(): _add_model_param( col, val["dtype"], model_config.ModelOutput, config.output, [-1, len(val["columns"])], ) else: for column, dtype in workflow.input_dtypes.items(): _add_model_param(column, dtype, model_config.ModelInput, config.input) for column, dtype in workflow.output_dtypes.items(): _add_model_param(column, dtype, model_config.ModelOutput, config.output) with open(os.path.join(output_path, "config.pbtxt"), "w") as o: text_format.PrintMessage(config, o) return config