def setUp(self): super(ComponentTest, self).setUp() self.examples = channel_utils.as_channel([standard_artifacts.Examples()]) self.transform_output = channel_utils.as_channel( [standard_artifacts.TransformGraph()]) self.schema = channel_utils.as_channel([standard_artifacts.Schema()]) self.hyperparameters = channel_utils.as_channel( [standard_artifacts.HyperParameters()]) self.train_args = trainer_pb2.TrainArgs(splits=['train'], num_steps=100) self.eval_args = trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)
def setUp(self): super(KubeflowGCPIntegrationTest, self).setUp() # Raw Example artifacts for testing. raw_train_examples = standard_artifacts.Examples(split='train') raw_train_examples.uri = os.path.join( self._intermediate_data_root, 'csv_example_gen/examples/test-pipeline/train/') raw_eval_examples = standard_artifacts.Examples(split='eval') raw_eval_examples.uri = os.path.join( self._intermediate_data_root, 'csv_example_gen/examples/test-pipeline/eval/') self._test_raw_examples = [raw_train_examples, raw_eval_examples] # Transformed Example artifacts for testing. transformed_train_examples = standard_artifacts.Examples(split='train') transformed_train_examples.uri = os.path.join( self._intermediate_data_root, 'transform/transformed_examples/test-pipeline/train/') transformed_eval_examples = standard_artifacts.Examples(split='eval') transformed_eval_examples.uri = os.path.join( self._intermediate_data_root, 'transform/transformed_examples/test-pipeline/eval/') self._test_transformed_examples = [ transformed_train_examples, transformed_eval_examples ] # Schema artifact for testing. schema = standard_artifacts.Schema() schema.uri = os.path.join(self._intermediate_data_root, 'schema_gen/output/test-pipeline/') self._test_schema = [schema] # TransformGraph artifact for testing. transform_graph = standard_artifacts.TransformGraph() transform_graph.uri = os.path.join( self._intermediate_data_root, 'transform/transform_output/test-pipeline/') self._test_transform_graph = [transform_graph] # Model artifact for testing. model = standard_artifacts.Model() model.uri = os.path.join(self._intermediate_data_root, 'trainer/output/test-pipeline/') self._test_model = [model] # ModelBlessing artifact for testing. model_blessing = standard_artifacts.ModelBlessing() model_blessing.uri = os.path.join( self._intermediate_data_root, 'model_validator/blessing/test-pipeline/') self._test_model_blessing = [model_blessing]
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. example1 = standard_artifacts.Examples() example1.uri = self._ARTIFACT1_URI example1.split_names = artifact_utils.encode_split_names( ['train', 'eval']) example2 = copy.deepcopy(example1) example2.uri = self._ARTIFACT2_URI self._example_artifacts = [example1, example2] schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._example_artifacts[:1], standard_component_specs.SCHEMA_KEY: [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_graph') transformed1 = standard_artifacts.Examples() transformed1.uri = os.path.join(output_data_dir, 'transformed_examples', '0') transformed2 = standard_artifacts.Examples() transformed2.uri = os.path.join(output_data_dir, 'transformed_examples', '1') self._transformed_example_artifacts = [transformed1, transformed2] temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache( ) self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [self._transformed_output], standard_component_specs.TRANSFORMED_EXAMPLES_KEY: self._transformed_example_artifacts[:1], executor.TEMP_PATH_KEY: [temp_path_output], standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [self._updated_analyzer_cache_artifact], } # Create exec properties skeleton. self._exec_properties = {}
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. train_examples = standard_artifacts.Examples(split='train') train_examples.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples/train/') eval_examples = standard_artifacts.Examples(split='eval') eval_examples.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples/eval/') transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(self._source_data_dir, 'transform/transform_output/') schema = standard_artifacts.Examples() schema.uri = os.path.join(self._source_data_dir, 'schema_gen/') self._input_dict = { 'examples': [train_examples, eval_examples], 'transform_output': [transform_output], 'schema': [schema], } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._output_dict = {'output': [self._model_exports]} # Create exec properties skeleton. self._exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000)), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500)), 'warm_starting': False, } self._module_file = os.path.join(self._source_data_dir, 'module_file', 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executor for test. self._trainer_executor = executor.Executor()
def testGetCommonFnArgs(self): source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'transform/transformed_examples') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(source_data_dir, 'schema_gen') input_dict = { constants.EXAMPLES_KEY: [examples], constants.TRANSFORM_GRAPH_KEY: [transform_output], constants.SCHEMA_KEY: [schema], } # Create exec properties skeleton. exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000), preserving_proto_field_name=True), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500), preserving_proto_field_name=True), } fn_args = fn_args_utils.get_common_fn_args(input_dict, exec_properties, 'tempdir') self.assertEqual(fn_args.working_dir, 'tempdir') self.assertEqual(fn_args.train_steps, 1000) self.assertEqual(fn_args.eval_steps, 500) self.assertLen(fn_args.train_files, 1) self.assertEqual(fn_args.train_files[0], os.path.join(examples.uri, 'train', '*')) self.assertLen(fn_args.eval_files, 1) self.assertEqual(fn_args.eval_files[0], os.path.join(examples.uri, 'eval', '*')) self.assertEqual(fn_args.schema_path, os.path.join(schema.uri, 'schema.pbtxt')) self.assertEqual(fn_args.transform_graph_path, transform_output.uri) self.assertIsInstance(fn_args.data_accessor, fn_args_utils.DataAccessor)
def testConstructTransformGraph(self): output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) artifact_channel = channel_utils.as_channel( [standard_artifacts.TransformGraph()]) component_instance = component.TransformGraphPusher( artifact=artifact_channel, push_destination=pusher_pb2.PushDestination( filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=output_data_dir))) self.assertEqual('TransformGraph', component_instance.inputs.artifact.type_name) self.assertEqual('TransformGraph', component_instance.outputs.pushed_artifact.type_name)
def setUp(self): super(ExecutorTest, self).setUp() # Create input_dict. self._input_data_dir = os.path.join(os.path.dirname(__file__), 'testdata') examples = standard_artifacts.Examples() examples.uri = os.path.join(self._input_data_dir, 'example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(self._input_data_dir, 'schema_gen') self._input_dict = { standard_component_specs.EXAMPLES_KEY: [examples], standard_component_specs.SCHEMA_KEY: [schema_artifact], } # Create output_dict. output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)), self._testMethodName) self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_output') self._transformed_examples = standard_artifacts.Examples() self._transformed_examples.uri = output_data_dir self._transformed_examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [self._transformed_output], standard_component_specs.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples], tfx_executor.TEMP_PATH_KEY: [temp_path_output], } # Create exec properties. self._exec_properties = { 'custom_config': json.dumps({'problem_statement_path': '/some/fake/path'}) }
def setUp(self): super(ComponentTest, self).setUp() examples_artifact = standard_artifacts.Examples() examples_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transform_output = standard_artifacts.TransformGraph() self.examples = channel_utils.as_channel([examples_artifact]) self.schema = channel_utils.as_channel([standard_artifacts.Schema()]) self.transform_graph = channel_utils.as_channel([transform_output]) self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2} self.train_args = trainer_pb2.TrainArgs(num_steps=100) self.eval_args = trainer_pb2.EvalArgs(num_steps=50) self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3) self.warmup_hyperparams = channel_utils.as_channel( [artifacts.KCandidateHyperParameters()]) self.meta_model = channel_utils.as_channel( [standard_artifacts.Model()])
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. train_artifact = standard_artifacts.Examples(split='train') train_artifact.uri = os.path.join(source_data_dir, 'csv_example_gen/train/') eval_artifact = standard_artifacts.Examples(split='eval') eval_artifact.uri = os.path.join(source_data_dir, 'csv_example_gen/eval/') schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen/') self._input_dict = { 'input_data': [train_artifact, eval_artifact], 'schema': [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_output') self._transformed_train_examples = standard_artifacts.Examples( split='train') self._transformed_train_examples.uri = os.path.join( output_data_dir, 'train') self._transformed_eval_examples = standard_artifacts.Examples( split='eval') self._transformed_eval_examples.uri = os.path.join( output_data_dir, 'eval') temp_path_output = types.Artifact('TempPath') temp_path_output.uri = tempfile.mkdtemp() self._output_dict = { 'transform_output': [self._transformed_output], 'transformed_examples': [ self._transformed_train_examples, self._transformed_eval_examples ], 'temp_path': [temp_path_output], } # Create exec properties skeleton. self._exec_properties = {}
def _make_base_do_params(self, source_data_dir, output_data_dir): # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(source_data_dir, 'csv_example_gen') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) schema_artifact = standard_artifacts.Schema() schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen') self._input_dict = { executor.EXAMPLES_KEY: [examples], executor.SCHEMA_KEY: [schema_artifact], } # Create output dict. self._transformed_output = standard_artifacts.TransformGraph() self._transformed_output.uri = os.path.join(output_data_dir, 'transformed_graph') self._transformed_examples = standard_artifacts.Examples() self._transformed_examples.uri = os.path.join(output_data_dir, 'transformed_examples') temp_path_output = _TempPath() temp_path_output.uri = tempfile.mkdtemp() self._updated_analyzer_cache_artifact = standard_artifacts.TransformCache( ) self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') self._output_dict = { executor.TRANSFORM_GRAPH_KEY: [self._transformed_output], executor.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples], executor.TEMP_PATH_KEY: [temp_path_output], executor.UPDATED_ANALYZER_CACHE_KEY: [self._updated_analyzer_cache_artifact], } # Create exec properties skeleton. self._exec_properties = {}
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. e1 = standard_artifacts.Examples() e1.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') e1.split_names = artifact_utils.encode_split_names(['train', 'eval']) e2 = copy.deepcopy(e1) self._single_artifact = [e1] self._multiple_artifacts = [e1, e2] transform_graph = standard_artifacts.TransformGraph() transform_graph.uri = os.path.join(self._source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { standard_component_specs.EXAMPLES_KEY: self._single_artifact, standard_component_specs.TRANSFORM_GRAPH_KEY: [transform_graph], standard_component_specs.SCHEMA_KEY: [schema], standard_component_specs.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._model_run_exports = standard_artifacts.ModelRun() self._model_run_exports.uri = os.path.join(self._output_data_dir, 'model_run_path') self._output_dict = { standard_component_specs.MODEL_KEY: [self._model_exports], standard_component_specs.MODEL_RUN_KEY: [self._model_run_exports] } # Create exec properties skeleton. self._exec_properties = { standard_component_specs.TRAIN_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=1000)), standard_component_specs.EVAL_ARGS_KEY: proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=500)), 'warm_starting': False, } self._module_file = os.path.join( self._source_data_dir, standard_component_specs.MODULE_FILE_KEY, 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, input_data: Optional[types.Channel] = None, instance_name: Optional[Text] = None, materialize: bool = True): """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. input_data: Backwards compatibility alias for the 'examples' argument. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. materialize: If True, write transformed examples as an output. If False, `transformed_examples` must not be provided. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ if input_data: absl.logging.warning( 'The "input_data" argument to the Transform component has ' 'been renamed to "examples" and is deprecated. Please update your ' 'usage as support for this argument will be removed soon.') examples = input_data if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if materialize and transformed_examples is None: transformed_examples = types.Channel( type=standard_artifacts.Examples, # TODO(b/161548528): remove the hardcode artifact. artifacts=[standard_artifacts.Examples()], matching_channel_name='examples') elif not materialize and transformed_examples is not None: raise ValueError( 'must not specify transformed_examples when materialize==False' ) spec = TransformSpec(examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_graph=transform_graph, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, instance_name=instance_name)
def __init__(self, input_data: types.Channel = None, schema: types.Channel = None, module_file: Optional[Text] = None, preprocessing_fn: Optional[Text] = None, transform_output: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, examples: Optional[types.Channel] = None, instance_name: Optional[Text] = None): """Construct a Transform component. Args: input_data: A Channel of 'ExamplesPath' type (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of 'SchemaPath' type. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. transform_output: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. examples: Forwards compatibility alias for the 'input_data' argument. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ input_data = input_data or examples if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_output = transform_output or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) transformed_examples = transformed_examples or types.Channel( type=standard_artifacts.Examples, artifacts=[ standard_artifacts.Examples(split=split) for split in artifact.DEFAULT_EXAMPLE_SPLITS ]) spec = TransformSpec(input_data=input_data, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, transform_output=transform_output, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, instance_name=instance_name)
def __init__( self, examples: types.Channel = None, schema: types.Channel = None, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[ Text, data_types.RuntimeParameter]] = None, custom_config: Optional[Dict[Text, Any]] = None, transform_graph: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, instance_name: Optional[Text] = None): # pyformat: disable # pylint: disable=g-doc-args """Construct a Transform component. Args: examples: A Channel of type `standard_artifacts.Examples` (required). This should contain the two splits 'train' and 'eval'. schema: A Channel of type `standard_artifacts.Schema`. This should contain a single schema artifact. module_file: The file path to a python module file, from which the 'preprocessing_fn' function will be loaded. The function must have the following signature. def preprocessing_fn(inputs: Dict[Text, Any], schema: schema_pb2.Schema, custom_config: Dict[Text, Any]) -> Dict[Text, Any]: ... where the values of input and returned Dict are either tf.Tensor or tf.SparseTensor. The 'schema' and 'custom_config' arguments are not necessary and can be omitted. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. preprocessing_fn: The path to python function that implements a 'preprocessing_fn'. See 'module_file' for expected signature of the function. Exactly one of 'module_file' or 'preprocessing_fn' must be supplied. custom_config: A dict which contains additional transform parameters that will be passed into the preprocessing_fn. transform_graph: Optional output 'TransformPath' channel for output of 'tf.Transform', which includes an exported Tensorflow graph suitable for both training and serving; transformed_examples: Optional output 'ExamplesPath' channel for materialized transformed examples, which includes both 'train' and 'eval' splits. instance_name: Optional unique instance name. Necessary iff multiple transform components are declared in the same pipeline. Raises: ValueError: When both or neither of 'module_file' and 'preprocessing_fn' is supplied. """ # pyformat: enable # pylint: enable=g-doc-args if bool(module_file) == bool(preprocessing_fn): raise ValueError( "Exactly one of 'module_file' or 'preprocessing_fn' must be supplied." ) transform_graph = transform_graph or types.Channel( type=standard_artifacts.TransformGraph, artifacts=[standard_artifacts.TransformGraph()]) if not transformed_examples: example_artifact = standard_artifacts.Examples() example_artifact.split_names = artifact_utils.encode_split_names( artifact.DEFAULT_EXAMPLE_SPLITS) transformed_examples = types.Channel( type=standard_artifacts.Examples, artifacts=[example_artifact]) spec = standard_component_specs.TransformSpec( examples=examples, schema=schema, module_file=module_file, preprocessing_fn=preprocessing_fn, custom_config=json.dumps(custom_config), transform_graph=transform_graph, transformed_examples=transformed_examples) super(Transform, self).__init__(spec=spec, instance_name=instance_name)
def setUp(self): super(ExecutorTest, self).setUp() self._source_data_dir = os.path.join( os.path.dirname(os.path.dirname(__file__)), 'testdata') self._output_data_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) # Create input dict. examples = standard_artifacts.Examples() examples.uri = os.path.join(self._source_data_dir, 'transform/transformed_examples') examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) transform_output = standard_artifacts.TransformGraph() transform_output.uri = os.path.join(self._source_data_dir, 'transform/transform_graph') schema = standard_artifacts.Schema() schema.uri = os.path.join(self._source_data_dir, 'schema_gen') previous_model = standard_artifacts.Model() previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous') self._input_dict = { constants.EXAMPLES_KEY: [examples], constants.TRANSFORM_GRAPH_KEY: [transform_output], constants.SCHEMA_KEY: [schema], constants.BASE_MODEL_KEY: [previous_model] } # Create output dict. self._model_exports = standard_artifacts.Model() self._model_exports.uri = os.path.join(self._output_data_dir, 'model_export_path') self._model_run_exports = standard_artifacts.ModelRun() self._model_run_exports.uri = os.path.join(self._output_data_dir, 'model_run_path') self._output_dict = { constants.MODEL_KEY: [self._model_exports], constants.MODEL_RUN_KEY: [self._model_run_exports] } # Create exec properties skeleton. self._exec_properties = { 'train_args': json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000), preserving_proto_field_name=True), 'eval_args': json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500), preserving_proto_field_name=True), 'warm_starting': False, } self._module_file = os.path.join(self._source_data_dir, 'module_file', 'trainer_module.py') self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__, trainer_module.trainer_fn.__name__) # Executors for test. self._trainer_executor = executor.Executor() self._generic_trainer_executor = executor.GenericExecutor()