def testConstructFromTrainerFn(self): trainer_fn = 'path.to.my_trainer_fn' trainer = component.Trainer(trainer_fn=trainer_fn, transformed_examples=self.examples, transform_graph=self.transform_output, schema=self.schema, train_args=self.train_args, eval_args=self.eval_args) self._verify_outputs(trainer) self.assertEqual(trainer_fn, trainer.spec.exec_properties['trainer_fn'])
def testConstructFromModuleFile(self): module_file = '/path/to/module/file' trainer = component.Trainer(module_file=module_file, transformed_examples=self.examples, transform_graph=self.transform_output, schema=self.schema, train_args=self.train_args, eval_args=self.eval_args) self._verify_outputs(trainer) self.assertEqual(module_file, trainer.spec.exec_properties['module_file'])
def testConstructWithHParams(self): trainer = component.Trainer( trainer_fn='path.to.my_trainer_fn', transformed_examples=self.examples, transform_graph=self.transform_output, schema=self.schema, hyperparameters=self.hyperparameters, train_args=self.train_args, eval_args=self.eval_args) self._verify_outputs(trainer) self.assertEqual(standard_artifacts.HyperParameters.TYPE_NAME, trainer.inputs['hyperparameters'].type_name)
def test_construct(self): transformed_examples = types.TfxType(type_name='ExamplesPath') transform_output = types.TfxType(type_name='TransformPath') schema = types.TfxType(type_name='SchemaPath') trainer = component.Trainer( module_file='/path/to/module/file', transformed_examples=channel.as_channel([transformed_examples]), transform_output=channel.as_channel([transform_output]), schema=channel.as_channel([schema]), train_args=trainer_pb2.TrainArgs(num_steps=100), eval_args=trainer_pb2.EvalArgs(num_steps=50)) self.assertEqual('ModelExportPath', trainer.outputs.output.type_name)
def testConstructFromRunFn(self): run_fn = 'path.to.my_run_fn' trainer = component.Trainer( run_fn=run_fn, custom_executor_spec=executor_spec.ExecutorClassSpec( executor.GenericExecutor), transformed_examples=self.examples, transform_graph=self.transform_output, train_args=self.train_args, eval_args=self.eval_args) self._verify_outputs(trainer) self.assertEqual(run_fn, trainer.spec.exec_properties['run_fn'])
def testConstructWithParameter(self): module_file = data_types.RuntimeParameter(name='module-file', ptype=Text) n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int) trainer = component.Trainer( module_file=module_file, transformed_examples=self.examples, transform_graph=self.transform_output, schema=self.schema, train_args=dict(num_steps=n_steps), eval_args=dict(num_steps=n_steps)) self._verify_outputs(trainer) self.assertJsonEqual( str(module_file), str(trainer.spec.exec_properties['module_file']))
def testConstructFromModuleFile(self): module_file = '/path/to/module/file' trainer = component.Trainer(module_file=module_file, examples=self.examples, transform_graph=self.transform_graph, schema=self.schema, custom_config={'test': 10}) self._verify_outputs(trainer) self.assertEqual( module_file, trainer.spec.exec_properties[ standard_component_specs.MODULE_FILE_KEY]) self.assertEqual( '{"test": 10}', trainer.spec.exec_properties[ standard_component_specs.CUSTOM_CONFIG_KEY])
def testConstructWithParameter(self): module_file = data_types.RuntimeParameter(name='module-file', ptype=Text) n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int) trainer = component.Trainer(module_file=module_file, examples=self.examples, transform_graph=self.transform_graph, schema=self.schema, train_args=dict(splits=['train'], num_steps=n_steps), eval_args=dict(splits=['eval'], num_steps=n_steps)) self._verify_outputs(trainer) self.assertJsonEqual( str(module_file), str(trainer.spec.exec_properties[ standard_component_specs.MODULE_FILE_KEY]))
def __init__(self, module_file: str, transformed_examples: str, schema: str, transform_output: str, training_steps: int, eval_training_steps: int): component = trainer_component.Trainer( module_file=module_file, transformed_examples=channel.Channel('ExamplesPath'), schema=channel.Channel('SchemaPath'), transform_output=channel.Channel('TransformPath'), train_args=trainer_pb2.TrainArgs(num_steps=training_steps), eval_args=trainer_pb2.EvalArgs(num_steps=eval_training_steps)) super().__init__( component, { "transformed_examples": transformed_examples, "schema": schema, "transform_output": transform_output })