Пример #1
0
def SubmitTraining(jobs_client,
                   job,
                   job_dir=None,
                   staging_bucket=None,
                   packages=None,
                   package_path=None,
                   scale_tier=None,
                   config=None,
                   module_name=None,
                   runtime_version=None,
                   python_version=None,
                   stream_logs=None,
                   user_args=None,
                   labels=None,
                   custom_train_server_config=None):
    """Submit a training job."""
    region = properties.VALUES.compute.region.Get(required=True)
    staging_location = jobs_prep.GetStagingLocation(
        staging_bucket=staging_bucket, job_id=job, job_dir=job_dir)
    try:
        uris = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=staging_location)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If local packages are provided, the `--staging-bucket` or '
            '`--job-dir` flag must be given.')
    log.debug('Using {0} as trainer uris'.format(uris))

    scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
    scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

    try:
        job = jobs_client.BuildTrainingJob(
            path=config,
            module_name=module_name,
            job_name=job,
            trainer_uri=uris,
            region=region,
            job_dir=job_dir.ToUrl() if job_dir else None,
            scale_tier=scale_tier,
            user_args=user_args,
            runtime_version=runtime_version,
            python_version=python_version,
            labels=labels,
            custom_train_server_config=custom_train_server_config)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If `--package-path` is not specified, at least one Python package '
            'must be specified via `--packages`.')

    project_ref = resources.REGISTRY.Parse(
        properties.VALUES.core.project.Get(required=True),
        collection='ml.projects')
    job = jobs_client.Create(project_ref, job)
    if not stream_logs:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
        return job
    else:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)

    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job.jobId),
        polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
        continue_interval=_CONTINUE_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job.jobId))

    printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)
    with execution_utils.RaisesKeyboardInterrupt():
        try:
            printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
        except KeyboardInterrupt:
            log.status.Print('Received keyboard interrupt.\n')
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))
        except exceptions.HttpError as err:
            log.status.Print('Polling logs failed:\n{}\n'.format(
                six.text_type(err)))
            log.info('Failure details:', exc_info=True)
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))

    job_ref = resources.REGISTRY.Parse(
        job.jobId,
        params={'projectsId': properties.VALUES.core.project.GetOrFail},
        collection='ml.projects.jobs')
    job = jobs_client.Get(job_ref)

    return job
Пример #2
0
 def testGetStagingLocation_JobDirAndStagingBucket(self):
     self.assertEqual(
         jobs_prep.GetStagingLocation(job_id='job_id',
                                      job_dir=self.job_dir,
                                      staging_bucket=self.staging_bucket),
         storage_util.ObjectReference.FromUrl('gs://staging-bucket/job_id'))
Пример #3
0
 def testGetStagingLocation_JobDir(self):
     self.assertEqual(
         jobs_prep.GetStagingLocation(job_dir=self.job_dir),
         storage_util.ObjectReference.FromUrl(
             'gs://job-bucket/job-dir/packages'))
Пример #4
0
 def testGetStagingLocation_JobDirEmptyPrefix(self):
     job_dir = storage_util.ObjectReference.FromUrl('gs://job-bucket',
                                                    allow_empty_object=True)
     self.assertEqual(
         jobs_prep.GetStagingLocation(job_dir=job_dir),
         storage_util.ObjectReference.FromUrl('gs://job-bucket/packages'))
Пример #5
0
 def testGetStagingLocation_NoFlags(self):
     self.assertEqual(jobs_prep.GetStagingLocation(), None)