Exemplo n.º 1
0
  def Run(self, args):
    stream_logs = jobs_util.GetStreamLogs(args.async, args.stream_logs)
    scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
    scale_tier_name = scale_tier.name if scale_tier else None
    jobs_client = jobs.JobsClient()
    labels = jobs_util.ParseCreateLabels(jobs_client, args)
    custom_container_config = (
        jobs_util.TrainingCustomInputServerConfig.FromArgs(args))
    custom_container_config.ValidateConfig()

    job = jobs_util.SubmitTraining(
        jobs_client, args.job,
        job_dir=args.job_dir,
        staging_bucket=args.staging_bucket,
        packages=args.packages,
        package_path=args.package_path,
        scale_tier=scale_tier_name,
        config=args.config,
        module_name=args.module_name,
        runtime_version=args.runtime_version,
        python_version=args.python_version,
        labels=labels,
        stream_logs=stream_logs,
        user_args=args.user_args,
        custom_train_server_config=custom_container_config)
    # If the job itself failed, we will return a failure status.
    if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
      self.exit_code = 1
    return job
Exemplo n.º 2
0
def _AddSubmitTrainingArgs(parser):
  """Add arguments for `jobs submit training` command."""
  flags.JOB_NAME.AddToParser(parser)
  flags.PACKAGE_PATH.AddToParser(parser)
  flags.PACKAGES.AddToParser(parser)
  flags.GetModuleNameFlag(required=False).AddToParser(parser)
  compute_flags.AddRegionFlag(parser, 'machine learning training job',
                              'submit')
  flags.CONFIG.AddToParser(parser)
  flags.STAGING_BUCKET.AddToParser(parser)
  flags.GetJobDirFlag(upload_help=True).AddToParser(parser)
  flags.GetUserArgs(local=False).AddToParser(parser)
  jobs_util.ScaleTierFlagMap().choice_arg.AddToParser(parser)
  flags.RUNTIME_VERSION.AddToParser(parser)
  flags.AddPythonVersionFlag(parser, 'during training')

  sync_group = parser.add_mutually_exclusive_group()
  # TODO(b/36195821): Use the flag deprecation machinery when it supports the
  # store_true action
  sync_group.add_argument(
      '--async', action='store_true', help=(
          '(DEPRECATED) Display information about the operation in progress '
          'without waiting for the operation to complete. '
          'Enabled by default and can be omitted; use `--stream-logs` to run '
          'synchronously.'))
  sync_group.add_argument(
      '--stream-logs',
      action='store_true',
      help=('Block until job completion and stream the logs while the job runs.'
            '\n\n'
            'Note that even if command execution is halted, the job will still '
            'run until cancelled with\n\n'
            '    $ gcloud ai-platform jobs cancel JOB_ID'))
  labels_util.AddCreateLabelsFlags(parser)
 def Run(self, args):
   stream_logs = jobs_util.GetStreamLogs(args.async_, args.stream_logs)
   scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
   scale_tier_name = scale_tier.name if scale_tier else None
   jobs_client = jobs.JobsClient()
   labels = jobs_util.ParseCreateLabels(jobs_client, args)
   custom_container_config = (
       jobs_util.TrainingCustomInputServerConfig.FromArgs(
           args, self._SUPPORT_TPU_TF_VERSION))
   custom_container_config.ValidateConfig()
   job = jobs_util.SubmitTraining(
       jobs_client,
       args.job,
       job_dir=args.job_dir,
       staging_bucket=args.staging_bucket,
       packages=args.packages,
       package_path=args.package_path,
       scale_tier=scale_tier_name,
       config=args.config,
       module_name=args.module_name,
       runtime_version=args.runtime_version,
       python_version=args.python_version,
       network=args.network if hasattr(args, 'network') else None,
       service_account=args.service_account,
       labels=labels,
       stream_logs=stream_logs,
       user_args=args.user_args,
       kms_key=_GetAndValidateKmsKey(args),
       custom_train_server_config=custom_container_config,
       enable_web_access=args.enable_web_access)
   # If the job itself failed, we will return a failure status.
   if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
     self.exit_code = 1
   return job
Exemplo n.º 4
0
 def Run(self, args):
   stream_logs = jobs_util.GetStreamLogs(args.async, args.stream_logs)
   scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
   scale_tier_name = scale_tier.name if scale_tier else None
   job = jobs_util.SubmitTraining(
       jobs.JobsClient(),
       args.job,
       job_dir=args.job_dir,
       staging_bucket=args.staging_bucket,
       packages=args.packages,
       package_path=args.package_path,
       scale_tier=scale_tier_name,
       config=args.config,
       module_name=args.module_name,
       runtime_version=args.runtime_version,
       stream_logs=stream_logs,
       user_args=args.user_args)
   # If the job itself failed, we will return a failure status.
   if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
     self.exit_code = 1
   return job