def test_set_when_already_defined(self): self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) with logger.PerThreadLoggingContext(xyz='value'): self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') with logger.PerThreadLoggingContext(xyz='value2'): self.assertEqual( logger.per_thread_worker_data.get_data()['xyz'], 'value2') self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
def test_nested_with_per_thread_info(self): self.maxDiff = None tracker = statesampler.StateSampler('stage', CounterFactory()) statesampler.set_current_tracker(tracker) formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') with logger.PerThreadLoggingContext(work_item_id='workitem'): with tracker.scoped_state('step1', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output1 = json.loads(formatter.format(record)) with tracker.scoped_state('step2', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output2 = json.loads(formatter.format(record)) record = self.create_log_record(**self.SAMPLE_RECORD) log_output3 = json.loads(formatter.format(record)) statesampler.set_current_tracker(None) record = self.create_log_record(**self.SAMPLE_RECORD) log_output4 = json.loads(formatter.format(record)) self.assertEqual( log_output1, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual( log_output2, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2')) self.assertEqual( log_output3, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
def test_per_thread_attribute(self): self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) with logger.PerThreadLoggingContext(xyz='value'): self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') thread = threading.Thread( target=self.thread_check_attribute, args=('xyz', )) thread.start() thread.join() self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
def start(self): with self.scoped_start_state: super(DoOperation, self).start() # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = (pickler.loads( self.spec.serialized_fn)) state = common.DoFnState(self.counter_factory) state.step_name = self.name_context.logging_name() # Tag to output index map used to dispatch the side output values emitted # by the DoFn function to the appropriate receivers. The main output is # tagged with None and is associated with its corresponding index. self.tagged_receivers = _TaggedReceivers( self.counter_factory, self.name_context.logging_name()) output_tag_prefix = PropertyNames.OUT + '_' for index, tag in enumerate(self.spec.output_tags): if tag == PropertyNames.OUT: original_tag = None elif tag.startswith(output_tag_prefix): original_tag = tag[len(output_tag_prefix):] else: raise ValueError( 'Unexpected output name for operation: %s' % tag) self.tagged_receivers[original_tag] = self.receivers[index] if self.side_input_maps is None: if tags_and_types: self.side_input_maps = list( self._read_side_inputs(tags_and_types)) else: self.side_input_maps = [] self.dofn_runner = common.DoFnRunner( fn, args, kwargs, self.side_input_maps, window_fn, tagged_receivers=self.tagged_receivers, step_name=self.name_context.logging_name(), logging_context=logger.PerThreadLoggingContext( step_name=self.name_context.logging_name()), state=state, scoped_metrics_container=None) self.dofn_receiver = (self.dofn_runner if isinstance( self.dofn_runner, Receiver) else DoFnRunnerReceiver( self.dofn_runner)) self.dofn_runner.start()
def test_nested_with_per_thread_info(self): formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') with logger.PerThreadLoggingContext(work_item_id='workitem', stage_name='stage', step_name='step1'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output1 = json.loads(formatter.format(record)) with logger.PerThreadLoggingContext(step_name='step2'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output2 = json.loads(formatter.format(record)) record = self.create_log_record(**self.SAMPLE_RECORD) log_output3 = json.loads(formatter.format(record)) record = self.create_log_record(**self.SAMPLE_RECORD) log_output4 = json.loads(formatter.format(record)) self.assertEqual( log_output1, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual( log_output2, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2')) self.assertEqual( log_output3, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
def test_record_with_per_thread_info(self): self.maxDiff = None tracker = statesampler.StateSampler('stage', CounterFactory()) statesampler.set_current_tracker(tracker) formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') with logger.PerThreadLoggingContext(work_item_id='workitem'): with tracker.scoped_state('step', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output = json.loads(formatter.format(record)) expected_output = dict(self.SAMPLE_OUTPUT) expected_output.update( {'work': 'workitem', 'stage': 'stage', 'step': 'step'}) self.assertEqual(log_output, expected_output) statesampler.set_current_tracker(None)
def test_record_with_per_thread_info(self): with logger.PerThreadLoggingContext(work_item_id='workitem', stage_name='stage', step_name='step'): formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') record = self.create_log_record(**self.SAMPLE_RECORD) log_output = json.loads(formatter.format(record)) expected_output = dict(self.SAMPLE_OUTPUT) expected_output.update({ 'work': 'workitem', 'stage': 'stage', 'step': 'step' }) self.assertEqual(log_output, expected_output)
def thread_check_attribute(self, name): self.assertFalse(name in logger.per_thread_worker_data.get_data()) with logger.PerThreadLoggingContext(**{name: 'thread-value'}): self.assertEqual(logger.per_thread_worker_data.get_data()[name], 'thread-value') self.assertFalse(name in logger.per_thread_worker_data.get_data())