예제 #1
0
def StreamLogs(api_version, job, task_name, polling_interval,
               allow_multiline_logs):
  log_fetcher = stream.LogFetcher(
      filters=log_utils.LogFilters(job, task_name),
      polling_interval=polling_interval, continue_interval=_CONTINUE_INTERVAL,
      continue_func=log_utils.MakeContinueFunction(job, api_version))
  return log_utils.SplitMultiline(
      log_fetcher.YieldLogs(), allow_multiline=allow_multiline_logs)
예제 #2
0
 def Run(self, args):
   """Run the stream-logs command."""
   log_fetcher = stream.LogFetcher(
       filters=log_utils.LogFilters(args.job, args.task_name),
       polling_interval=args.polling_interval,
       continue_func=log_utils.MakeContinueFunction(args.job))
   return log_utils.SplitMultiline(
       log_fetcher.YieldLogs(), allow_multiline=args.allow_multiline_logs)
예제 #3
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        region = properties.VALUES.compute.region.Get(required=True)
        staging_location = jobs_prep.GetStagingLocation(
            staging_bucket=args.staging_bucket,
            job_id=args.job,
            job_dir=args.job_dir)
        try:
            uris = jobs_prep.UploadPythonPackages(
                packages=args.packages,
                package_path=args.package_path,
                staging_location=staging_location)
        except jobs_prep.NoStagingLocationError:
            raise flags.ArgumentError(
                'If local packages are provided, the `--staging-bucket` or '
                '`--job-dir` flag must be given.')
        log.debug('Using {0} as trainer uris'.format(uris))

        scale_tier_enum = (jobs.GetMessagesModule(
        ).GoogleCloudMlV1beta1TrainingInput.ScaleTierValueValuesEnum)
        scale_tier = scale_tier_enum(
            args.scale_tier) if args.scale_tier else None
        job = jobs.BuildTrainingJob(
            path=args.config,
            module_name=args.module_name,
            job_name=args.job,
            trainer_uri=uris,
            region=region,
            job_dir=args.job_dir.ToUrl() if args.job_dir else None,
            scale_tier=scale_tier,
            user_args=args.user_args,
            runtime_version=args.runtime_version)

        jobs_client = jobs.JobsClient()
        project_ref = resources.REGISTRY.Parse(
            properties.VALUES.core.project.Get(required=True),
            collection='ml.projects')
        job = jobs_client.Create(project_ref, job)
        log.status.Print('Job [{}] submitted successfully.'.format(job.jobId))
        if args. async:
            log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))
            return job

        log_fetcher = stream.LogFetcher(
            filters=log_utils.LogFilters(job.jobId),
            polling_interval=_POLLING_INTERVAL,
            continue_func=log_utils.MakeContinueFunction(job.jobId))

        printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)

        def _CtrlCHandler(signal, frame):
            del signal, frame  # Unused
            raise KeyboardInterrupt

        with execution_utils.CtrlCSection(_CtrlCHandler):
            try:
                printer.Print(log_utils.SplitMultiline(
                    log_fetcher.YieldLogs()))
            except KeyboardInterrupt:
                log.status.Print('Received keyboard interrupt.')
                log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))

        job_ref = resources.REGISTRY.Parse(job.jobId,
                                           collection='ml.projects.jobs')
        job = jobs_client.Get(job_ref)
        # If the job itself failed, we will return a failure status.
        if job.state is not job.StateValueValuesEnum.SUCCEEDED:
            self.exit_code = 1

        return job
예제 #4
0
def SubmitTraining(jobs_client, job, job_dir=None, staging_bucket=None,
                   packages=None, package_path=None, scale_tier=None,
                   config=None, module_name=None, runtime_version=None,
                   stream_logs=None, user_args=None):
  """Submit a training job."""
  region = properties.VALUES.compute.region.Get(required=True)
  staging_location = jobs_prep.GetStagingLocation(
      staging_bucket=staging_bucket, job_id=job,
      job_dir=job_dir)
  try:
    uris = jobs_prep.UploadPythonPackages(
        packages=packages, package_path=package_path,
        staging_location=staging_location)
  except jobs_prep.NoStagingLocationError:
    raise flags.ArgumentError(
        'If local packages are provided, the `--staging-bucket` or '
        '`--job-dir` flag must be given.')
  log.debug('Using {0} as trainer uris'.format(uris))

  scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
  scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

  job = jobs_client.BuildTrainingJob(
      path=config,
      module_name=module_name,
      job_name=job,
      trainer_uri=uris,
      region=region,
      job_dir=job_dir.ToUrl() if job_dir else None,
      scale_tier=scale_tier,
      user_args=user_args,
      runtime_version=runtime_version)

  project_ref = resources.REGISTRY.Parse(
      properties.VALUES.core.project.Get(required=True),
      collection='ml.projects')
  job = jobs_client.Create(project_ref, job)
  if not stream_logs:
    PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
    return job
  else:
    PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)

  log_fetcher = stream.LogFetcher(
      filters=log_utils.LogFilters(job.jobId),
      polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
      continue_interval=_CONTINUE_INTERVAL,
      continue_func=log_utils.MakeContinueFunction(job.jobId))

  printer = resource_printer.Printer(log_utils.LOG_FORMAT,
                                     out=log.err)
  with execution_utils.RaisesKeyboardInterrupt():
    try:
      printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
    except KeyboardInterrupt:
      log.status.Print('Received keyboard interrupt.\n')
      log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                                 project=project_ref.Name()))
    except exceptions.HttpError as err:
      log.status.Print('Polling logs failed:\n{}\n'.format(str(err)))
      log.info('Failure details:', exc_info=True)
      log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                                 project=project_ref.Name()))

  job_ref = resources.REGISTRY.Parse(
      job.jobId,
      params={'projectsId': properties.VALUES.core.project.GetOrFail},
      collection='ml.projects.jobs')
  job = jobs_client.Get(job_ref)

  return job