Example #1
0
 def testDo(self, mock_runner):
   executor = ai_platform_trainer_executor.Executor()
   executor.Do(self._inputs, self._outputs, self._exec_properties)
   mock_runner.start_aip_training.assert_called_with(
       self._inputs, self._outputs, self._exec_properties,
       self._executor_class_path, {
           'project': self._project_id,
           'jobDir': self._job_dir,
       }, None)
Example #2
0
 def testDo(self):
     executor = ai_platform_trainer_executor.Executor()
     executor.Do(self._inputs, self._outputs,
                 self._serialize_custom_config_under_test())
     self.mock_runner.start_cloud_training.assert_called_with(
         self._inputs, self._outputs,
         self._serialize_custom_config_under_test(),
         self._executor_class_path, {
             'project': self._project_id,
             'jobDir': self._job_dir,
         }, None, False, None)
Example #3
0
 def testDoWithJobIdOverride(self, mock_runner):
   executor = ai_platform_trainer_executor.Executor()
   job_id = 'overridden_job_id'
   self._exec_properties['custom_config'][
       ai_platform_trainer_executor.JOB_ID_KEY] = job_id
   executor.Do(self._inputs, self._outputs, self._exec_properties)
   mock_runner.start_aip_training.assert_called_with(
       self._inputs, self._outputs, self._exec_properties,
       self._executor_class_path, {
           'project': self._project_id,
           'jobDir': self._job_dir,
       }, job_id)
Example #4
0
 def testDoWithJobIdOverride(self):
     executor = ai_platform_trainer_executor.Executor()
     job_id = 'overridden_job_id'
     self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][
         ai_platform_trainer_executor.JOB_ID_KEY] = job_id
     executor.Do(self._inputs, self._outputs,
                 self._serialize_custom_config_under_test())
     self.mock_runner.start_cloud_training.assert_called_with(
         self._inputs, self._outputs,
         self._serialize_custom_config_under_test(),
         self._executor_class_path, {
             'project': self._project_id,
             'jobDir': self._job_dir,
         }, job_id, False, None)
Example #5
0
 def testDoWithEnableVertexOverride(self):
     executor = ai_platform_trainer_executor.Executor()
     enable_vertex = True
     vertex_region = 'us-central2'
     self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][
         ai_platform_trainer_executor.ENABLE_VERTEX_KEY] = enable_vertex
     self._exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY][
         ai_platform_trainer_executor.VERTEX_REGION_KEY] = vertex_region
     executor.Do(self._inputs, self._outputs,
                 self._serialize_custom_config_under_test())
     self.mock_runner.start_cloud_training.assert_called_with(
         self._inputs, self._outputs,
         self._serialize_custom_config_under_test(),
         self._executor_class_path, {
             'project': self._project_id,
             'jobDir': self._job_dir,
         }, None, enable_vertex, vertex_region)