def testDo(self, mock_runner): executor = ai_platform_trainer_executor.Executor() executor.Do(self._inputs, self._outputs, self._exec_properties) mock_runner.start_aip_training.assert_called_with( self._inputs, self._outputs, self._exec_properties, self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, }, None)
def testDo(self): executor = ai_platform_trainer_executor.Executor() executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test()) self.mock_runner.start_cloud_training.assert_called_with( self._inputs, self._outputs, self._serialize_custom_config_under_test(), self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, }, None, False, None)
def testDoWithJobIdOverride(self, mock_runner): executor = ai_platform_trainer_executor.Executor() job_id = 'overridden_job_id' self._exec_properties['custom_config'][ ai_platform_trainer_executor.JOB_ID_KEY] = job_id executor.Do(self._inputs, self._outputs, self._exec_properties) mock_runner.start_aip_training.assert_called_with( self._inputs, self._outputs, self._exec_properties, self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, }, job_id)
def testDoWithJobIdOverride(self): executor = ai_platform_trainer_executor.Executor() job_id = 'overridden_job_id' self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][ ai_platform_trainer_executor.JOB_ID_KEY] = job_id executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test()) self.mock_runner.start_cloud_training.assert_called_with( self._inputs, self._outputs, self._serialize_custom_config_under_test(), self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, }, job_id, False, None)
def testDoWithEnableVertexOverride(self): executor = ai_platform_trainer_executor.Executor() enable_vertex = True vertex_region = 'us-central2' self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][ ai_platform_trainer_executor.ENABLE_VERTEX_KEY] = enable_vertex self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][ ai_platform_trainer_executor.VERTEX_REGION_KEY] = vertex_region executor.Do(self._inputs, self._outputs, self._serialize_custom_config_under_test()) self.mock_runner.start_cloud_training.assert_called_with( self._inputs, self._outputs, self._serialize_custom_config_under_test(), self._executor_class_path, { 'project': self._project_id, 'jobDir': self._job_dir, }, None, enable_vertex, vertex_region)