def Run(self, args):
    stream_logs = jobs_util.GetStreamLogs(args.async, args.stream_logs)
    scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
    scale_tier_name = scale_tier.name if scale_tier else None
    jobs_client = jobs.JobsClient()
    labels = jobs_util.ParseCreateLabels(jobs_client, args)
    custom_container_config = (
        jobs_util.TrainingCustomInputServerConfig.FromArgs(args))
    custom_container_config.ValidateConfig()

    job = jobs_util.SubmitTraining(
        jobs_client, args.job,
        job_dir=args.job_dir,
        staging_bucket=args.staging_bucket,
        packages=args.packages,
        package_path=args.package_path,
        scale_tier=scale_tier_name,
        config=args.config,
        module_name=args.module_name,
        runtime_version=args.runtime_version,
        python_version=args.python_version,
        labels=labels,
        stream_logs=stream_logs,
        user_args=args.user_args,
        custom_train_server_config=custom_container_config)
    # If the job itself failed, we will return a failure status.
    if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
      self.exit_code = 1
    return job
 def Run(self, args):
   stream_logs = jobs_util.GetStreamLogs(args.async_, args.stream_logs)
   scale_tier = jobs_util.ScaleTierFlagMap().GetEnumForChoice(args.scale_tier)
   scale_tier_name = scale_tier.name if scale_tier else None
   jobs_client = jobs.JobsClient()
   labels = jobs_util.ParseCreateLabels(jobs_client, args)
   custom_container_config = (
       jobs_util.TrainingCustomInputServerConfig.FromArgs(
           args, self._SUPPORT_TPU_TF_VERSION))
   custom_container_config.ValidateConfig()
   job = jobs_util.SubmitTraining(
       jobs_client,
       args.job,
       job_dir=args.job_dir,
       staging_bucket=args.staging_bucket,
       packages=args.packages,
       package_path=args.package_path,
       scale_tier=scale_tier_name,
       config=args.config,
       module_name=args.module_name,
       runtime_version=args.runtime_version,
       python_version=args.python_version,
       network=args.network if hasattr(args, 'network') else None,
       service_account=args.service_account,
       labels=labels,
       stream_logs=stream_logs,
       user_args=args.user_args,
       kms_key=_GetAndValidateKmsKey(args),
       custom_train_server_config=custom_container_config,
       enable_web_access=args.enable_web_access)
   # If the job itself failed, we will return a failure status.
   if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
     self.exit_code = 1
   return job
Exemple #3
0
def MakeContinueFunction(job_id):
    """Returns a function to decide if log fetcher should continue polling.

  Args:
    job_id: String id of job.

  Returns:
    A one-argument function decides if log fetcher should continue.
  """
    jobs_client = jobs.JobsClient()
    project_id = properties.VALUES.core.project.Get(required=True)
    job_ref = resources.REGISTRY.Create('ml.projects.jobs',
                                        jobsId=job_id,
                                        projectsId=project_id)

    def ShouldContinue(periods_without_logs):
        """Returns whether to continue polling the logs.

    Returns False only once we've checked the job and it is finished; we only
    check whether the job is finished once we've gone >1 interval without
    getting any new logs.

    Args:
      periods_without_logs: integer number of empty polls.

    Returns:
      True if we haven't tried polling more than once or if job is not finished.
    """
        if periods_without_logs <= 1:
            return True
        return jobs_client.Get(job_ref).endTime is None

    return ShouldContinue
Exemple #4
0
 def Run(self, args):
     job = jobs_util.Describe(jobs.JobsClient(), args.job)
     self.job = job  # Hack to make the Epilog method work
     if args.summarize:
         if args.format:
             log.warn('--format is ignored when --summarize is present')
         args.format = jobs_util.GetSummaryFormat(job)
     return job
 def Run(self, args):
     return jobs_util.SubmitPrediction(
         jobs.JobsClient(),
         args.job,
         model_dir=args.model_dir,
         model=args.model,
         version=args.version,
         input_paths=args.input_paths,
         data_format=args.data_format,
         output_path=args.output_path,
         region=args.region,
         runtime_version=args.runtime_version,
         max_worker_count=args.max_worker_count)
 def Run(self, args):
     data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
         args.data_format)
     return jobs_util.SubmitPrediction(
         jobs.JobsClient(),
         args.job,
         model_dir=args.model_dir,
         model=args.model,
         version=args.version,
         input_paths=args.input_paths,
         data_format=data_format.name,
         output_path=args.output_path,
         region=args.region,
         runtime_version=args.runtime_version,
         max_worker_count=args.max_worker_count,
         batch_size=args.batch_size)
 def Run(self, args):
     stream_logs = jobs_util.GetStreamLogs(args. async, args.stream_logs)
     job = jobs_util.SubmitTraining(jobs.JobsClient('v1'),
                                    args.job,
                                    job_dir=args.job_dir,
                                    staging_bucket=args.staging_bucket,
                                    packages=args.packages,
                                    package_path=args.package_path,
                                    scale_tier=args.scale_tier,
                                    config=args.config,
                                    module_name=args.module_name,
                                    runtime_version=args.runtime_version,
                                    stream_logs=stream_logs,
                                    user_args=args.user_args)
     # If the job itself failed, we will return a failure status.
     if stream_logs and job.state is not job.StateValueValuesEnum.SUCCEEDED:
         self.exit_code = 1
     return job
Exemple #8
0
  def Run(self, args):
    data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
        args.data_format)
    jobs_client = jobs.JobsClient()

    labels = jobs_util.ParseCreateLabels(jobs_client, args)
    return jobs_util.SubmitPrediction(
        jobs_client, args.job,
        model_dir=args.model_dir,
        model=args.model,
        version=args.version,
        input_paths=args.input_paths,
        data_format=data_format.name,
        output_path=args.output_path,
        region=args.region,
        runtime_version=args.runtime_version,
        max_worker_count=args.max_worker_count,
        batch_size=args.batch_size,
        signature_name=args.signature_name,
        labels=labels,
        accelerator_type=args.accelerator_type,
        accelerator_count=args.accelerator_count)
Exemple #9
0
 def Run(self, args):
     return jobs_util.Cancel(jobs.JobsClient(), args.job)
Exemple #10
0
 def Run(self, args):
     jobs_client = jobs.JobsClient()
     updated_job = jobs_util.Update(jobs_client, args)
     log.UpdatedResource(args.job, kind='ml engine job')
     return updated_job
Exemple #11
0
 def Run(self, args):
     return jobs_util.List(jobs.JobsClient('v1'))
Exemple #12
0
 def SetUp(self):
     self.jobs_client = jobs.JobsClient()
Exemple #13
0
 def Run(self, args):
     job = jobs_util.Describe(jobs.JobsClient('v1'), args.job)
     self.job = job  # Hack to make the Epilog method work
     return job
Exemple #14
0
 def Run(self, args):
     jobs_client = jobs.JobsClient()
     operations_client = operations.OperationsClient()
     jobs_util.Update(jobs_client, operations_client, args)
     log.UpdatedResource(args.job, kind='ml engine job')