コード例 #1
0
 def ValidateConfig(self):
   """Validate that custom config parameters are set correctly."""
   if self.main_image_uri and self.runtime_version:
     raise flags.ArgumentError('Only one of --main-image-uri,'
                               ' --runtime-version can be set.')
   if self.scale_tier and self.scale_tier.name == 'CUSTOM':
     if not self.main_machine_type:
       raise flags.ArgumentError('--main-machine-type is required if '
                                 'scale-tier is set to `CUSTOM`.')
   return True
コード例 #2
0
def _ValidateSubmitPredictionArgs(model_dir, version):
    if model_dir and version:
        raise flags.ArgumentError(
            '`--version` cannot be set with `--model-dir`')
コード例 #3
0
def SubmitTraining(jobs_client,
                   job,
                   job_dir=None,
                   staging_bucket=None,
                   packages=None,
                   package_path=None,
                   scale_tier=None,
                   config=None,
                   module_name=None,
                   runtime_version=None,
                   python_version=None,
                   stream_logs=None,
                   user_args=None,
                   labels=None,
                   custom_train_server_config=None):
    """Submit a training job."""
    region = properties.VALUES.compute.region.Get(required=True)
    staging_location = jobs_prep.GetStagingLocation(
        staging_bucket=staging_bucket, job_id=job, job_dir=job_dir)
    try:
        uris = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=staging_location)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If local packages are provided, the `--staging-bucket` or '
            '`--job-dir` flag must be given.')
    log.debug('Using {0} as trainer uris'.format(uris))

    scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
    scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

    try:
        job = jobs_client.BuildTrainingJob(
            path=config,
            module_name=module_name,
            job_name=job,
            trainer_uri=uris,
            region=region,
            job_dir=job_dir.ToUrl() if job_dir else None,
            scale_tier=scale_tier,
            user_args=user_args,
            runtime_version=runtime_version,
            python_version=python_version,
            labels=labels,
            custom_train_server_config=custom_train_server_config)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If `--package-path` is not specified, at least one Python package '
            'must be specified via `--packages`.')

    project_ref = resources.REGISTRY.Parse(
        properties.VALUES.core.project.Get(required=True),
        collection='ml.projects')
    job = jobs_client.Create(project_ref, job)
    if not stream_logs:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
        return job
    else:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)

    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job.jobId),
        polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
        continue_interval=_CONTINUE_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job.jobId))

    printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)
    with execution_utils.RaisesKeyboardInterrupt():
        try:
            printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
        except KeyboardInterrupt:
            log.status.Print('Received keyboard interrupt.\n')
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))
        except exceptions.HttpError as err:
            log.status.Print('Polling logs failed:\n{}\n'.format(
                six.text_type(err)))
            log.info('Failure details:', exc_info=True)
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))

    job_ref = resources.REGISTRY.Parse(
        job.jobId,
        params={'projectsId': properties.VALUES.core.project.GetOrFail},
        collection='ml.projects.jobs')
    job = jobs_client.Get(job_ref)

    return job
コード例 #4
0
def UploadPythonPackages(packages=(), package_path=None, staging_location=None,
                         supports_container_training=False):
  """Uploads Python packages (if necessary), building them as-specified.

  An AI Platform job needs one or more Python packages to run. These Python
  packages can be specified in one of three ways:

    1. As a path to a local, pre-built Python package file.
    2. As a path to a Cloud Storage-hosted, pre-built Python package file (paths
       beginning with 'gs://').
    3. As a local Python source tree (the `--package-path` flag).

  In case 1, we upload the local files to Cloud Storage[1] and provide their
  paths. These can then be given to the AI Platform API, which can fetch
  these files.

  In case 2, we don't need to do anything. We can just send these paths directly
  to the AI Platform API.

  In case 3, we perform a build using setuptools[2], and upload the resulting
  artifacts to Cloud Storage[1]. The paths to these artifacts can be given to
  the AI Platform API. See the `BuildPackages` method.

  These methods of specifying Python packages may be combined.


  [1] Uploads are to a specially-prefixed location in a user-provided Cloud
  Storage staging bucket. If the user provides bucket `gs://my-bucket/`, a file
  `package.tar.gz` is uploaded to
  `gs://my-bucket/<job name>/<checksum>/package.tar.gz`.

  [2] setuptools must be installed on the local user system.

  Args:
    packages: list of str. Path to extra tar.gz packages to upload, if any. If
      empty, a package_path must be provided.
    package_path: str. Relative path to source directory to be built, if any. If
      omitted, one or more packages must be provided.
    staging_location: storage_util.ObjectReference. Cloud Storage prefix to
      which archives are uploaded. Not necessary if only remote packages are
      given.
    supports_container_training: bool, if this release track supports container
      training. If containiner training is requested then uploads are not
      required.

  Returns:
    list of str. Fully qualified Cloud Storage URLs (`gs://..`) from uploaded
      packages.

  Raises:
    ValueError: If packages is empty, and building package_path produces no
      tar archives.
    SetuptoolsFailedError: If the setup.py file fails to successfully build.
    MissingInitError: If the package doesn't contain an `__init__.py` file.
    DuplicateEntriesError: If multiple files with the same name were provided.
    ArgumentError: if no packages were found in the given path or no
      staging_location was but uploads were required.
  """
  remote_paths = []
  local_paths = []
  for package in packages:
    if storage_util.ObjectReference.IsStorageUrl(package):
      remote_paths.append(package)
    else:
      local_paths.append(package)

  if package_path:
    package_root = os.path.dirname(os.path.abspath(package_path))
    with _TempDirOrBackup(package_root) as working_dir:
      local_paths.extend(BuildPackages(package_path,
                                       os.path.join(working_dir, 'output')))
      remote_paths.extend(_UploadFilesByPath(local_paths, staging_location))
  elif local_paths:
    # Can't combine this with above because above requires the temporary
    # directory to still be around
    remote_paths.extend(_UploadFilesByPath(local_paths, staging_location))

  # For custom container training, uploads are not required.
  if not remote_paths and not supports_container_training:
    raise flags.ArgumentError(_NO_PACKAGES_ERROR_MSG)
  return remote_paths