示例#1
0
 def testConstructFromTrainerFn(self):
     trainer_fn = 'path.to.my_trainer_fn'
     trainer = component.Trainer(trainer_fn=trainer_fn,
                                 transformed_examples=self.examples,
                                 transform_graph=self.transform_output,
                                 schema=self.schema,
                                 train_args=self.train_args,
                                 eval_args=self.eval_args)
     self._verify_outputs(trainer)
     self.assertEqual(trainer_fn,
                      trainer.spec.exec_properties['trainer_fn'])
示例#2
0
 def testConstructFromModuleFile(self):
     module_file = '/path/to/module/file'
     trainer = component.Trainer(module_file=module_file,
                                 transformed_examples=self.examples,
                                 transform_graph=self.transform_output,
                                 schema=self.schema,
                                 train_args=self.train_args,
                                 eval_args=self.eval_args)
     self._verify_outputs(trainer)
     self.assertEqual(module_file,
                      trainer.spec.exec_properties['module_file'])
示例#3
0
 def testConstructWithHParams(self):
   trainer = component.Trainer(
       trainer_fn='path.to.my_trainer_fn',
       transformed_examples=self.examples,
       transform_graph=self.transform_output,
       schema=self.schema,
       hyperparameters=self.hyperparameters,
       train_args=self.train_args,
       eval_args=self.eval_args)
   self._verify_outputs(trainer)
   self.assertEqual(standard_artifacts.HyperParameters.TYPE_NAME,
                    trainer.inputs['hyperparameters'].type_name)
示例#4
0
 def test_construct(self):
     transformed_examples = types.TfxType(type_name='ExamplesPath')
     transform_output = types.TfxType(type_name='TransformPath')
     schema = types.TfxType(type_name='SchemaPath')
     trainer = component.Trainer(
         module_file='/path/to/module/file',
         transformed_examples=channel.as_channel([transformed_examples]),
         transform_output=channel.as_channel([transform_output]),
         schema=channel.as_channel([schema]),
         train_args=trainer_pb2.TrainArgs(num_steps=100),
         eval_args=trainer_pb2.EvalArgs(num_steps=50))
     self.assertEqual('ModelExportPath', trainer.outputs.output.type_name)
示例#5
0
 def testConstructFromRunFn(self):
     run_fn = 'path.to.my_run_fn'
     trainer = component.Trainer(
         run_fn=run_fn,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             executor.GenericExecutor),
         transformed_examples=self.examples,
         transform_graph=self.transform_output,
         train_args=self.train_args,
         eval_args=self.eval_args)
     self._verify_outputs(trainer)
     self.assertEqual(run_fn, trainer.spec.exec_properties['run_fn'])
示例#6
0
 def testConstructWithParameter(self):
   module_file = data_types.RuntimeParameter(name='module-file', ptype=Text)
   n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int)
   trainer = component.Trainer(
       module_file=module_file,
       transformed_examples=self.examples,
       transform_graph=self.transform_output,
       schema=self.schema,
       train_args=dict(num_steps=n_steps),
       eval_args=dict(num_steps=n_steps))
   self._verify_outputs(trainer)
   self.assertJsonEqual(
       str(module_file), str(trainer.spec.exec_properties['module_file']))
示例#7
0
 def testConstructFromModuleFile(self):
     module_file = '/path/to/module/file'
     trainer = component.Trainer(module_file=module_file,
                                 examples=self.examples,
                                 transform_graph=self.transform_graph,
                                 schema=self.schema,
                                 custom_config={'test': 10})
     self._verify_outputs(trainer)
     self.assertEqual(
         module_file, trainer.spec.exec_properties[
             standard_component_specs.MODULE_FILE_KEY])
     self.assertEqual(
         '{"test": 10}', trainer.spec.exec_properties[
             standard_component_specs.CUSTOM_CONFIG_KEY])
示例#8
0
 def testConstructWithParameter(self):
     module_file = data_types.RuntimeParameter(name='module-file',
                                               ptype=Text)
     n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int)
     trainer = component.Trainer(module_file=module_file,
                                 examples=self.examples,
                                 transform_graph=self.transform_graph,
                                 schema=self.schema,
                                 train_args=dict(splits=['train'],
                                                 num_steps=n_steps),
                                 eval_args=dict(splits=['eval'],
                                                num_steps=n_steps))
     self._verify_outputs(trainer)
     self.assertJsonEqual(
         str(module_file),
         str(trainer.spec.exec_properties[
             standard_component_specs.MODULE_FILE_KEY]))
    def __init__(self, module_file: str, transformed_examples: str,
                 schema: str, transform_output: str, training_steps: int,
                 eval_training_steps: int):
        component = trainer_component.Trainer(
            module_file=module_file,
            transformed_examples=channel.Channel('ExamplesPath'),
            schema=channel.Channel('SchemaPath'),
            transform_output=channel.Channel('TransformPath'),
            train_args=trainer_pb2.TrainArgs(num_steps=training_steps),
            eval_args=trainer_pb2.EvalArgs(num_steps=eval_training_steps))

        super().__init__(
            component, {
                "transformed_examples": transformed_examples,
                "schema": schema,
                "transform_output": transform_output
            })