def testUploadPythonPackages_MixedPackages(self): package_path = os.path.join('/path/to/package-root/package_name') packages = ['gs://bucket1/package.tar.gz', 'local-package.tar.gz'] storage_paths = jobs_prep.UploadPythonPackages( packages=packages, package_path=package_path, staging_location=self.staging_location) self.build_packages_mock.assert_called_once_with( package_path, mock.ANY) self.upload_mock.assert_has_calls([ mock.call([('local-package.tar.gz', 'local-package.tar.gz'), (mock.ANY, 'built-package.tar.gz'), (mock.ANY, 'built-package2.whl')], self.bucket_ref, 'job_name') ], any_order=True) self.assertSetEqual( set(storage_paths), { 'gs://bucket1/package.tar.gz', 'gs://bucket/job_name/DEADBEEF/local-package.tar.gz', 'gs://bucket/job_name/DEADBEEF/built-package.tar.gz', 'gs://bucket/job_name/DEADBEEF/built-package2.whl' })
def testUploadPythonPackages_OnlyRemotePackages(self): packages = [ 'gs://bucket1/package.tar.gz', 'gs://bucket2/path/package2.tar.gz' ] storage_paths = jobs_prep.UploadPythonPackages(packages=packages) self.upload_mock.assert_not_called() self.build_packages_mock.assert_not_called() self.assertEqual(storage_paths, packages)
def testUploadPythonPackages_DuplicateFilenames(self): packages = [ os.path.join('path', 'to', 'package.tar.gz'), 'package.tar.gz' ] with self.AssertRaisesExceptionMatches( jobs_prep.DuplicateEntriesError, 'Cannot upload multiple packages with the same filename: ' '[package.tar.gz]'): jobs_prep.UploadPythonPackages( packages=packages, staging_location=self.staging_location) self.upload_mock.assert_not_called() self.build_packages_mock.assert_not_called()
def testUploadPythonPackages_SourcePackage(self): package_path = os.path.join('/path/to/package-root/package_name') storage_paths = jobs_prep.UploadPythonPackages( package_path=package_path, staging_location=self.staging_location) self.build_packages_mock.assert_called_once_with( package_path, mock.ANY) self.upload_mock.assert_called_once_with( [(mock.ANY, 'built-package.tar.gz'), (mock.ANY, 'built-package2.whl')], self.bucket_ref, 'job_name') self.assertSetEqual( set(storage_paths), { 'gs://bucket/job_name/DEADBEEF/built-package.tar.gz', 'gs://bucket/job_name/DEADBEEF/built-package2.whl' })
def testUploadPythonPackages_OnlyLocalPackages(self): packages = [ os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl' ] storage_paths = jobs_prep.UploadPythonPackages( packages=packages, staging_location=self.staging_location) self.upload_mock.assert_called_once_with( [(packages[0], 'package.tar.gz'), (packages[1], 'package2.whl')], self.bucket_ref, 'job_name') self.build_packages_mock.assert_not_called() self.assertEqual(storage_paths, [ 'gs://bucket/job_name/DEADBEEF/package.tar.gz', 'gs://bucket/job_name/DEADBEEF/package2.whl' ])
def testUploadPythonPackages_EmptyStagingLocation(self): staging_location = storage_util.ObjectReference.FromBucketRef( self.bucket_ref, '') packages = [ os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl' ] storage_paths = jobs_prep.UploadPythonPackages( packages=packages, staging_location=staging_location) self.upload_mock.assert_called_once_with( [(packages[0], 'package.tar.gz'), (packages[1], 'package2.whl')], self.bucket_ref, '') self.build_packages_mock.assert_not_called() self.assertEqual(storage_paths, [ 'gs://bucket/DEADBEEF/package.tar.gz', 'gs://bucket/DEADBEEF/package2.whl' ])
def SubmitTraining(jobs_client, job, job_dir=None, staging_bucket=None, packages=None, package_path=None, scale_tier=None, config=None, module_name=None, runtime_version=None, python_version=None, stream_logs=None, user_args=None, labels=None, custom_train_server_config=None): """Submit a training job.""" region = properties.VALUES.compute.region.Get(required=True) staging_location = jobs_prep.GetStagingLocation( staging_bucket=staging_bucket, job_id=job, job_dir=job_dir) try: uris = jobs_prep.UploadPythonPackages( packages=packages, package_path=package_path, staging_location=staging_location) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If local packages are provided, the `--staging-bucket` or ' '`--job-dir` flag must be given.') log.debug('Using {0} as trainer uris'.format(uris)) scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum scale_tier = scale_tier_enum(scale_tier) if scale_tier else None try: job = jobs_client.BuildTrainingJob( path=config, module_name=module_name, job_name=job, trainer_uri=uris, region=region, job_dir=job_dir.ToUrl() if job_dir else None, scale_tier=scale_tier, user_args=user_args, runtime_version=runtime_version, python_version=python_version, labels=labels, custom_train_server_config=custom_train_server_config) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If `--package-path` is not specified, at least one Python package ' 'must be specified via `--packages`.') project_ref = resources.REGISTRY.Parse( properties.VALUES.core.project.Get(required=True), collection='ml.projects') job = jobs_client.Create(project_ref, job) if not stream_logs: PrintSubmitFollowUp(job.jobId, print_follow_up_message=True) return job else: PrintSubmitFollowUp(job.jobId, print_follow_up_message=False) log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(job.jobId), polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(), continue_interval=_CONTINUE_INTERVAL, continue_func=log_utils.MakeContinueFunction(job.jobId)) printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err) with execution_utils.RaisesKeyboardInterrupt(): try: printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs())) except KeyboardInterrupt: log.status.Print('Received keyboard interrupt.\n') log.status.Print( _FOLLOW_UP_MESSAGE.format(job_id=job.jobId, project=project_ref.Name())) except exceptions.HttpError as err: log.status.Print('Polling logs failed:\n{}\n'.format( six.text_type(err))) log.info('Failure details:', exc_info=True) log.status.Print( _FOLLOW_UP_MESSAGE.format(job_id=job.jobId, project=project_ref.Name())) job_ref = resources.REGISTRY.Parse( job.jobId, params={'projectsId': properties.VALUES.core.project.GetOrFail}, collection='ml.projects.jobs') job = jobs_client.Get(job_ref) return job
def testUploadPythonPackages_UploadRequiredButNoStagingLocationGiven(self): packages = [ os.path.join('path', 'to', 'package.tar.gz'), 'package2.whl' ] with self.assertRaises(jobs_prep.NoStagingLocationError): jobs_prep.UploadPythonPackages(packages=packages)