def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, model: Optional[types.Channel] = None, model_run: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. Schema is optional when 1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `run_fn` at its top level with function signature: `def run_fn(trainer.fn_args_utils.FnArgs)`, and the trained model must be saved to FnArgs.serving_model_dir when this function is executed. For Estimator based Executor, The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma EvalInputReceiver. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). trainer_fn: A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer uses Estimator based Executor train_args: A trainer_pb2.TrainArgs instance or a dict, containing args used for training. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args used for evaluation. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. Default behavior (when splits is empty) is evaluate on `eval` split. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. model: Optional `Model` channel for result of exported models. model_run: Optional `ModelRun` channel, as the working dir of models, can be used to output non-model related output (e.g., TensorBoard logs). instance_name: Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: raise ValueError( "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be " "supplied.") if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied.") if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") examples = examples or transformed_examples model = model or types.Channel(type=standard_artifacts.Model) model_run = model_run or types.Channel(type=standard_artifacts.ModelRun) spec = TrainerSpec( examples=examples, transform_graph=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args, eval_args=eval_args, module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, custom_config=json_utils.dumps(custom_config), model=model, model_run=model_run) super(Trainer, self).__init__( spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: types.Channel = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, output: Optional[types.Channel] = None, transform_output: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: A Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.executor._TrainerFnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver. Exactly one of 'module_file' or 'trainer_fn' must be supplied. trainer_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Current only num_steps is available. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Current only num_steps is available. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. output: Optional `Model` channel for result of exported models. transform_output: Backwards compatibility alias for the 'transform_graph' argument. instance_name: Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and 'trainer_fn' is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ transform_graph = transform_graph or transform_output if bool(module_file) == bool(trainer_fn): raise ValueError( "Exactly one of 'module_file' or 'trainer_fn' must be supplied" ) if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") examples = examples or transformed_examples output = output or types.Channel( type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()]) spec = TrainerSpec(examples=examples, transform_output=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args, eval_args=eval_args, module_file=module_file, trainer_fn=trainer_fn, custom_config=custom_config, output=output) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_output: Optional[types.Channel] = None, schema: types.Channel = None, module_file: Optional[Text] = None, trainer_fn: Optional[Text] = None, train_args: trainer_pb2.TrainArgs = None, eval_args: trainer_pb2.EvalArgs = None, custom_config: Optional[Dict[Text, Any]] = None, executor_class: Optional[Type[base_executor.BaseExecutor]] = None, output: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of 'ExamplesPath' type, serving as the source of examples that are used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_output: An optional Channel of 'TransformPath' type, serving as the input transform graph if present. schema: A Channel of 'SchemaPath' type, serving as the schema of training and eval data. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(tf.contrib.training.HParams, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver Exactly one of 'module_file' or 'trainer_fn' must be supplied. trainer_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Current only num_steps is available. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Current only num_steps is available. custom_config: A dict which contains the training job parameters to be passed to Google Cloud ML Engine. For the full set of parameters supported by Google Cloud ML Engine, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job executor_class: Optional custom executor class. output: Optional 'ModelExportPath' channel for result of exported models. transform_graph: Forwards compatibility alias for the 'transform_output' argument. name: Optional unique name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and 'trainer_fn' is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_output' is not supplied. """ transform_output = transform_output or transform_graph if bool(module_file) == bool(trainer_fn): raise ValueError( "Exactly one of 'module_file' or 'trainer_fn' must be supplied") if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied.") if transformed_examples and not transform_output: raise ValueError("If 'transformed_examples' is supplied, " "'transform_output' must be supplied too.") examples = examples or transformed_examples output = output or types.Channel( type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()]) spec = TrainerSpec( examples=examples, transform_output=transform_output, schema=schema, train_args=train_args, eval_args=eval_args, module_file=module_file, trainer_fn=trainer_fn, custom_config=custom_config, output=output) super(Trainer, self).__init__( spec=spec, custom_executor_class=executor_class, name=name)