コード例 #1
0
 def testVertexDistributedTunerPipeline(self):
     """Tuner-only pipeline for distributed Tuner flock on Vertex AI Training."""
     pipeline_name = self._make_unique_pipeline_name(
         'kubeflow-vertex-dist-tuner')
     pipeline = self._create_pipeline(
         pipeline_name,
         [
             self.penguin_examples_importer,
             self.penguin_schema_importer,
             ai_platform_tuner_component.Tuner(
                 examples=self.penguin_examples_importer.outputs['result'],
                 module_file=self._penguin_tuner_module,
                 schema=self.penguin_schema_importer.outputs['result'],
                 train_args=trainer_pb2.TrainArgs(num_steps=10),
                 eval_args=trainer_pb2.EvalArgs(num_steps=5),
                 # 3 worker parallel tuning.
                 tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3),
                 custom_config={
                     ai_platform_tuner_executor.TUNING_ARGS_KEY:
                     self._getVertexTrainingArgs(pipeline_name),
                     constants.ENABLE_VERTEX_KEY:
                     True,
                     constants.VERTEX_REGION_KEY:
                     self._GCP_REGION
                 })
         ])
     self._compile_and_run_pipeline(pipeline)
     self._assertHyperparametersAreWritten(pipeline_name)
コード例 #2
0
ファイル: component_test.py プロジェクト: jay90099/tfx
 def testConstructWithoutCustomConfig(self):
     tuner = component.Tuner(
         examples=self.examples,
         schema=self.schema,
         train_args=self.train_args,
         eval_args=self.eval_args,
         tune_args=self.tune_args,
         module_file='/path/to/module/file',
     )
     self._verify_output(tuner)
コード例 #3
0
 def testAIPlatformDistributedTunerPipeline(self):
     """Tuner-only pipeline for distributed Tuner flock on AIP Training."""
     pipeline_name = 'kubeflow-aip-dist-tuner-test-{}'.format(
         test_utils.random_id())
     pipeline = self._create_pipeline(
         pipeline_name,
         [
             self.iris_examples_importer,
             self.iris_schema_importer,
             ai_platform_tuner_component.Tuner(
                 examples=self.iris_examples_importer.outputs['result'],
                 module_file=self._iris_tuner_module,
                 schema=self.iris_schema_importer.outputs['result'],
                 train_args=trainer_pb2.TrainArgs(num_steps=10),
                 eval_args=trainer_pb2.EvalArgs(num_steps=5),
                 # 3 worker parallel tuning.
                 tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3),
                 custom_config={
                     ai_platform_trainer_executor.TRAINING_ARGS_KEY:
                     self._getCaipTrainingArgs(pipeline_name)
                 })
         ])
     self._compile_and_run_pipeline(pipeline)
     self._assertHyperparametersAreWritten(pipeline_name)