def Run(self, args): stream_logs = jobs_util.GetStreamLogs(args.async, args.stream_logs) scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier) scale_tier_name = scale_tier.name if scale_tier else None jobs_client = jobs.JobsClient() labels = jobs_util.ParseCreateLabels(jobs_client, args) custom_container_config = ( jobs_util.TrainingCustomInputServerConfig.FromArgs(args)) custom_container_config.ValidateConfig() job = jobs_util.SubmitTraining( jobs_client, args.job, job_dir=args.job_dir, staging_bucket=args.staging_bucket, packages=args.packages, package_path=args.package_path, scale_tier=scale_tier_name, config=args.config, module_name=args.module_name, runtime_version=args.runtime_version, python_version=args.python_version, labels=labels, stream_logs=stream_logs, user_args=args.user_args, custom_train_server_config=custom_container_config) # If the job itself failed, we will return a failure status. if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED: self.exit_code = 1 return job
def _AddSubmitTrainingArgs(parser): """Add arguments for `jobs submit training` command.""" flags.JOB_NAME.AddToParser(parser) flags.PACKAGE_PATH.AddToParser(parser) flags.PACKAGES.AddToParser(parser) flags.GetModuleNameFlag(required=False).AddToParser(parser) compute_flags.AddRegionFlag(parser, 'machine learning training job', 'submit') flags.CONFIG.AddToParser(parser) flags.STAGING_BUCKET.AddToParser(parser) flags.GetJobDirFlag(upload_help=True).AddToParser(parser) flags.GetUserArgs(local=False).AddToParser(parser) jobs_util.ScaleTierFlagMap().choice_arg.AddToParser(parser) flags.RUNTIME_VERSION.AddToParser(parser) flags.AddPythonVersionFlag(parser, 'during training') sync_group = parser.add_mutually_exclusive_group() # TODO(b/36195821): Use the flag deprecation machinery when it supports the # store_true action sync_group.add_argument( '--async', action='store_true', help=( '(DEPRECATED) Display information about the operation in progress ' 'without waiting for the operation to complete. ' 'Enabled by default and can be omitted; use `--stream-logs` to run ' 'synchronously.')) sync_group.add_argument( '--stream-logs', action='store_true', help=('Block until job completion and stream the logs while the job runs.' '\n\n' 'Note that even if command execution is halted, the job will still ' 'run until cancelled with\n\n' ' $ gcloud ai-platform jobs cancel JOB_ID')) labels_util.AddCreateLabelsFlags(parser)
def Run(self, args): stream_logs = jobs_util.GetStreamLogs(args.async_, args.stream_logs) scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier) scale_tier_name = scale_tier.name if scale_tier else None jobs_client = jobs.JobsClient() labels = jobs_util.ParseCreateLabels(jobs_client, args) custom_container_config = ( jobs_util.TrainingCustomInputServerConfig.FromArgs( args, self._SUPPORT_TPU_TF_VERSION)) custom_container_config.ValidateConfig() job = jobs_util.SubmitTraining( jobs_client, args.job, job_dir=args.job_dir, staging_bucket=args.staging_bucket, packages=args.packages, package_path=args.package_path, scale_tier=scale_tier_name, config=args.config, module_name=args.module_name, runtime_version=args.runtime_version, python_version=args.python_version, network=args.network if hasattr(args, 'network') else None, service_account=args.service_account, labels=labels, stream_logs=stream_logs, user_args=args.user_args, kms_key=_GetAndValidateKmsKey(args), custom_train_server_config=custom_container_config, enable_web_access=args.enable_web_access) # If the job itself failed, we will return a failure status. if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED: self.exit_code = 1 return job
def Run(self, args): stream_logs = jobs_util.GetStreamLogs(args.async, args.stream_logs) scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier) scale_tier_name = scale_tier.name if scale_tier else None job = jobs_util.SubmitTraining( jobs.JobsClient(), args.job, job_dir=args.job_dir, staging_bucket=args.staging_bucket, packages=args.packages, package_path=args.package_path, scale_tier=scale_tier_name, config=args.config, module_name=args.module_name, runtime_version=args.runtime_version, stream_logs=stream_logs, user_args=args.user_args) # If the job itself failed, we will return a failure status. if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED: self.exit_code = 1 return job