def testRunDistributedErrorMaster(self): a = self._mockMakeProcess() b = self._mockMakeProcess() popen_mock = self.StartPatch('subprocess.Popen', side_effect=[a, b, RuntimeError]) kill_mock = self.StartPatch( 'googlecloudsdk.core.execution_utils.KillSubprocess') with self.assertRaises(RuntimeError): local_train.RunDistributed('test_package.test_task', self.temp_path, 2, 2, 0, 0) atexit._run_exitfuncs() self.assertEqual(popen_mock.call_count, 3) kill_mock.assert_has_calls([mock.call(a), mock.call(b)], any_order=True)
def testRunDistributed(self): exec_mock = self.StartPatch('googlecloudsdk.core.execution_utils.Exec') ps_mock = mock.MagicMock() worker_mock = mock.MagicMock() popen_mock = self.StartPatch('subprocess.Popen', side_effect=[ps_mock, worker_mock]) kill_mock = self.StartPatch( 'googlecloudsdk.core.execution_utils.KillSubprocess') local_train.RunDistributed('test_package.test_task', self.temp_path, 1, 1, 0, 0) self.assertEqual(popen_mock.call_count, 2) exec_mock.assert_called_once() atexit._run_exitfuncs() kill_mock.assert_has_calls( [mock.call(ps_mock), mock.call(worker_mock)], any_order=True)
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ package_path = args.package_path or files.GetCWD() # Mimic behavior of ai-platform jobs submit training package_root = os.path.dirname(os.path.abspath(package_path)) user_args = args.user_args or [] if args.job_dir: user_args.extend(('--job-dir', args.job_dir)) worker_count = 2 if args.worker_count is None else args.worker_count ps_count = 2 if args.parameter_server_count is None else args.parameter_server_count if args.distributed: retval = local_train.RunDistributed( args.module_name, package_root, ps_count, worker_count, args.evaluator_count or 0, args.start_port, user_args=user_args) else: if args.parameter_server_count: log.warning(_BAD_FLAGS_WARNING_MESSAGE.format( flag='--parameter-server-count')) if args.worker_count: log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count')) retval = local_train.MakeProcess( args.module_name, package_root, args=user_args, task_type=local_train.GetPrimaryNodeName()) # Don't raise an exception because the users will already see the message. # We want this to mimic calling the script directly as much as possible. self.exit_code = retval