def testYieldLogsStopsAppropriatelyShorterContinueInterval(self): expected_logs = [ Log('2017-01-20T17:28:22.929735908Z', 'foo0'), Log('2017-01-20T17:28:22.929735909Z', 'foo1'), ] self.log_fetcher_mock.side_effect = [[expected_logs[0]], [], [expected_logs[1]], [], []] # Note that 10 / 5 = 2, so we'll expect 2 _ContinueFunc calls for every log # poll fetcher = stream.LogFetcher(continue_func=self._ContinueFunc, polling_interval=10, continue_interval=5) logs = fetcher.YieldLogs() self.assertEqual(next(logs), expected_logs[0]) self.assertEqual(self.time_slept, 0) self.assertEqual(self.log_fetcher_mock.call_count, 1) self.assertEqual(self.continue_func_calls, []) self.assertEqual(next(logs), expected_logs[1]) self.assertEqual(self.time_slept, 20) self.assertEqual(self.log_fetcher_mock.call_count, 3) self.assertEqual(self.continue_func_calls, [0, 0, 1, 1]) with self.assertRaises(StopIteration): next(logs) self.assertEqual(self.time_slept, 40) self.assertEqual(self.log_fetcher_mock.call_count, 5) self.assertEqual(self.continue_func_calls, [0, 0, 1, 1, 0, 0, 1, 1, 2])
def StreamLogs(job, task_name, polling_interval, allow_multiline_logs): log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(job, task_name), polling_interval=polling_interval, continue_interval=_CONTINUE_INTERVAL, continue_func=log_utils.MakeContinueFunction(job)) return log_utils.SplitMultiline(log_fetcher.YieldLogs(), allow_multiline=allow_multiline_logs)
def Run(self, args): """Run the stream-logs command.""" log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(args.job, args.task_name), polling_interval=args.polling_interval, continue_func=log_utils.MakeContinueFunction(args.job)) return log_utils.SplitMultiline( log_fetcher.YieldLogs(), allow_multiline=args.allow_multiline_logs)
def Run(self, args): printer = logs_util.LogPrinter() printer.RegisterFormatter(logs_util.FormatRequestLogEntry) printer.RegisterFormatter(logs_util.FormatAppEntry) project = properties.VALUES.core.project.Get(required=True) filters = logs_util.GetFilters(project, args.logs, args.service, args.version, args.level) log_fetcher = stream.LogFetcher(filters=filters, polling_interval=1) for log_entry in log_fetcher.YieldLogs(): log.out.Print(printer.Format(log_entry))
def SetUp(self): # This is a list of lists where each poll returns the next batch of logs. self.logs = [] self.continue_func_calls = [] self.fetcher = stream.LogFetcher(continue_func=self._ContinueFunc) self.log_fetcher_mock = self.StartObjectPatch(common, 'FetchLogs') self.time_slept = 0 def _IncrementSleepTime(x): self.time_slept += x self.sleep_mock = self.StartPatch('time.sleep', side_effect=_IncrementSleepTime)
def testFiltersAreAdded(self): continue_func = lambda num_empty_polls: num_empty_polls == 0 filters = ['insertId>=foo', 'random irrelevant filter'] custom_fetcher = stream.LogFetcher(continue_func=continue_func, filters=filters) log1 = Log('2017-01-20T17:28:22.929735908Z', 'foo') log2 = Log('2017-01-20T17:28:22.929735908Z', 'foo2') self.log_fetcher_mock.return_value = [log1, log2] logs = custom_fetcher.GetLogs() self.assertEqual([log1, log2], logs) _, kwargs = self.log_fetcher_mock.call_args filter_string = kwargs['log_filter'] for filter_ in filters: self.assertIn(filter_, filter_string)
def StreamLogs(name, continue_function, polling_interval, task_name, allow_multiline): """Returns the streaming log of the job by id. Args: name: string id of the entity. continue_function: One-arg function that takes in the number of empty polls and outputs a boolean to decide if we should keep polling or not. If not given, keep polling indefinitely. polling_interval: amount of time to sleep between each poll. task_name: String name of task. allow_multiline: Tells us if logs with multiline messages are okay or not. """ log_fetcher = stream.LogFetcher( filters=_LogFilters(name, task_name=task_name), polling_interval=polling_interval, continue_interval=_CONTINUE_INTERVAL, continue_func=continue_function) return _SplitMultiline(log_fetcher.YieldLogs(), allow_multiline)
def testYieldLogsStopsAppropriatelyLongerContinueInterval(self): expected_logs = [ Log('2017-01-20T17:28:22.929735908Z', 'foo0'), Log('2017-01-20T17:28:22.929735909Z', 'foo1'), Log('2017-01-20T17:28:22.929735910Z', 'foo2') ] self.log_fetcher_mock.side_effect = [[expected_logs[0]], [expected_logs[1]], [], [expected_logs[2]], [], [], []] # Note that 20 / 10 = 2, so we'll expect 2 log polls for every _ContinueFunc # call fetcher = stream.LogFetcher(continue_func=self._ContinueFunc, polling_interval=10, continue_interval=20) logs = fetcher.YieldLogs() self.assertEqual(next(logs), expected_logs[0]) self.assertEqual(self.time_slept, 0) self.assertEqual(self.log_fetcher_mock.call_count, 1) self.assertEqual(self.continue_func_calls, []) self.assertEqual(next(logs), expected_logs[1]) self.assertEqual(self.time_slept, 10) self.assertEqual(self.log_fetcher_mock.call_count, 2) self.assertEqual(self.continue_func_calls, [0]) self.assertEqual(next(logs), expected_logs[2]) self.assertEqual(self.time_slept, 30) self.assertEqual(self.log_fetcher_mock.call_count, 4) self.assertEqual(self.continue_func_calls, [0, 1]) with self.assertRaises(StopIteration): next(logs) self.assertEqual(self.time_slept, 60) self.assertEqual(self.log_fetcher_mock.call_count, 7) self.assertEqual(self.continue_func_calls, [0, 1, 1, 3])
def SubmitTraining(jobs_client, job, job_dir=None, staging_bucket=None, packages=None, package_path=None, scale_tier=None, config=None, module_name=None, runtime_version=None, python_version=None, stream_logs=None, user_args=None, labels=None, custom_train_server_config=None): """Submit a training job.""" region = properties.VALUES.compute.region.Get(required=True) staging_location = jobs_prep.GetStagingLocation( staging_bucket=staging_bucket, job_id=job, job_dir=job_dir) try: uris = jobs_prep.UploadPythonPackages( packages=packages, package_path=package_path, staging_location=staging_location) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If local packages are provided, the `--staging-bucket` or ' '`--job-dir` flag must be given.') log.debug('Using {0} as trainer uris'.format(uris)) scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum scale_tier = scale_tier_enum(scale_tier) if scale_tier else None try: job = jobs_client.BuildTrainingJob( path=config, module_name=module_name, job_name=job, trainer_uri=uris, region=region, job_dir=job_dir.ToUrl() if job_dir else None, scale_tier=scale_tier, user_args=user_args, runtime_version=runtime_version, python_version=python_version, labels=labels, custom_train_server_config=custom_train_server_config) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If `--package-path` is not specified, at least one Python package ' 'must be specified via `--packages`.') project_ref = resources.REGISTRY.Parse( properties.VALUES.core.project.Get(required=True), collection='ml.projects') job = jobs_client.Create(project_ref, job) if not stream_logs: PrintSubmitFollowUp(job.jobId, print_follow_up_message=True) return job else: PrintSubmitFollowUp(job.jobId, print_follow_up_message=False) log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(job.jobId), polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(), continue_interval=_CONTINUE_INTERVAL, continue_func=log_utils.MakeContinueFunction(job.jobId)) printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err) with execution_utils.RaisesKeyboardInterrupt(): try: printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs())) except KeyboardInterrupt: log.status.Print('Received keyboard interrupt.\n') log.status.Print( _FOLLOW_UP_MESSAGE.format(job_id=job.jobId, project=project_ref.Name())) except exceptions.HttpError as err: log.status.Print('Polling logs failed:\n{}\n'.format( six.text_type(err))) log.info('Failure details:', exc_info=True) log.status.Print( _FOLLOW_UP_MESSAGE.format(job_id=job.jobId, project=project_ref.Name())) job_ref = resources.REGISTRY.Parse( job.jobId, params={'projectsId': properties.VALUES.core.project.GetOrFail}, collection='ml.projects.jobs') job = jobs_client.Get(job_ref) return job
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ region = properties.VALUES.compute.region.Get(required=True) staging_location = jobs_prep.GetStagingLocation( staging_bucket=args.staging_bucket, job_id=args.job, job_dir=args.job_dir) try: uris = jobs_prep.UploadPythonPackages( packages=args.packages, package_path=args.package_path, staging_location=staging_location) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If local packages are provided, the `--staging-bucket` or ' '`--job-dir` flag must be given.') log.debug('Using {0} as trainer uris'.format(uris)) scale_tier_enum = (jobs.GetMessagesModule( ).GoogleCloudMlV1beta1TrainingInput.ScaleTierValueValuesEnum) scale_tier = scale_tier_enum( args.scale_tier) if args.scale_tier else None job = jobs.BuildTrainingJob( path=args.config, module_name=args.module_name, job_name=args.job, trainer_uri=uris, region=region, job_dir=args.job_dir.ToUrl() if args.job_dir else None, scale_tier=scale_tier, user_args=args.user_args, runtime_version=args.runtime_version) jobs_client = jobs.JobsClient() project_ref = resources.REGISTRY.Parse( properties.VALUES.core.project.Get(required=True), collection='ml.projects') job = jobs_client.Create(project_ref, job) log.status.Print('Job [{}] submitted successfully.'.format(job.jobId)) if args. async: log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId)) return job log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(job.jobId), polling_interval=_POLLING_INTERVAL, continue_func=log_utils.MakeContinueFunction(job.jobId)) printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err) def _CtrlCHandler(signal, frame): del signal, frame # Unused raise KeyboardInterrupt with execution_utils.CtrlCSection(_CtrlCHandler): try: printer.Print(log_utils.SplitMultiline( log_fetcher.YieldLogs())) except KeyboardInterrupt: log.status.Print('Received keyboard interrupt.') log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId)) job_ref = resources.REGISTRY.Parse(job.jobId, collection='ml.projects.jobs') job = jobs_client.Get(job_ref) # If the job itself failed, we will return a failure status. if job.state is not job.StateValueValuesEnum.SUCCEEDED: self.exit_code = 1 return job
def SubmitTraining(jobs_client, job, job_dir=None, staging_bucket=None, packages=None, package_path=None, scale_tier=None, config=None, module_name=None, runtime_version=None, async_=None, user_args=None): """Submit a training job.""" region = properties.VALUES.compute.region.Get(required=True) staging_location = jobs_prep.GetStagingLocation( staging_bucket=staging_bucket, job_id=job, job_dir=job_dir) try: uris = jobs_prep.UploadPythonPackages( packages=packages, package_path=package_path, staging_location=staging_location) except jobs_prep.NoStagingLocationError: raise flags.ArgumentError( 'If local packages are provided, the `--staging-bucket` or ' '`--job-dir` flag must be given.') log.debug('Using {0} as trainer uris'.format(uris)) scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum scale_tier = scale_tier_enum(scale_tier) if scale_tier else None job = jobs_client.BuildTrainingJob( path=config, module_name=module_name, job_name=job, trainer_uri=uris, region=region, job_dir=job_dir.ToUrl() if job_dir else None, scale_tier=scale_tier, user_args=user_args, runtime_version=runtime_version) project_ref = resources.REGISTRY.Parse( properties.VALUES.core.project.Get(required=True), collection='ml.projects') job = jobs_client.Create(project_ref, job) log.status.Print('Job [{}] submitted successfully.'.format(job.jobId)) if async_: log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId)) return job log_fetcher = stream.LogFetcher( filters=log_utils.LogFilters(job.jobId), polling_interval=_POLLING_INTERVAL, continue_func=log_utils.MakeContinueFunction(job.jobId)) printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err) with execution_utils.RaisesKeyboardInterrupt(): try: printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs())) except KeyboardInterrupt: log.status.Print('Received keyboard interrupt.\n') log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId)) except exceptions.HttpError as err: log.status.Print('Polling logs failed:\n{}\n'.format(str(err))) log.info('Failure details:', exc_info=True) log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId)) job_ref = resources.REGISTRY.Parse(job.jobId, collection='ml.projects.jobs') job = jobs_client.Get(job_ref) return job