def testDumpUiMetadata(self): trainer = pipeline_pb2.PipelineNode() trainer.node_info.type.name = 'tfx.components.trainer.component.Trainer' model_run_out_spec = pipeline_pb2.OutputSpec( artifact_spec=pipeline_pb2.OutputSpec.ArtifactSpec( type=metadata_store_pb2.ArtifactType( name=standard_artifacts.ModelRun.TYPE_NAME))) trainer.outputs.outputs['model_run'].CopyFrom(model_run_out_spec) model_run = standard_artifacts.ModelRun() model_run.uri = 'model_run_uri' exec_info = data_types.ExecutionInfo( input_dict={}, output_dict={'model_run': [model_run]}, exec_properties={}, execution_id='id') ui_metadata_path = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName, 'json') fileio.makedirs(os.path.dirname(ui_metadata_path)) container_entrypoint._dump_ui_metadata(trainer, exec_info, ui_metadata_path) with open(ui_metadata_path) as f: ui_metadata = json.load(f) self.assertEqual('tensorboard', ui_metadata['outputs'][-1]['type']) self.assertEqual('model_run_uri', ui_metadata['outputs'][-1]['source'])
def testDumpUiMetadata(self): trainer = Trainer(examples=Channel(type=standard_artifacts.Examples), module_file='module_file', train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=100), eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)) model_run = standard_artifacts.ModelRun() model_run.uri = 'model_run_uri' exec_info = data_types.ExecutionInfo( input_dict={}, output_dict={'model_run': [model_run]}, exec_properties={}, execution_id='id') ui_metadata_path = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName, 'json') fileio.makedirs(os.path.dirname(ui_metadata_path)) container_entrypoint._dump_ui_metadata(trainer, exec_info, ui_metadata_path) with open(ui_metadata_path) as f: ui_metadata = json.load(f) self.assertEqual('tensorboard', ui_metadata['outputs'][-1]['type']) self.assertEqual('model_run_uri', ui_metadata['outputs'][-1]['source'])
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. e1 = standard_artifacts.Examples() e1.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') e1.split_names = artifact_utils.encode_split_names(['train', 'eval']) e2 = copy.deepcopy(e1) self._single_artifact = [e1] self._multiple_artifacts = [e1, e2] transform_graph = standard_artifacts.TransformGraph() transform_graph.uri = os.path.join(self._source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._single_artifact, standard_component_specs.TRANSFORM_GRAPH_KEY: [transform_graph], standard_component_specs.SCHEMA_KEY: [schema], standard_component_specs.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._model_run_exports = standard_artifacts.ModelRun() self._model_run_exports.uri = os.path.join(self._output_data_dir, 'model_run_path') self._output_dict = { standard_component_specs.MODEL_KEY: [self._model_exports], standard_component_specs.MODEL_RUN_KEY: [self._model_run_exports] } # Create exec properties skeleton. self._exec_properties = { standard_component_specs.TRAIN_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=1000)), standard_component_specs.EVAL_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=500)), 'warm_starting': False, } self._module_file = os.path.join( self._source_data_dir, standard_component_specs.MODULE_FILE_KEY, 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()
def __init__( self, examples: types.Channel = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: types.Channel = None, base_model: Optional[types.Channel] = None, hyperparameters: Optional[types.Channel] = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, output: Optional[types.Channel] = None, model_run: Optional[types.Channel] = None, transform_output: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Trainer component. Args: examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. transformed_examples: Deprecated field. Please set 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. schema: A Channel of type `standard_artifacts.Schema`, serving as the schema of training and eval data. base_model: A Channel of type `Model`, containing model that will be used for training. This can be used for warmstart, transfer learning or model ensembling. hyperparameters: A Channel of type `standard_artifacts.HyperParameters`, serving as the hyperparameters for training module. Tuner's output best hyperparameters can be feed into this. module_file: A path to python module file containing UDF model definition. For default executor, The module_file must implement a function named `trainer_fn` at its top level. The function must have the following signature. def trainer_fn(trainer.executor.TrainerFnArgs, tensorflow_metadata.proto.v0.schema_pb2) -> Dict: ... where the returned Dict has the following key-values. 'estimator': an instance of tf.estimator.Estimator 'train_spec': an instance of tf.estimator.TrainSpec 'eval_spec': an instance of tf.estimator.EvalSpec 'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver. Exactly one of 'module_file' or 'trainer_fn' must be supplied. For generic executor, The module_file must implement a function named `run_fn` at its top level with function signature: `def run_fn(trainer.executor.TrainerFnArgs)`, and the trained model must be saved to TrainerFnArgs.serving_model_dir when execute this function. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor. trainer_fn: A python path to UDF model definition function for estimator based trainer. See 'module_file' for the required signature of the UDF. Exactly one of 'module_file' or 'trainer_fn' must be supplied. train_args: A trainer_pb2.TrainArgs instance or a dict, containing args used for training. Current only num_steps is available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a TrainArgs proto message. eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args used for evaluation. Current only num_steps is available. If it's provided as a dict and any field is a RuntimeParameter, it should have the same field names as a EvalArgs proto message. custom_config: A dict which contains addtional training job parameters that will be passed into user module. custom_executor_spec: Optional custom executor spec. output: Optional `Model` channel for result of exported models. model_run: Optional `ModelRun` channel, as the working dir of models, can be used to output non-model related output (e.g., TensorBoard logs). transform_output: Backwards compatibility alias for the 'transform_graph' argument. instance_name: Optional unique instance name. Necessary iff multiple Trainer components are declared in the same pipeline. Raises: ValueError: - When both or neither of 'module_file' and user function (e.g., trainer_fn and run_fn) is supplied. - When both or neither of 'examples' and 'transformed_examples' is supplied. - When 'transformed_examples' is supplied but 'transform_graph' is not supplied. """ if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: raise ValueError( "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be supplied." ) if bool(examples) == bool(transformed_examples): raise ValueError( "Exactly one of 'example' or 'transformed_example' must be supplied." ) if transform_output: absl.logging.warning( 'The "transform_output" argument to the Trainer component has ' 'been renamed to "transform_graph" and is deprecated. Please update ' "your usage as support for this argument will be removed soon." ) transform_graph = transform_output if transformed_examples and not transform_graph: raise ValueError("If 'transformed_examples' is supplied, " "'transform_graph' must be supplied too.") examples = examples or transformed_examples output = output or types.Channel( type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()]) model_run = model_run or types.Channel( type=standard_artifacts.ModelRun, artifacts=[standard_artifacts.ModelRun()]) spec = TrainerSpec( examples=examples, transform_graph=transform_graph, schema=schema, base_model=base_model, hyperparameters=hyperparameters, train_args=train_args, eval_args=eval_args, module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, custom_config=json_utils.dumps(custom_config), model=output, # TODO(b/158106209): change the model_run as optional output artifact model_run=model_run) super(Trainer, self).__init__(spec=spec, custom_executor_spec=custom_executor_spec, instance_name=instance_name)
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(self._source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { constants.EXAMPLES_KEY: [examples], constants.TRANSFORM_GRAPH_KEY: [transform_output], constants.SCHEMA_KEY: [schema], constants.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._model_run_exports = standard_artifacts.ModelRun() self._model_run_exports.uri = os.path.join(self._output_data_dir, 'model_run_path') self._output_dict = { constants.MODEL_KEY: [self._model_exports], constants.MODEL_RUN_KEY: [self._model_run_exports] } # Create exec properties skeleton. self._exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000), preserving_proto_field_name=True), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500), preserving_proto_field_name=True), 'warm_starting': False, } self._module_file = os.path.join(self._source_data_dir, 'module_file', 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()