Beispiel #1
0
    def testYieldLogsStopsAppropriatelyShorterContinueInterval(self):
        expected_logs = [
            Log('2017-01-20T17:28:22.929735908Z', 'foo0'),
            Log('2017-01-20T17:28:22.929735909Z', 'foo1'),
        ]
        self.log_fetcher_mock.side_effect = [[expected_logs[0]], [],
                                             [expected_logs[1]], [], []]

        # Note that 10 / 5 = 2, so we'll expect 2 _ContinueFunc calls for every log
        # poll
        fetcher = stream.LogFetcher(continue_func=self._ContinueFunc,
                                    polling_interval=10,
                                    continue_interval=5)

        logs = fetcher.YieldLogs()

        self.assertEqual(next(logs), expected_logs[0])
        self.assertEqual(self.time_slept, 0)
        self.assertEqual(self.log_fetcher_mock.call_count, 1)
        self.assertEqual(self.continue_func_calls, [])

        self.assertEqual(next(logs), expected_logs[1])
        self.assertEqual(self.time_slept, 20)
        self.assertEqual(self.log_fetcher_mock.call_count, 3)
        self.assertEqual(self.continue_func_calls, [0, 0, 1, 1])

        with self.assertRaises(StopIteration):
            next(logs)
        self.assertEqual(self.time_slept, 40)
        self.assertEqual(self.log_fetcher_mock.call_count, 5)
        self.assertEqual(self.continue_func_calls, [0, 0, 1, 1, 0, 0, 1, 1, 2])
def StreamLogs(job, task_name, polling_interval, allow_multiline_logs):
    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job, task_name),
        polling_interval=polling_interval,
        continue_interval=_CONTINUE_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job))
    return log_utils.SplitMultiline(log_fetcher.YieldLogs(),
                                    allow_multiline=allow_multiline_logs)
Beispiel #3
0
 def Run(self, args):
   """Run the stream-logs command."""
   log_fetcher = stream.LogFetcher(
       filters=log_utils.LogFilters(args.job, args.task_name),
       polling_interval=args.polling_interval,
       continue_func=log_utils.MakeContinueFunction(args.job))
   return log_utils.SplitMultiline(
       log_fetcher.YieldLogs(), allow_multiline=args.allow_multiline_logs)
Beispiel #4
0
    def Run(self, args):
        printer = logs_util.LogPrinter()
        printer.RegisterFormatter(logs_util.FormatRequestLogEntry)
        printer.RegisterFormatter(logs_util.FormatAppEntry)
        project = properties.VALUES.core.project.Get(required=True)
        filters = logs_util.GetFilters(project, args.logs, args.service,
                                       args.version, args.level)

        log_fetcher = stream.LogFetcher(filters=filters, polling_interval=1)
        for log_entry in log_fetcher.YieldLogs():
            log.out.Print(printer.Format(log_entry))
Beispiel #5
0
    def SetUp(self):
        # This is a list of lists where each poll returns the next batch of logs.
        self.logs = []
        self.continue_func_calls = []
        self.fetcher = stream.LogFetcher(continue_func=self._ContinueFunc)
        self.log_fetcher_mock = self.StartObjectPatch(common, 'FetchLogs')
        self.time_slept = 0

        def _IncrementSleepTime(x):
            self.time_slept += x

        self.sleep_mock = self.StartPatch('time.sleep',
                                          side_effect=_IncrementSleepTime)
Beispiel #6
0
    def testFiltersAreAdded(self):
        continue_func = lambda num_empty_polls: num_empty_polls == 0
        filters = ['insertId>=foo', 'random irrelevant filter']
        custom_fetcher = stream.LogFetcher(continue_func=continue_func,
                                           filters=filters)
        log1 = Log('2017-01-20T17:28:22.929735908Z', 'foo')
        log2 = Log('2017-01-20T17:28:22.929735908Z', 'foo2')
        self.log_fetcher_mock.return_value = [log1, log2]

        logs = custom_fetcher.GetLogs()

        self.assertEqual([log1, log2], logs)
        _, kwargs = self.log_fetcher_mock.call_args
        filter_string = kwargs['log_filter']
        for filter_ in filters:
            self.assertIn(filter_, filter_string)
Beispiel #7
0
def StreamLogs(name, continue_function, polling_interval, task_name,
               allow_multiline):
  """Returns the streaming log of the job by id.

  Args:
    name: string id of the entity.
    continue_function: One-arg function that takes in the number of empty polls
      and outputs a boolean to decide if we should keep polling or not. If not
      given, keep polling indefinitely.
    polling_interval: amount of time to sleep between each poll.
    task_name: String name of task.
    allow_multiline: Tells us if logs with multiline messages are okay or not.
  """
  log_fetcher = stream.LogFetcher(
      filters=_LogFilters(name, task_name=task_name),
      polling_interval=polling_interval,
      continue_interval=_CONTINUE_INTERVAL,
      continue_func=continue_function)
  return _SplitMultiline(log_fetcher.YieldLogs(), allow_multiline)
Beispiel #8
0
    def testYieldLogsStopsAppropriatelyLongerContinueInterval(self):
        expected_logs = [
            Log('2017-01-20T17:28:22.929735908Z', 'foo0'),
            Log('2017-01-20T17:28:22.929735909Z', 'foo1'),
            Log('2017-01-20T17:28:22.929735910Z', 'foo2')
        ]
        self.log_fetcher_mock.side_effect = [[expected_logs[0]],
                                             [expected_logs[1]], [],
                                             [expected_logs[2]], [], [], []]

        # Note that 20 / 10 = 2, so we'll expect 2 log polls for every _ContinueFunc
        # call
        fetcher = stream.LogFetcher(continue_func=self._ContinueFunc,
                                    polling_interval=10,
                                    continue_interval=20)

        logs = fetcher.YieldLogs()

        self.assertEqual(next(logs), expected_logs[0])
        self.assertEqual(self.time_slept, 0)
        self.assertEqual(self.log_fetcher_mock.call_count, 1)
        self.assertEqual(self.continue_func_calls, [])

        self.assertEqual(next(logs), expected_logs[1])
        self.assertEqual(self.time_slept, 10)
        self.assertEqual(self.log_fetcher_mock.call_count, 2)
        self.assertEqual(self.continue_func_calls, [0])

        self.assertEqual(next(logs), expected_logs[2])
        self.assertEqual(self.time_slept, 30)
        self.assertEqual(self.log_fetcher_mock.call_count, 4)
        self.assertEqual(self.continue_func_calls, [0, 1])

        with self.assertRaises(StopIteration):
            next(logs)
        self.assertEqual(self.time_slept, 60)
        self.assertEqual(self.log_fetcher_mock.call_count, 7)
        self.assertEqual(self.continue_func_calls, [0, 1, 1, 3])
def SubmitTraining(jobs_client,
                   job,
                   job_dir=None,
                   staging_bucket=None,
                   packages=None,
                   package_path=None,
                   scale_tier=None,
                   config=None,
                   module_name=None,
                   runtime_version=None,
                   python_version=None,
                   stream_logs=None,
                   user_args=None,
                   labels=None,
                   custom_train_server_config=None):
    """Submit a training job."""
    region = properties.VALUES.compute.region.Get(required=True)
    staging_location = jobs_prep.GetStagingLocation(
        staging_bucket=staging_bucket, job_id=job, job_dir=job_dir)
    try:
        uris = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=staging_location)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If local packages are provided, the `--staging-bucket` or '
            '`--job-dir` flag must be given.')
    log.debug('Using {0} as trainer uris'.format(uris))

    scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
    scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

    try:
        job = jobs_client.BuildTrainingJob(
            path=config,
            module_name=module_name,
            job_name=job,
            trainer_uri=uris,
            region=region,
            job_dir=job_dir.ToUrl() if job_dir else None,
            scale_tier=scale_tier,
            user_args=user_args,
            runtime_version=runtime_version,
            python_version=python_version,
            labels=labels,
            custom_train_server_config=custom_train_server_config)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If `--package-path` is not specified, at least one Python package '
            'must be specified via `--packages`.')

    project_ref = resources.REGISTRY.Parse(
        properties.VALUES.core.project.Get(required=True),
        collection='ml.projects')
    job = jobs_client.Create(project_ref, job)
    if not stream_logs:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=True)
        return job
    else:
        PrintSubmitFollowUp(job.jobId, print_follow_up_message=False)

    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job.jobId),
        polling_interval=properties.VALUES.ml_engine.polling_interval.GetInt(),
        continue_interval=_CONTINUE_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job.jobId))

    printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)
    with execution_utils.RaisesKeyboardInterrupt():
        try:
            printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
        except KeyboardInterrupt:
            log.status.Print('Received keyboard interrupt.\n')
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))
        except exceptions.HttpError as err:
            log.status.Print('Polling logs failed:\n{}\n'.format(
                six.text_type(err)))
            log.info('Failure details:', exc_info=True)
            log.status.Print(
                _FOLLOW_UP_MESSAGE.format(job_id=job.jobId,
                                          project=project_ref.Name()))

    job_ref = resources.REGISTRY.Parse(
        job.jobId,
        params={'projectsId': properties.VALUES.core.project.GetOrFail},
        collection='ml.projects.jobs')
    job = jobs_client.Get(job_ref)

    return job
Beispiel #10
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        region = properties.VALUES.compute.region.Get(required=True)
        staging_location = jobs_prep.GetStagingLocation(
            staging_bucket=args.staging_bucket,
            job_id=args.job,
            job_dir=args.job_dir)
        try:
            uris = jobs_prep.UploadPythonPackages(
                packages=args.packages,
                package_path=args.package_path,
                staging_location=staging_location)
        except jobs_prep.NoStagingLocationError:
            raise flags.ArgumentError(
                'If local packages are provided, the `--staging-bucket` or '
                '`--job-dir` flag must be given.')
        log.debug('Using {0} as trainer uris'.format(uris))

        scale_tier_enum = (jobs.GetMessagesModule(
        ).GoogleCloudMlV1beta1TrainingInput.ScaleTierValueValuesEnum)
        scale_tier = scale_tier_enum(
            args.scale_tier) if args.scale_tier else None
        job = jobs.BuildTrainingJob(
            path=args.config,
            module_name=args.module_name,
            job_name=args.job,
            trainer_uri=uris,
            region=region,
            job_dir=args.job_dir.ToUrl() if args.job_dir else None,
            scale_tier=scale_tier,
            user_args=args.user_args,
            runtime_version=args.runtime_version)

        jobs_client = jobs.JobsClient()
        project_ref = resources.REGISTRY.Parse(
            properties.VALUES.core.project.Get(required=True),
            collection='ml.projects')
        job = jobs_client.Create(project_ref, job)
        log.status.Print('Job [{}] submitted successfully.'.format(job.jobId))
        if args. async:
            log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))
            return job

        log_fetcher = stream.LogFetcher(
            filters=log_utils.LogFilters(job.jobId),
            polling_interval=_POLLING_INTERVAL,
            continue_func=log_utils.MakeContinueFunction(job.jobId))

        printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)

        def _CtrlCHandler(signal, frame):
            del signal, frame  # Unused
            raise KeyboardInterrupt

        with execution_utils.CtrlCSection(_CtrlCHandler):
            try:
                printer.Print(log_utils.SplitMultiline(
                    log_fetcher.YieldLogs()))
            except KeyboardInterrupt:
                log.status.Print('Received keyboard interrupt.')
                log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))

        job_ref = resources.REGISTRY.Parse(job.jobId,
                                           collection='ml.projects.jobs')
        job = jobs_client.Get(job_ref)
        # If the job itself failed, we will return a failure status.
        if job.state is not job.StateValueValuesEnum.SUCCEEDED:
            self.exit_code = 1

        return job
Beispiel #11
0
def SubmitTraining(jobs_client,
                   job,
                   job_dir=None,
                   staging_bucket=None,
                   packages=None,
                   package_path=None,
                   scale_tier=None,
                   config=None,
                   module_name=None,
                   runtime_version=None,
                   async_=None,
                   user_args=None):
    """Submit a training job."""
    region = properties.VALUES.compute.region.Get(required=True)
    staging_location = jobs_prep.GetStagingLocation(
        staging_bucket=staging_bucket, job_id=job, job_dir=job_dir)
    try:
        uris = jobs_prep.UploadPythonPackages(
            packages=packages,
            package_path=package_path,
            staging_location=staging_location)
    except jobs_prep.NoStagingLocationError:
        raise flags.ArgumentError(
            'If local packages are provided, the `--staging-bucket` or '
            '`--job-dir` flag must be given.')
    log.debug('Using {0} as trainer uris'.format(uris))

    scale_tier_enum = jobs_client.training_input_class.ScaleTierValueValuesEnum
    scale_tier = scale_tier_enum(scale_tier) if scale_tier else None

    job = jobs_client.BuildTrainingJob(
        path=config,
        module_name=module_name,
        job_name=job,
        trainer_uri=uris,
        region=region,
        job_dir=job_dir.ToUrl() if job_dir else None,
        scale_tier=scale_tier,
        user_args=user_args,
        runtime_version=runtime_version)

    project_ref = resources.REGISTRY.Parse(
        properties.VALUES.core.project.Get(required=True),
        collection='ml.projects')
    job = jobs_client.Create(project_ref, job)
    log.status.Print('Job [{}] submitted successfully.'.format(job.jobId))
    if async_:
        log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))
        return job

    log_fetcher = stream.LogFetcher(
        filters=log_utils.LogFilters(job.jobId),
        polling_interval=_POLLING_INTERVAL,
        continue_func=log_utils.MakeContinueFunction(job.jobId))

    printer = resource_printer.Printer(log_utils.LOG_FORMAT, out=log.err)
    with execution_utils.RaisesKeyboardInterrupt():
        try:
            printer.Print(log_utils.SplitMultiline(log_fetcher.YieldLogs()))
        except KeyboardInterrupt:
            log.status.Print('Received keyboard interrupt.\n')
            log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))
        except exceptions.HttpError as err:
            log.status.Print('Polling logs failed:\n{}\n'.format(str(err)))
            log.info('Failure details:', exc_info=True)
            log.status.Print(_FOLLOW_UP_MESSAGE.format(job_id=job.jobId))

    job_ref = resources.REGISTRY.Parse(job.jobId,
                                       collection='ml.projects.jobs')
    job = jobs_client.Get(job_ref)

    return job