示例#1
0
 def testMakeProcess_Distributed(self):
     package_dir = self.Resource('tests', 'unit', 'command_lib',
                                 'ml_engine', 'test_data', 'package_root')
     run_root = os.path.join(self.temp_path, 'run_dir')
     shutil.copytree(package_dir, run_root)
     module_name = 'test_package.test_task'
     out = io.BytesIO()
     args = ['foo']
     cluster = {'distributed': ['address_1']}
     stdout, _ = local_train.MakeProcess(
         module_name,
         run_root,
         args=args,
         task_type='distributed',
         index=0,
         cluster=cluster,
         stdout=subprocess.PIPE).communicate()
     out.write(stdout)
     self.assertEqual(
         yaml.load(out.getvalue()), {
             'TF_CONFIG': {
                 'job': {
                     'job_name': module_name,
                     'args': args
                 },
                 'task': {
                     'type': 'distributed',
                     'index': 0
                 },
                 'cluster': cluster,
                 'environment': 'cloud',
             },
             'PWD': run_root,
             'ARGS': ['foo']
         })
示例#2
0
 def testUseSystemPython(self):
     environ_cp = os.environ.copy()
     environ_cp['CLOUDSDK_PYTHON'] = 'DUMMY_STRING'
     self.StartPatch('os.environ', return_value=environ_cp)
     exec_mock = self.StartPatch('googlecloudsdk.core.execution_utils.Exec')
     local_train.MakeProcess('foo', 'bar', task_type='master')
     exec_cmd = exec_mock.call_args[0][0]
     self.assertNotEqual(exec_cmd[0], 'DUMMY_STRING')
示例#3
0
 def testMakeProcess_Master(self):
     package_dir = self.Resource('tests', 'unit', 'command_lib',
                                 'ml_engine', 'test_data', 'package_root')
     run_root = os.path.join(self.temp_path, 'run_dir')
     shutil.copytree(package_dir, run_root)
     module_name = 'test_package.test_task'
     args = ['foo']
     # We can only check the return code due to the weird semantics of
     # MakeProcess
     return_code = local_train.MakeProcess(module_name,
                                           run_root,
                                           args=args,
                                           task_type='master',
                                           index=0,
                                           cluster={},
                                           stdout=subprocess.PIPE)
     self.assertEqual(return_code, 0)
  def Run(self, args):
    """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
    package_path = args.package_path or files.GetCWD()
    # Mimic behavior of ai-platform jobs submit training
    package_root = os.path.dirname(os.path.abspath(package_path))
    user_args = args.user_args or []
    if args.job_dir:
      user_args.extend(('--job-dir', args.job_dir))

    worker_count = 2 if args.worker_count is None else args.worker_count
    ps_count = 2 if args.parameter_server_count is None else args.parameter_server_count

    if args.distributed:
      retval = local_train.RunDistributed(
          args.module_name,
          package_root,
          ps_count,
          worker_count,
          args.evaluator_count or 0,
          args.start_port,
          user_args=user_args)
    else:
      if args.parameter_server_count:
        log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(
            flag='--parameter-server-count'))
      if args.worker_count:
        log.warning(_BAD_FLAGS_WARNING_MESSAGE.format(flag='--worker-count'))
      retval = local_train.MakeProcess(
          args.module_name,
          package_root,
          args=user_args,
          task_type=local_train.GetPrimaryNodeName())
    # Don't raise an exception because the users will already see the message.
    # We want this to mimic calling the script directly as much as possible.
    self.exit_code = retval
示例#5
0
 def testArgs(self):
     exec_mock = self.StartPatch('googlecloudsdk.core.execution_utils.Exec')
     args = ['foo', 'bar']
     local_train.MakeProcess('baz', 'zap', task_type='master', args=args)
     exec_cmd = exec_mock.call_args[0][0]
     self.assertEqual(exec_cmd[3:], args)