def testVertexDistributedTunerPipeline(self): """Tuner-only pipeline for distributed Tuner flock on Vertex AI Training.""" pipeline_name = self._make_unique_pipeline_name( 'kubeflow-vertex-dist-tuner') pipeline = self._create_pipeline( pipeline_name, [ self.penguin_examples_importer, self.penguin_schema_importer, ai_platform_tuner_component.Tuner( examples=self.penguin_examples_importer.outputs['result'], module_file=self._penguin_tuner_module, schema=self.penguin_schema_importer.outputs['result'], train_args=trainer_pb2.TrainArgs(num_steps=10), eval_args=trainer_pb2.EvalArgs(num_steps=5), # 3 worker parallel tuning. tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3), custom_config={ ai_platform_tuner_executor.TUNING_ARGS_KEY: self._getVertexTrainingArgs(pipeline_name), constants.ENABLE_VERTEX_KEY: True, constants.VERTEX_REGION_KEY: self._GCP_REGION }) ]) self._compile_and_run_pipeline(pipeline) self._assertHyperparametersAreWritten(pipeline_name)
def testConstructWithoutCustomConfig(self): tuner = component.Tuner( examples=self.examples, schema=self.schema, train_args=self.train_args, eval_args=self.eval_args, tune_args=self.tune_args, module_file='/path/to/module/file', ) self._verify_output(tuner)
def testAIPlatformDistributedTunerPipeline(self): """Tuner-only pipeline for distributed Tuner flock on AIP Training.""" pipeline_name = 'kubeflow-aip-dist-tuner-test-{}'.format( test_utils.random_id()) pipeline = self._create_pipeline( pipeline_name, [ self.iris_examples_importer, self.iris_schema_importer, ai_platform_tuner_component.Tuner( examples=self.iris_examples_importer.outputs['result'], module_file=self._iris_tuner_module, schema=self.iris_schema_importer.outputs['result'], train_args=trainer_pb2.TrainArgs(num_steps=10), eval_args=trainer_pb2.EvalArgs(num_steps=5), # 3 worker parallel tuning. tune_args=tuner_pb2.TuneArgs(num_parallel_trials=3), custom_config={ ai_platform_trainer_executor.TRAINING_ARGS_KEY: self._getCaipTrainingArgs(pipeline_name) }) ]) self._compile_and_run_pipeline(pipeline) self._assertHyperparametersAreWritten(pipeline_name)