def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, analyzer_cache: Optional[types.Channel] = None, instance_name: Optional[Text] = None, materialize: bool = True, disable_analyzer_cache: bool = False, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. The function needs to have the following signature: ``` def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. If additional inputs are needed for preprocessing_fn, they can be passed in custom_config: ``` def preprocessing_fn(inputs: Dict[Text, Any], custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. analyzer_cache: Optional input 'TransformCache' channel containing cached information from previous Transform runs. When provided, Transform will try use the cached calculation if possible. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. materialize: If True, write transformed examples as an output. If False, `transformed_examples` must not be provided. disable_analyzer_cache: If False, Transform will use input cache if provided and write cache output. If True, `analyzer_cache` must not be provided. custom_config: A dict which contains additional parameters that will be passed to preprocessing_fn. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if materialize and transformed_examples is None: transformed_examples = types.Channel( type=standard_artifacts.Examples, # TODO(b/161548528): remove the hardcode artifact. artifacts=[standard_artifacts.Examples()], matching_channel_name='examples') elif not materialize and transformed_examples is not None: raise ValueError( 'Must not specify transformed_examples when materialize is False.') if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: raise ValueError( '`analyzer_cache` is set when disable_analyzer_cache is True.') else: updated_analyzer_cache = types.Channel( type=standard_artifacts.TransformCache, artifacts=[standard_artifacts.TransformCache()]) spec = TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples, analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json.dumps(custom_config)) super(Transform, self).__init__(spec=spec, instance_name=instance_name)
def __init__(self, input_data: types.Channel = None, schema: types.Channel = None, module_file: Optional[Text] = None, preprocessing_fn: Optional[Text] = None, transform_output: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, name: Optional[Text] = None): """Construct a Transform component. Args: input_data: A Channel of 'ExamplesPath' type. This should contain two splits 'train' and 'eval'. schema: A Channel of 'SchemaPath' type. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_output: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. name: Optional unique name. Necessary iff multiple transform components are declared in the same pipeline. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_output = transform_output or types.Channel( type=standard_artifacts.TransformResult, artifacts=[standard_artifacts.TransformResult()]) transformed_examples = transformed_examples or types.Channel( type=standard_artifacts.Examples, artifacts=[ standard_artifacts.Examples(split=split) for split in artifact.DEFAULT_EXAMPLE_SPLITS ]) spec = TransformSpec(input_data=channel_utils.as_channel(input_data), schema=channel_utils.as_channel(schema), module_file=module_file, preprocessing_fn=preprocessing_fn, transform_output=transform_output, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, name=name)
def __init__( self, model: types.Channel = None, model_blessing: types.Channel = None, infra_blessing: Optional[types.Channel] = None, push_destination: Optional[Union[pusher_pb2.PushDestination, Dict[Text, Any]]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, output: Optional[types.Channel] = None, model_export: Optional[types.Channel] = None, instance_name: Optional[Text] = None, enable_cache: Optional[bool] = None): """Construct a Pusher component. Args: model: A Channel of type `standard_artifacts.Model`, usually produced by a Trainer component. model_blessing: A Channel of type `standard_artifacts.ModelBlessing`, usually produced by a ModelValidator component. _required_ infra_blessing: An optional Channel of type `standard_artifacts.InfraBlessing`, usually produced from an InfraValidator component. push_destination: A pusher_pb2.PushDestination instance, providing info for tensorflow serving to load models. Optional if executor_class doesn't require push_destination. If any field is provided as a RuntimeParameter, push_destination should be constructed as a dict with the same field names as PushDestination proto message. custom_config: A dict which contains the deployment job parameters to be passed to cloud-based training platforms. The [Kubeflow example]( https://github.com/tensorflow/tfx/blob/6ff57e36a7b65818d4598d41e584a42584d361e6/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow_gcp.py#L278-L285) contains an example how this can be used by custom executors. custom_executor_spec: Optional custom executor spec. output: Optional output `standard_artifacts.PushedModel` channel with result of push. model_export: Backwards compatibility alias for the 'model' argument. instance_name: Optional unique instance name. Necessary if multiple Pusher components are declared in the same pipeline. enable_cache: Optional boolean to indicate if cache is enabled for the Pusher component. If not specified, defaults to the value specified for pipeline's enable_cache parameter. """ if model_export: absl.logging.warning( 'The "model_export" argument to the Pusher component has ' 'been renamed to "model" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') model = model_export output = output or types.Channel( type=standard_artifacts.PushedModel, artifacts=[standard_artifacts.PushedModel()]) if push_destination is None and not custom_executor_spec: raise ValueError( 'push_destination is required unless a ' 'custom_executor_spec is supplied that does not require ' 'it.') spec = PusherSpec(model=model, model_blessing=model_blessing, infra_blessing=infra_blessing, push_destination=push_destination, custom_config=json_utils.dumps(custom_config), pushed_model=output) super(Pusher, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name, enable_cache=enable_cache)
def __init__(self, examples: types.Channel = None, model: Optional[types.Channel] = None, model_blessing: Optional[types.Channel] = None, data_spec: Optional[Union[bulk_inferrer_pb2.DataSpec, Dict[Text, Any]]] = None, model_spec: Optional[Union[bulk_inferrer_pb2.ModelSpec, Dict[Text, Any]]] = None, output_example_spec: Optional[Union[ bulk_inferrer_pb2.OutputExampleSpec, Dict[Text, Any]]] = None, inference_result: Optional[types.Channel] = None, output_examples: Optional[types.Channel] = None): """Construct an BulkInferrer component. Args: examples: A Channel of type `standard_artifacts.Examples`, usually produced by an ExampleGen component. _required_ model: A Channel of type `standard_artifacts.Model`, usually produced by a Trainer component. model_blessing: A Channel of type `standard_artifacts.ModelBlessing`, usually produced by a ModelValidator component. data_spec: bulk_inferrer_pb2.DataSpec instance that describes data selection. If any field is provided as a RuntimeParameter, data_spec should be constructed as a dict with the same field names as DataSpec proto message. model_spec: bulk_inferrer_pb2.ModelSpec instance that describes model specification. If any field is provided as a RuntimeParameter, model_spec should be constructed as a dict with the same field names as ModelSpec proto message. output_example_spec: bulk_inferrer_pb2.OutputExampleSpec instance, specify if you want BulkInferrer to output examples instead of inference result. If any field is provided as a RuntimeParameter, output_example_spec should be constructed as a dict with the same field names as OutputExampleSpec proto message. inference_result: Channel of type `standard_artifacts.InferenceResult` to store the inference results, must not be specified when output_example_spec is set. output_examples: Channel of type `standard_artifacts.Examples` to store the output examples, must not be specified when output_example_spec is unset. Check output_example_spec for details. Raises: ValueError: Must not specify inference_result or output_examples depends on whether output_example_spec is set or not. """ if output_example_spec: if inference_result: raise ValueError( 'Must not specify inference_result when output_example_spec is set.' ) output_examples = output_examples or types.Channel( type=standard_artifacts.Examples) else: if output_examples: raise ValueError( 'Must not specify output_examples when output_example_spec is unset.' ) inference_result = inference_result or types.Channel( type=standard_artifacts.InferenceResult) spec = BulkInferrerSpec(examples=examples, model=model, model_blessing=model_blessing, data_spec=data_spec or bulk_inferrer_pb2.DataSpec(), model_spec=model_spec or bulk_inferrer_pb2.ModelSpec(), output_example_spec=output_example_spec, inference_result=inference_result, output_examples=output_examples) super(BulkInferrer, self).__init__(spec=spec)
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. Schema is optional when 1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `run_fn` at its top level with function signature: `def run_fn(trainer.fn_args_utils.FnArgs)`, and the trained model must be saved to FnArgs.serving_model_dir when this function is executed. For Estimator based Executor, The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma EvalInputReceiver. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a RuntimeParameter for this argument is experimental. trainer_fn: A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer uses Estimator based Executor. Use of a RuntimeParameter for this argument is experimental. train_args: A proto.TrainArgs instance or a dict, containing args used for training. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. Default behavior (when splits is empty) is train on `train` split. eval_args: A proto.EvalArgs instance or a dict, containing args used for evaluation. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. Default behavior (when splits is empty) is evaluate on `eval` split. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. This is experimental and is subject to change in the future. Raises: ValueError: - When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: raise ValueError( "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be " "supplied.") if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") if custom_executor_spec: logging.warning( "`custom_executor_spec` is going to be deprecated.") examples = examples or transformed_examples model = types.Channel(type=standard_artifacts.Model) model_run = types.Channel(type=standard_artifacts.ModelRun) spec = standard_component_specs.TrainerSpec( examples=examples, transform_graph=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args, eval_args=eval_args, module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, custom_config=json_utils.dumps(custom_config), model=model, model_run=model_run) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_output: Optional[types.Channel] = None, schema: types.Channel = None, module_file: Optional[Text] = None, trainer_fn: Optional[Text] = None, train_args: trainer_pb2.TrainArgs = None, eval_args: trainer_pb2.EvalArgs = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, output: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of 'ExamplesPath' type, serving as the source of examples that are used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_output: An optional Channel of 'TransformPath' type, serving as the input transform graph if present. schema: A Channel of 'SchemaPath' type, serving as the schema of training and eval data. module_file: A path to python module file containing UDF model definition. The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(tf.contrib.training.HParams, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver Exactly one of 'module_file' or 'trainer_fn' must be supplied. trainer_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Current only num_steps is available. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Current only num_steps is available. custom_config: A dict which contains the training job parameters to be passed to Google Cloud ML Engine. For the full set of parameters supported by Google Cloud ML Engine, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job custom_executor_spec: Optional custom executor spec. output: Optional 'ModelExportPath' channel for result of exported models. transform_graph: Forwards compatibility alias for the 'transform_output' argument. instance_name: Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and 'trainer_fn' is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_output' is not supplied. """ transform_output = transform_output or transform_graph if bool(module_file) == bool(trainer_fn): raise ValueError( "Exactly one of 'module_file' or 'trainer_fn' must be supplied" ) if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transformed_examples and not transform_output: raise ValueError("If 'transformed_examples' is supplied, " "'transform_output' must be supplied too.") examples = examples or transformed_examples output = output or types.Channel( type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()]) spec = TrainerSpec(examples=examples, transform_output=transform_output, schema=schema, train_args=train_args, eval_args=eval_args, module_file=module_file, trainer_fn=trainer_fn, custom_config=custom_config, output=output) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def __init__(self, name: Text, spec_kwargs: Dict[Text, Any]): spec = _FakeComponentSpec(output=types.Channel(type_name=name), **spec_kwargs) super(_FakeComponent, self).__init__(spec=spec, component_name=name)
def __init__( self, examples: types.BaseChannel, model: Optional[types.BaseChannel] = None, baseline_model: Optional[types.BaseChannel] = None, # TODO(b/148618405): deprecate feature_slicing_spec. feature_slicing_spec: Optional[ Union[evaluator_pb2.FeatureSlicingSpec, data_types.RuntimeParameter]] = None, fairness_indicator_thresholds: Optional[Union[ List[float], data_types.RuntimeParameter]] = None, example_splits: Optional[List[str]] = None, eval_config: Optional[tfma.EvalConfig] = None, schema: Optional[types.BaseChannel] = None, module_file: Optional[str] = None, module_path: Optional[str] = None): """Construct an Evaluator component. Args: examples: A BaseChannel of type `standard_artifacts.Examples`, usually produced by an ExampleGen component. _required_ model: A BaseChannel of type `standard_artifacts.Model`, usually produced by a Trainer component. baseline_model: An optional channel of type 'standard_artifacts.Model' as the baseline model for model diff and model validation purpose. feature_slicing_spec: Deprecated, please use eval_config instead. Only support estimator. [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto) instance that describes how Evaluator should slice the data. fairness_indicator_thresholds: Optional list of float (or RuntimeParameter) threshold values for use with TFMA fairness indicators. Experimental functionality: this interface and functionality may change at any time. TODO(b/142653905): add a link to additional documentation for TFMA fairness indicators here. example_splits: Names of splits on which the metrics are computed. Default behavior (when example_splits is set to None or Empty) is using the 'eval' split. eval_config: Instance of tfma.EvalConfig containg configuration settings for running the evaluation. This config has options for both estimator and Keras. schema: A `Schema` channel to use for TFXIO. module_file: A path to python module file containing UDFs for Evaluator customization. This functionality is experimental and may change at any time. The module_file can implement following functions at its top level. def custom_eval_shared_model( eval_saved_model_path, model_name, eval_config, **kwargs, ) -> tfma.EvalSharedModel: def custom_extractors( eval_shared_model, eval_config, tensor_adapter_config, ) -> List[tfma.extractors.Extractor]: module_path: A python path to the custom module that contains the UDFs. See 'module_file' for the required signature of UDFs. This functionality is experimental and this API may change at any time. Note this can not be set together with module_file. """ if bool(module_file) and bool(module_path): raise ValueError( 'Python module path can not be set together with module file path.' ) if eval_config is not None and feature_slicing_spec is not None: raise ValueError( "Exactly one of 'eval_config' or 'feature_slicing_spec' " 'must be supplied.') if eval_config is None and feature_slicing_spec is None: feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec() logging.info( 'Neither eval_config nor feature_slicing_spec is passed, ' 'the model is treated as estimator.') if feature_slicing_spec: logging.warning('feature_slicing_spec is deprecated, please use ' 'eval_config instead.') blessing = types.Channel(type=standard_artifacts.ModelBlessing) evaluation = types.Channel(type=standard_artifacts.ModelEvaluation) spec = standard_component_specs.EvaluatorSpec( examples=examples, model=model, baseline_model=baseline_model, feature_slicing_spec=feature_slicing_spec, fairness_indicator_thresholds=( fairness_indicator_thresholds if isinstance( fairness_indicator_thresholds, data_types.RuntimeParameter) else json_utils.dumps(fairness_indicator_thresholds)), example_splits=json_utils.dumps(example_splits), evaluation=evaluation, eval_config=eval_config, blessing=blessing, schema=schema, module_file=module_file, module_path=module_path) super().__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__(self, examples: types.Channel = None, schema: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, module_file: Optional[Text] = None, tuner_fn: Optional[Text] = None, train_args: trainer_pb2.TrainArgs = None, eval_args: trainer_pb2.EvalArgs = None, tune_args: Optional[tuner_pb2.TuneArgs] = None, custom_config: Optional[Dict[Text, Any]] = None, best_hyperparameters: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Tuner component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples that are used in tuning (required). schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. This is used when raw examples are provided. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. This is used when transformed examples are provided. module_file: A path to python module file containing UDF tuner definition. The module_file must implement a function named `tuner_fn` at its top level. The function must have the following signature. def tuner_fn(fn_args: FnArgs) -> TunerFnResult: Exactly one of 'module_file' or 'tuner_fn' must be supplied. tuner_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on `eval` split. tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available. custom_config: A dict which contains addtional training job parameters that will be passed into user module. best_hyperparameters: Optional Channel of type `standard_artifacts.HyperParameters` for result of the best hparams. instance_name: Optional unique instance name. Necessary if multiple Tuner components are declared in the same pipeline. """ if bool(module_file) == bool(tuner_fn): raise ValueError( "Exactly one of 'module_file' or 'tuner_fn' must be supplied") best_hyperparameters = best_hyperparameters or types.Channel( type=standard_artifacts.HyperParameters) spec = TunerSpec( examples=examples, schema=schema, transform_graph=transform_graph, module_file=module_file, tuner_fn=tuner_fn, train_args=train_args, eval_args=eval_args, tune_args=tune_args, best_hyperparameters=best_hyperparameters, custom_config=json_utils.dumps(custom_config), ) super(Tuner, self).__init__(spec=spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, model: types.Channel = None, baseline_model: Optional[types.Channel] = None, # TODO(b/148618405): deprecate feature_slicing_spec. feature_slicing_spec: Optional[Union[ evaluator_pb2.FeatureSlicingSpec, Dict[Text, Any]]] = None, fairness_indicator_thresholds: Optional[List[Union[ float, data_types.RuntimeParameter]]] = None, example_splits: Optional[List[Text]] = None, evaluation: Optional[types.Channel] = None, instance_name: Optional[Text] = None, eval_config: Optional[tfma.EvalConfig] = None, blessing: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, module_file: Optional[Text] = None, module_path: Optional[Text] = None): """Construct an Evaluator component. Args: examples: A Channel of type `standard_artifacts.Examples`, usually produced by an ExampleGen component. _required_ model: A Channel of type `standard_artifacts.Model`, usually produced by a Trainer component. baseline_model: An optional channel of type 'standard_artifacts.Model' as the baseline model for model diff and model validation purpose. feature_slicing_spec: Deprecated, please use eval_config instead. Only support estimator. [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto) instance that describes how Evaluator should slice the data. If any field is provided as a RuntimeParameter, feature_slicing_spec should be constructed as a dict with the same field names as FeatureSlicingSpec proto message. fairness_indicator_thresholds: Optional list of float (or RuntimeParameter) threshold values for use with TFMA fairness indicators. Experimental functionality: this interface and functionality may change at any time. TODO(b/142653905): add a link to additional documentation for TFMA fairness indicators here. example_splits: Names of splits on which the metrics are computed. Default behavior (when example_splits is set to None or Empty) is using the 'eval' split. evaluation: Channel of `ModelEvaluation` to store the evaluation results. instance_name: Optional name assigned to this specific instance of Evaluator. Required only if multiple Evaluator components are declared in the same pipeline. Either `model_exports` or `model` must be present in the input arguments. eval_config: Instance of tfma.EvalConfig containg configuration settings for running the evaluation. This config has options for both estimator and Keras. blessing: Output channel of 'ModelBlessing' that contains the blessing result. schema: A `Schema` channel to use for TFXIO. module_file: A path to python module file containing UDFs for Evaluator customization. This functionality is experimental and may change at any time. The module_file can implement following functions at its top level. def custom_eval_shared_model( eval_saved_model_path, model_name, eval_config, **kwargs, ) -> tfma.EvalSharedModel: def custom_extractors( eval_shared_model, eval_config, tensor_adapter_config, ) -> List[tfma.extractors.Extractor]: module_path: A python path to the custom module that contains the UDFs. See 'module_file' for the required signature of UDFs. This functionality is experimental and this API may change at any time. Note this can not be set together with module_file. """ if bool(module_file) and bool(module_path): raise ValueError( 'Python module path can not be set together with module file path.' ) if eval_config is not None and feature_slicing_spec is not None: raise ValueError( "Exactly one of 'eval_config' or 'feature_slicing_spec' " "must be supplied.") if eval_config is None and feature_slicing_spec is None: feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec() logging.info( 'Neither eval_config nor feature_slicing_spec is passed, ' 'the model is treated as estimator.') if feature_slicing_spec: logging.warning('feature_slicing_spec is deprecated, please use ' 'eval_config instead.') blessing = blessing or types.Channel( type=standard_artifacts.ModelBlessing) evaluation = types.Channel(type=standard_artifacts.ModelEvaluation) spec = EvaluatorSpec( examples=examples, model=model, baseline_model=baseline_model, feature_slicing_spec=feature_slicing_spec, fairness_indicator_thresholds=fairness_indicator_thresholds, example_splits=json_utils.dumps(example_splits), evaluation=evaluation, eval_config=eval_config, blessing=blessing, schema=schema, module_file=module_file, module_path=module_path) super(Evaluator, self).__init__(spec=spec, instance_name=instance_name)
def __init__( self, type: Type[types.Artifact], # pylint: disable=redefined-builtin spec_kwargs: Dict[Text, Any]): spec = _FakeComponentSpec(output=types.Channel(type=type), **spec_kwargs) super(_FakeComponent, self).__init__(spec=spec, instance_name=name)
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, splits_config: transform_pb2.SplitsConfig = None, analyzer_cache: Optional[types.Channel] = None, materialize: bool = True, disable_analyzer_cache: bool = False, force_tf_compat_v1: bool = False, custom_config: Optional[Dict[Text, Any]] = None, compute_statistics: bool = False): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain custom splits specified in splits_config. If custom split is not provided, this should contain two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. The function needs to have the following signature: ``` def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. If additional inputs are needed for preprocessing_fn, they can be passed in custom_config: ``` def preprocessing_fn(inputs: Dict[Text, Any], custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... ``` Use of a RuntimeParameter for this argument is experimental. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. Use of a RuntimeParameter for this argument is experimental. splits_config: A transform_pb2.SplitsConfig instance, providing splits that should be analyzed and splits that should be transformed. Note analyze and transform splits can have overlap. Default behavior (when splits_config is not set) is analyze the 'train' split and transform all splits. If splits_config is set, analyze cannot be empty. analyzer_cache: Optional input 'TransformCache' channel containing cached information from previous Transform runs. When provided, Transform will try use the cached calculation if possible. materialize: If True, write transformed examples as an output. disable_analyzer_cache: If False, Transform will use input cache if provided and write cache output. If True, `analyzer_cache` must not be provided. force_tf_compat_v1: (Optional) If True and/or TF2 behaviors are disabled Transform will use Tensorflow in compat.v1 mode irrespective of installed version of Tensorflow. Defaults to `False`. custom_config: A dict which contains additional parameters that will be passed to preprocessing_fn. compute_statistics: If True, invoke TFDV to compute pre-transform and post-transform statistics. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = types.Channel(type=standard_artifacts.TransformGraph) transformed_examples = None if materialize: transformed_examples = types.Channel( type=standard_artifacts.Examples) transformed_examples.matching_channel_name = 'examples' if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: raise ValueError( '`analyzer_cache` is set when disable_analyzer_cache is True.' ) else: updated_analyzer_cache = types.Channel( type=standard_artifacts.TransformCache) spec = standard_component_specs.TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, force_tf_compat_v1=int(force_tf_compat_v1), splits_config=splits_config, transform_graph=transform_graph, transformed_examples=transformed_examples, analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json_utils.dumps(custom_config), compute_statistics=int(compute_statistics)) super(Transform, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__(self, examples: types.Channel, schema: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, module_file: Optional[Text] = None, tuner_fn: Optional[Text] = None, train_args: Optional[trainer_pb2.TrainArgs] = None, eval_args: Optional[trainer_pb2.EvalArgs] = None, tune_args: Optional[tuner_pb2.TuneArgs] = None, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Tuner component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples that are used in tuning (required). schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. This is used when raw examples are provided. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. This is used when transformed examples are provided. module_file: A path to python module file containing UDF tuner definition. The module_file must implement a function named `tuner_fn` at its top level. The function must have the following signature. def tuner_fn(fn_args: FnArgs) -> TunerFnResult: Exactly one of 'module_file' or 'tuner_fn' must be supplied. tuner_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on `eval` split. tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available. custom_config: A dict which contains addtional training job parameters that will be passed into user module. """ if bool(module_file) == bool(tuner_fn): raise ValueError( "Exactly one of 'module_file' or 'tuner_fn' must be supplied") best_hyperparameters = types.Channel( type=standard_artifacts.HyperParameters) spec = standard_component_specs.TunerSpec( examples=examples, schema=schema, transform_graph=transform_graph, module_file=module_file, tuner_fn=tuner_fn, train_args=train_args or trainer_pb2.TrainArgs(), eval_args=eval_args or trainer_pb2.EvalArgs(), tune_args=tune_args, best_hyperparameters=best_hyperparameters, custom_config=json_utils.dumps(custom_config), ) super(Tuner, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): # In this case, the `MODULE_PATH_KEY` execution property will be injected # as a reference to the given user module file after packaging, at which # point the `MODULE_FILE_KEY` execution property will be removed. udf_utils.add_user_module_dependency( self, standard_component_specs.MODULE_FILE_KEY, standard_component_specs.MODULE_PATH_KEY)
def __init__( self, examples: types.Channel = None, model: types.Channel = None, baseline_model: Optional[types.Channel] = None, # TODO(b/148618405): deprecate feature_slicing_spec. feature_slicing_spec: Optional[Union[ evaluator_pb2.FeatureSlicingSpec, Dict[Text, Any]]] = None, fairness_indicator_thresholds: Optional[List[Union[ float, data_types.RuntimeParameter]]] = None, example_splits: Optional[List[Text]] = None, output: Optional[types.Channel] = None, model_exports: Optional[types.Channel] = None, instance_name: Optional[Text] = None, eval_config: Optional[tfma.EvalConfig] = None, blessing: Optional[types.Channel] = None, schema: Optional[types.Channel] = None): """Construct an Evaluator component. Args: examples: A Channel of type `standard_artifacts.Examples`, usually produced by an ExampleGen component. _required_ model: A Channel of type `standard_artifacts.Model`, usually produced by a Trainer component. baseline_model: An optional channel of type 'standard_artifacts.Model' as the baseline model for model diff and model validation purpose. feature_slicing_spec: Deprecated, please use eval_config instead. Only support estimator. [evaluator_pb2.FeatureSlicingSpec](https://github.com/tensorflow/tfx/blob/master/tfx/proto/evaluator.proto) instance that describes how Evaluator should slice the data. If any field is provided as a RuntimeParameter, feature_slicing_spec should be constructed as a dict with the same field names as FeatureSlicingSpec proto message. fairness_indicator_thresholds: Optional list of float (or RuntimeParameter) threshold values for use with TFMA fairness indicators. Experimental functionality: this interface and functionality may change at any time. TODO(b/142653905): add a link to additional documentation for TFMA fairness indicators here. example_splits: Names of splits on which the metrics are computed. Default behavior (when example_splits is set to None or Empty) is using the 'eval' split. output: Channel of `ModelEvalPath` to store the evaluation results. model_exports: Backwards compatibility alias for the `model` argument. instance_name: Optional name assigned to this specific instance of Evaluator. Required only if multiple Evaluator components are declared in the same pipeline. Either `model_exports` or `model` must be present in the input arguments. eval_config: Instance of tfma.EvalConfig containg configuration settings for running the evaluation. This config has options for both estimator and Keras. blessing: Output channel of 'ModelBlessingPath' that contains the blessing result. schema: A `Schema` channel to use for TFXIO. """ if eval_config is not None and feature_slicing_spec is not None: raise ValueError( "Exactly one of 'eval_config' or 'feature_slicing_spec' " "must be supplied.") if eval_config is None and feature_slicing_spec is None: feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec() logging.info( 'Neither eval_config nor feature_slicing_spec is passed, ' 'the model is treated as estimator.') if model_exports: logging.warning( 'The "model_exports" argument to the Evaluator component has ' 'been renamed to "model" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') model = model_exports if feature_slicing_spec: logging.warning('feature_slicing_spec is deprecated, please use ' 'eval_config instead.') blessing = blessing or types.Channel( type=standard_artifacts.ModelBlessing, artifacts=[standard_artifacts.ModelBlessing()]) evaluation = output or types.Channel( type=standard_artifacts.ModelEvaluation, artifacts=[standard_artifacts.ModelEvaluation()]) spec = EvaluatorSpec( examples=examples, model=model, baseline_model=baseline_model, feature_slicing_spec=feature_slicing_spec, fairness_indicator_thresholds=fairness_indicator_thresholds, example_splits=json_utils.dumps(example_splits), evaluation=evaluation, eval_config=eval_config, blessing=blessing, schema=schema) super(Evaluator, self).__init__(spec=spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None, enable_cache: Optional[bool] = None): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. enable_cache: Optional boolean to indicate if cache is enabled for the Transform component. If not specified, defaults to the value specified for pipeline's enable_cache parameter. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if not transformed_examples: example_artifact = standard_artifacts.Examples() example_artifact.split_names = artifact_utils.encode_split_names( artifact.DEFAULT_EXAMPLE_SPLITS) transformed_examples = types.Channel( type=standard_artifacts.Examples, artifacts=[example_artifact]) spec = TransformSpec(examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, instance_name=instance_name, enable_cache=enable_cache)
def __init__(self, examples: types.Channel = None, schema: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, module_file: Optional[str] = None, tuner_fn: Optional[str] = None, train_args: trainer_pb2.TrainArgs = None, eval_args: trainer_pb2.EvalArgs = None, tune_args: Optional[tuner_pb2.TuneArgs] = None, custom_config: Optional[Dict[str, Any]] = None, metalearning_algorithm: Optional[str] = None, warmup_hyperparameters: Optional[types.Channel] = None, metamodel: Optional[types.Channel] = None, metafeature: Optional[types.Channel] = None, best_hyperparameters: Optional[types.Channel] = None, instance_name: Optional[str] = None): """Constructs custom Tuner component that stores trial learning curve. Adapted from the following code: https://github.com/tensorflow/tfx/blob/master/tfx/components/tuner/component.py Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples that are used in tuning (required). schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. This is used when raw examples are provided. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. This is used when transformed examples are provided. module_file: A path to python module file containing UDF tuner definition. The module_file must implement a function named `tuner_fn` at its top level. The function must have the following signature. def tuner_fn(fn_args: FnArgs) -> TunerFnResult: Exactly one of 'module_file' or 'tuner_fn' must be supplied. tuner_fn: A python path to UDF model definition function. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'tuner_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance, containing args used for training. Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval. Currently only splits and num_steps are available. Default behavior (when splits is empty) is evaluate on `eval` split. tune_args: A tuner_pb2.TuneArgs instance, containing args used for tuning. Currently only num_parallel_trials is available. custom_config: A dict which contains addtional training job parameters that will be passed into user module. metalearning_algorithm: Optional str for the type of metalearning_algorithm. warmup_hyperparameters: Optional Channel of type `artifacts.KCandidateHyperParameters` for a list of recommended search space for warm-starting the tuner (generally the output of a metalearning component or subpipeline). metamodel: Optional Channel of type `standard_artifacts.Model` for trained meta model metafeature: Optional Channel of `artifacts.MetaFeatures` of the dataset to be tuned. This is used as an input to the `meta_model` to predict search space. best_hyperparameters: Optional Channel of type `standard_artifacts.HyperParameters` for result of the best hparams. instance_name: Optional unique instance name. Necessary if multiple Tuner components are declared in the same pipeline. """ if bool(module_file) == bool(tuner_fn): raise ValueError( "Exactly one of 'module_file' or 'tuner_fn' must be supplied") best_hyperparameters = best_hyperparameters or types.Channel( type=standard_artifacts.HyperParameters, artifacts=[standard_artifacts.HyperParameters()]) trial_summary_plot = types.Channel(type=TunerData, artifacts=[TunerData()]) spec = AugmentedTunerSpec( examples=examples, schema=schema, transform_graph=transform_graph, module_file=module_file, tuner_fn=tuner_fn, train_args=train_args, eval_args=eval_args, tune_args=tune_args, metalearning_algorithm=metalearning_algorithm, warmup_hyperparameters=warmup_hyperparameters, metamodel=metamodel, metafeature=metafeature, best_hyperparameters=best_hyperparameters, trial_summary_plot=trial_summary_plot, custom_config=json_utils.dumps(custom_config), ) super(AugmentedTuner, self).__init__(spec=spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, output: Optional[types.Channel] = None, model_run: Optional[types.Channel] = None, transform_output: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: An optional Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. Schema is optional when 1) transform_graph is provided which contains schema. 2) user module bypasses the usage of schema, e.g., hardcoded. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. For default executor, The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.fn_args_utils.FnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver. Exactly one of 'module_file' or 'trainer_fn' must be supplied. For generic executor, The module_file must implement a function named `run_fn` at its top level with function signature: `def run_fn(trainer.fn_args_utils.FnArgs)`, and the trained model must be saved to FnArgs.serving_model_dir when execute this function. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor. trainer_fn: A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance or a dict, containing args used for training. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. Default behavior (when splits is empty) is train on `train` split. eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args used for evaluation. Currently only splits and num_steps are available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. Default behavior (when splits is empty) is evaluate on `eval` split. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. output: Optional `Model` channel for result of exported models. model_run: Optional `ModelRun` channel, as the working dir of models, can be used to output non-model related output (e.g., TensorBoard logs). transform_output: Backwards compatibility alias for the 'transform_graph' argument. instance_name: Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: raise ValueError( "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be " "supplied.") if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transform_output: absl.logging.warning( 'The "transform_output" argument to the Trainer component has ' 'been renamed to "transform_graph" and is deprecated. Please update ' "your usage as support for this argument will be removed soon." ) transform_graph = transform_output if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") examples = examples or transformed_examples output = output or types.Channel(type=standard_artifacts.Model) model_run = model_run or types.Channel( type=standard_artifacts.ModelRun) spec = TrainerSpec(examples=examples, transform_graph=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args, eval_args=eval_args, module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, custom_config=json_utils.dumps(custom_config), model=output, model_run=model_run) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def testAirflowDagRunner(self, mock_airflow_dag_class, mock_airflow_component_class): mock_airflow_dag_class.return_value = 'DAG' mock_airflow_component_a = mock.Mock() mock_airflow_component_b = mock.Mock() mock_airflow_component_c = mock.Mock() mock_airflow_component_d = mock.Mock() mock_airflow_component_e = mock.Mock() mock_airflow_component_class.side_effect = [ mock_airflow_component_a, mock_airflow_component_b, mock_airflow_component_c, mock_airflow_component_d, mock_airflow_component_e ] airflow_config = { 'schedule_interval': '* * * * *', 'start_date': datetime.datetime(2019, 1, 1) } component_a = _FakeComponent( _FakeComponentSpecA(output=types.Channel(type=_ArtifactTypeA))) component_b = _FakeComponent( _FakeComponentSpecB(a=component_a.outputs['output'], output=types.Channel(type=_ArtifactTypeB))) component_c = _FakeComponent( _FakeComponentSpecC(a=component_a.outputs['output'], b=component_b.outputs['output'], output=types.Channel(type=_ArtifactTypeC))) component_d = _FakeComponent( _FakeComponentSpecD(b=component_b.outputs['output'], c=component_c.outputs['output'], output=types.Channel(type=_ArtifactTypeD))) component_e = _FakeComponent( _FakeComponentSpecE(a=component_a.outputs['output'], b=component_b.outputs['output'], d=component_d.outputs['output'], output=types.Channel(type=_ArtifactTypeE))) test_pipeline = pipeline.Pipeline(pipeline_name='x', pipeline_root='y', metadata_connection_config=None, components=[ component_d, component_c, component_a, component_b, component_e ]) runner = airflow_dag_runner.AirflowDagRunner( airflow_dag_runner.AirflowPipelineConfig( airflow_dag_config=airflow_config)) runner.run(test_pipeline) mock_airflow_component_a.set_upstream.assert_not_called() mock_airflow_component_b.set_upstream.assert_has_calls( [mock.call(mock_airflow_component_a)]) mock_airflow_component_c.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_a), mock.call(mock_airflow_component_b) ], any_order=True) mock_airflow_component_d.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_b), mock.call(mock_airflow_component_c) ], any_order=True) mock_airflow_component_e.set_upstream.assert_has_calls([ mock.call(mock_airflow_component_a), mock.call(mock_airflow_component_b), mock.call(mock_airflow_component_d) ], any_order=True)