def testMakeProcess_Distributed(self): package_dir = self.Resource('tests', 'unit', 'command_lib', 'ml_engine', 'test_data', 'package_root') run_root = os.path.join(self.temp_path, 'run_dir') shutil.copytree(package_dir, run_root) module_name = 'test_package.test_task' out = io.BytesIO() args = ['foo'] cluster = {'distributed': ['address_1']} stdout, _ = local_train.MakeProcess( module_name, run_root, args=args, task_type='distributed', index=0, cluster=cluster, stdout=subprocess.PIPE).communicate() out.write(stdout) self.assertEqual( yaml.load(out.getvalue()), { 'TF_CONFIG': { 'job': { 'job_name': module_name, 'args': args }, 'task': { 'type': 'distributed', 'index': 0 }, 'cluster': cluster, 'environment': 'cloud', }, 'PWD': run_root, 'ARGS': ['foo'] })
def testUseSystemPython(self): environ_cp = os.environ.copy() environ_cp['CLOUDSDK_PYTHON'] = 'DUMMY_STRING' self.StartPatch('os.environ', return_value=environ_cp) exec_mock = self.StartPatch('googlecloudsdk.core.execution_utils.Exec') local_train.MakeProcess('foo', 'bar', task_type='master') exec_cmd = exec_mock.call_args[0][0] self.assertNotEqual(exec_cmd[0], 'DUMMY_STRING')
def testMakeProcess_Master(self): package_dir = self.Resource('tests', 'unit', 'command_lib', 'ml_engine', 'test_data', 'package_root') run_root = os.path.join(self.temp_path, 'run_dir') shutil.copytree(package_dir, run_root) module_name = 'test_package.test_task' args = ['foo'] # We can only check the return code due to the weird semantics of # MakeProcess return_code = local_train.MakeProcess(module_name, run_root, args=args, task_type='master', index=0, cluster={}, stdout=subprocess.PIPE) self.assertEqual(return_code, 0)
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ package_path = args.package_path or files.GetCWD() # Mimic behavior of ai-platform jobs submit training package_root = os.path.dirname(os.path.abspath(package_path)) user_args = args.user_args or [] if args.job_dir: user_args.extend(('--job-dir', args.job_dir)) worker_count = 2 if args.worker_count is None else args.worker_count ps_count = 2 if args.parameter_server_count is None else args.parameter_server_count if args.distributed: retval = local_train.RunDistributed( args.module_name, package_root, ps_count, worker_count, args.evaluator_count or 0, args.start_port, user_args=user_args) else: if args.parameter_server_count: log.warning(_BAD_FLAGS_WARNING_MESSAGE.format( flag='--parameter-server-count')) if args.worker_count: log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count')) retval = local_train.MakeProcess( args.module_name, package_root, args=user_args, task_type=local_train.GetPrimaryNodeName()) # Don't raise an exception because the users will already see the message. # We want this to mimic calling the script directly as much as possible. self.exit_code = retval
def testArgs(self): exec_mock = self.StartPatch('googlecloudsdk.core.execution_utils.Exec') args = ['foo', 'bar'] local_train.MakeProcess('baz', 'zap', task_type='master', args=args) exec_cmd = exec_mock.call_args[0][0] self.assertEqual(exec_cmd[3:], args)