Beispiel #1
0
    def testUploadPythonPackages_MixedPackages(self):
        package_path = os.path.join('/path/to/package-root/package_name')
        packages = ['gs://bucket1/package.tar.gz', 'local-package.tar.gz']

        storage_paths = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=self.staging_location)

        self.build_packages_mock.assert_called_once_with(
            package_path, mock.ANY)
        self.upload_mock.assert_has_calls([
            mock.call([('local-package.tar.gz', 'local-package.tar.gz'),
                       (mock.ANY, 'built-package.tar.gz'),
                       (mock.ANY, 'built-package2.whl')], self.bucket_ref,
                      'job_name')
        ],
                                          any_order=True)
        self.assertSetEqual(
            set(storage_paths), {
                'gs://bucket1/package.tar.gz',
                'gs://bucket/job_name/DEADBEEF/local-package.tar.gz',
                'gs://bucket/job_name/DEADBEEF/built-package.tar.gz',
                'gs://bucket/job_name/DEADBEEF/built-package2.whl'
            })
Beispiel #2
0
    def testUploadPythonPackages_OnlyRemotePackages(self):
        packages = [
            'gs://bucket1/package.tar.gz', 'gs://bucket2/path/package2.tar.gz'
        ]

        storage_paths = jobs_prep.UploadPythonPackages(packages=packages)

        self.upload_mock.assert_not_called()
        self.build_packages_mock.assert_not_called()
        self.assertEqual(storage_paths, packages)
Beispiel #3
0
    def testUploadPythonPackages_DuplicateFilenames(self):
        packages = [
            os.path.join('path', 'to', 'package.tar.gz'), 'package.tar.gz'
        ]

        with self.AssertRaisesExceptionMatches(
                jobs_prep.DuplicateEntriesError,
                'Cannot upload multiple packages with the same filename: '
                '[package.tar.gz]'):
            jobs_prep.UploadPythonPackages(
                packages=packages, staging_location=self.staging_location)

        self.upload_mock.assert_not_called()
        self.build_packages_mock.assert_not_called()
Beispiel #4
0
    def testUploadPythonPackages_SourcePackage(self):
        package_path = os.path.join('/path/to/package-root/package_name')

        storage_paths = jobs_prep.UploadPythonPackages(
            package_path=package_path, staging_location=self.staging_location)

        self.build_packages_mock.assert_called_once_with(
            package_path, mock.ANY)
        self.upload_mock.assert_called_once_with(
            [(mock.ANY, 'built-package.tar.gz'),
             (mock.ANY, 'built-package2.whl')], self.bucket_ref, 'job_name')
        self.assertSetEqual(
            set(storage_paths), {
                'gs://bucket/job_name/DEADBEEF/built-package.tar.gz',
                'gs://bucket/job_name/DEADBEEF/built-package2.whl'
            })
Beispiel #5
0
    def testUploadPythonPackages_OnlyLocalPackages(self):
        packages = [
            os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl'
        ]

        storage_paths = jobs_prep.UploadPythonPackages(
            packages=packages, staging_location=self.staging_location)

        self.upload_mock.assert_called_once_with(
            [(packages[0], 'package.tar.gz'),
             (packages[1], 'package2.whl')], self.bucket_ref, 'job_name')
        self.build_packages_mock.assert_not_called()
        self.assertEqual(storage_paths, [
            'gs://bucket/job_name/DEADBEEF/package.tar.gz',
            'gs://bucket/job_name/DEADBEEF/package2.whl'
        ])
Beispiel #6
0
    def testUploadPythonPackages_EmptyStagingLocation(self):
        staging_location = storage_util.ObjectReference.FromBucketRef(
            self.bucket_ref, '')
        packages = [
            os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl'
        ]

        storage_paths = jobs_prep.UploadPythonPackages(
            packages=packages, staging_location=staging_location)

        self.upload_mock.assert_called_once_with(
            [(packages[0], 'package.tar.gz'),
             (packages[1], 'package2.whl')], self.bucket_ref, '')
        self.build_packages_mock.assert_not_called()
        self.assertEqual(storage_paths, [
            'gs://bucket/DEADBEEF/package.tar.gz',
            'gs://bucket/DEADBEEF/package2.whl'
        ])
def SubmitTraining(jobs_client,
                   job,
                   job_dir=None,
                   staging_bucket=None,
                   packages=None,
                   package_path=None,
                   scale_tier=None,
                   config=None,
                   module_name=None,
                   runtime_version=None,
                   python_version=None,
                   stream_logs=None,
                   user_args=None,
                   labels=None,
                   custom_train_server_config=None):
    """Submit a training job."""
    region = properties.VALUES.compute.region.Get(required=True)
    staging_location = jobs_prep.GetStagingLocation(
        staging_bucket=staging_bucket, job_id=job, job_dir=job_dir)
    try:
        uris = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=staging_location)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If local packages are provided, the `--staging-bucket` or '
            '`--job-dir` flag must be given.')
    log.debug('Using {0} as trainer uris'.format(uris))

    scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
    scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

    try:
        job = jobs_client.BuildTrainingJob(
            path=config,
            module_name=module_name,
            job_name=job,
            trainer_uri=uris,
            region=region,
            job_dir=job_dir.ToUrl() if job_dir else None,
            scale_tier=scale_tier,
            user_args=user_args,
            runtime_version=runtime_version,
            python_version=python_version,
            labels=labels,
            custom_train_server_config=custom_train_server_config)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If `--package-path` is not specified, at least one Python package '
            'must be specified via `--packages`.')

    project_ref = resources.REGISTRY.Parse(
        properties.VALUES.core.project.Get(required=True),
        collection='ml.projects')
    job = jobs_client.Create(project_ref, job)
    if not stream_logs:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
        return job
    else:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)

    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job.jobId),
        polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
        continue_interval=_CONTINUE_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job.jobId))

    printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)
    with execution_utils.RaisesKeyboardInterrupt():
        try:
            printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
        except KeyboardInterrupt:
            log.status.Print('Received keyboard interrupt.\n')
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))
        except exceptions.HttpError as err:
            log.status.Print('Polling logs failed:\n{}\n'.format(
                six.text_type(err)))
            log.info('Failure details:', exc_info=True)
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))

    job_ref = resources.REGISTRY.Parse(
        job.jobId,
        params={'projectsId': properties.VALUES.core.project.GetOrFail},
        collection='ml.projects.jobs')
    job = jobs_client.Get(job_ref)

    return job
Beispiel #8
0
 def testUploadPythonPackages_UploadRequiredButNoStagingLocationGiven(self):
     packages = [
         os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl'
     ]
     with self.assertRaises(jobs_prep.NoStagingLocationError):
         jobs_prep.UploadPythonPackages(packages=packages)