def test_create_counter_distribution(self): sampler = statesampler.StateSampler('', counters.CounterFactory()) statesampler.set_current_tracker(sampler) state1 = sampler.scoped_state('mystep', 'myState', metrics_container=MetricsContainer('mystep')) sampler.start() with state1: counter_ns = 'aCounterNamespace' distro_ns = 'aDistributionNamespace' name = 'a_name' counter = Metrics.counter(counter_ns, name) distro = Metrics.distribution(distro_ns, name) counter.inc(10) counter.dec(3) distro.update(10) distro.update(2) self.assertTrue(isinstance(counter, Metrics.DelegatingCounter)) self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution)) del distro del counter container = MetricsEnvironment.current_container() self.assertEqual( container.counters[MetricName(counter_ns, name)].get_cumulative(), 7) self.assertEqual( container.distributions[MetricName(distro_ns, name)].get_cumulative(), DistributionData(12, 2, 2, 10)) sampler.stop()
def test_create_process_wide(self): sampler = statesampler.StateSampler('', counters.CounterFactory()) statesampler.set_current_tracker(sampler) state1 = sampler.scoped_state( 'mystep', 'myState', metrics_container=MetricsContainer('mystep')) try: sampler.start() with state1: urn = "my:custom:urn" labels = {'key': 'value'} counter = InternalMetrics.counter( urn=urn, labels=labels, process_wide=True) # Test that if process_wide is set, that it will be set # on the process_wide container. counter.inc(10) self.assertTrue(isinstance(counter, Metrics.DelegatingCounter)) del counter metric_name = MetricName(None, None, urn=urn, labels=labels) # Expect a value set on the current container. self.assertEqual( MetricsEnvironment.process_wide_container().get_counter( metric_name).get_cumulative(), 10) # Expect no value set on the current container. self.assertEqual( MetricsEnvironment.current_container().get_counter( metric_name).get_cumulative(), 0) finally: sampler.stop()
def test_nested_with_per_thread_info(self): self.maxDiff = None tracker = statesampler.StateSampler('stage', CounterFactory()) statesampler.set_current_tracker(tracker) formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') with logger.PerThreadLoggingContext(work_item_id='workitem'): with tracker.scoped_state('step1', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output1 = json.loads(formatter.format(record)) with tracker.scoped_state('step2', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output2 = json.loads(formatter.format(record)) record = self.create_log_record(**self.SAMPLE_RECORD) log_output3 = json.loads(formatter.format(record)) statesampler.set_current_tracker(None) record = self.create_log_record(**self.SAMPLE_RECORD) log_output4 = json.loads(formatter.format(record)) self.assertEqual( log_output1, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual( log_output2, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2')) self.assertEqual( log_output3, dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
def test_basic_sampler(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler('basic', counter_factory, sampling_period_ms=1) # Run basic workload transitioning between 3 states. sampler.start() with sampler.scoped_state('statea'): time.sleep(0.1) with sampler.scoped_state('stateb'): time.sleep(0.2 / 2) with sampler.scoped_state('statec'): time.sleep(0.3) time.sleep(0.2 / 2) sampler.stop() sampler.commit_counters() # Test that sampled state timings are close to their expected values. expected_counter_values = { 'basic-statea-msecs': 100, 'basic-stateb-msecs': 200, 'basic-statec-msecs': 300, } for counter in counter_factory.get_counters(): self.assertIn(counter.name, expected_counter_values) expected_value = expected_counter_values[counter.name] actual_value = counter.value() self.assertGreater(actual_value, expected_value * 0.75) self.assertLess(actual_value, expected_value * 1.25)
def test_sampler_transition_overhead(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler('overhead-', counter_factory, sampling_period_ms=10) # Run basic workload transitioning between 3 states. state_a = sampler.scoped_state('step1', 'statea') state_b = sampler.scoped_state('step1', 'stateb') state_c = sampler.scoped_state('step1', 'statec') start_time = time.time() sampler.start() for _ in range(100000): with state_a: with state_b: for _ in range(10): with state_c: pass sampler.stop() elapsed_time = time.time() - start_time state_transition_count = sampler.get_info().transition_count overhead_us = 1000000.0 * elapsed_time / state_transition_count # TODO: This test is flaky when it is run under load. A better solution # would be to change the test structure to not depend on specific timings. overhead_us = 2 * overhead_us _LOGGER.info('Overhead per transition: %fus', overhead_us) # Conservative upper bound on overhead in microseconds (we expect this to # take 0.17us when compiled in opt mode or 0.48 us when compiled with in # debug mode). self.assertLess(overhead_us, 10.0)
def execute_map_tasks(self, ordered_map_tasks): tt = time.time() for ix, (_, map_task) in enumerate(ordered_map_tasks): logging.info('Running %s', map_task) t = time.time() stage_names, all_operations = zip(*map_task) # TODO(robertwb): The DataflowRunner worker receives system step names # (e.g. "s3") that are used to label the output msec counters. We use the # operation names here, but this is not the same scheme used by the # DataflowRunner; the result is that the output msec counters are named # differently. system_names = stage_names # Create the CounterFactory and StateSampler for this MapTask. # TODO(robertwb): Output counters produced here are currently ignored. counter_factory = CounterFactory() state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory) map_executor = operations.SimpleMapTaskExecutor( operation_specs.MapTask(all_operations, 'S%02d' % ix, system_names, stage_names, system_names), counter_factory, state_sampler) self.executors.append(map_executor) map_executor.execute() logging.info('Stage %s finished: %0.3f sec', stage_names[0], time.time() - t) logging.info('Total time: %0.3f sec', time.time() - tt)
def test_sampler_transition_overhead(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler('overhead-', counter_factory, sampling_period_ms=10) # Run basic workload transitioning between 3 states. state_a = sampler.scoped_state('step1', 'statea') state_b = sampler.scoped_state('step1', 'stateb') state_c = sampler.scoped_state('step1', 'statec') start_time = time.time() sampler.start() for _ in range(100000): with state_a: with state_b: for _ in range(10): with state_c: pass sampler.stop() elapsed_time = time.time() - start_time state_transition_count = sampler.get_info().transition_count overhead_us = 1000000.0 * elapsed_time / state_transition_count logging.info('Overhead per transition: %fus', overhead_us) # Conservative upper bound on overhead in microseconds (we expect this to # take 0.17us when compiled in opt mode or 0.48 us when compiled with in # debug mode). self.assertLess(overhead_us, 10.0)
def run_benchmark(num_runs=50, input_per_source=4000, num_sources=4): print("Number of runs:", num_runs) print("Input size:", num_sources * input_per_source) print("Sources:", num_sources) times = [] for i in range(num_runs): counter_factory = CounterFactory() state_sampler = statesampler.StateSampler('basic', counter_factory) with state_sampler.scoped_state('step1', 'state'): si_counter = opcounters.SideInputReadCounter( counter_factory, state_sampler, 'step1', 1) si_counter = opcounters.NoOpTransformIOCounter() sources = [ FakeSource(long_generator(i, input_per_source)) for i in range(num_sources)] iterator_fn = sideinputs.get_iterator_fn_for_sources( sources, read_counter=si_counter) start = time.time() list(iterator_fn()) time_cost = time.time() - start times.append(time_cost) print("Runtimes:", times) avg_runtime = sum(times) // len(times) print("Average runtime:", avg_runtime) print("Time per element:", avg_runtime // (input_per_source * num_sources))
def test_metrics(self): sampler = statesampler.StateSampler('', counters.CounterFactory()) statesampler.set_current_tracker(sampler) state1 = sampler.scoped_state( 'mystep', 'myState', metrics_container=MetricsContainer('mystep')) try: sampler.start() with state1: counter = MetricTests.base_metric_group.counter("my_counter") meter = MetricTests.base_metric_group.meter("my_meter") distribution = MetricTests.base_metric_group.distribution("my_distribution") container = MetricsEnvironment.current_container() self.assertEqual(0, counter.get_count()) self.assertEqual(0, meter.get_count()) self.assertEqual( DistributionData( 0, 0, 0, 0), container.get_distribution( MetricName( '[]', 'my_distribution')).get_cumulative()) counter.inc(-2) meter.mark_event(3) distribution.update(10) distribution.update(2) self.assertEqual(-2, counter.get_count()) self.assertEqual(3, meter.get_count()) self.assertEqual( DistributionData( 12, 2, 2, 10), container.get_distribution( MetricName( '[]', 'my_distribution')).get_cumulative()) finally: sampler.stop()
def test_basic_sampler(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler( 'basic', counter_factory, sampling_period_ms=1) # Duration of the fastest state. Total test duration is 6 times longer. state_duration_ms = 1000 margin_of_error = 0.25 # Run basic workload transitioning between 3 states. sampler.start() with sampler.scoped_state('step1', 'statea'): time.sleep(state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName('statea-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'stateb'): time.sleep(state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName('stateb-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'statec'): time.sleep(3 * state_duration_ms / 1000) self.assertEqual( sampler.current_state().name, CounterName( 'statec-msecs', step_name='step1', stage_name='basic')) time.sleep(state_duration_ms / 1000) sampler.stop() sampler.commit_counters() if not statesampler.FAST_SAMPLER: # The slow sampler does not implement sampling, so we won't test it. return # Test that sampled state timings are close to their expected values. # yapf: disable expected_counter_values = { CounterName('statea-msecs', step_name='step1', stage_name='basic'): state_duration_ms, CounterName('stateb-msecs', step_name='step1', stage_name='basic'): 2 * state_duration_ms, CounterName('statec-msecs', step_name='step1', stage_name='basic'): 3 * state_duration_ms, } # yapf: enable for counter in counter_factory.get_counters(): self.assertIn(counter.name, expected_counter_values) expected_value = expected_counter_values[counter.name] actual_value = counter.value() deviation = float(abs(actual_value - expected_value)) / expected_value _LOGGER.info('Sampling deviation from expectation: %f', deviation) self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error)) self.assertLess(actual_value, expected_value * (1.0 + margin_of_error))
def __init__( self, process_bundle_descriptor, state_handler, data_channel_factory): self.process_bundle_descriptor = process_bundle_descriptor self.state_handler = state_handler self.data_channel_factory = data_channel_factory # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. self.counter_factory = counters.CounterFactory() self.state_sampler = statesampler.StateSampler( 'fnapi-step-%s' % self.process_bundle_descriptor.id, self.counter_factory) self.ops = self.create_execution_tree(self.process_bundle_descriptor)
def run(self): state_sampler = statesampler.StateSampler('', counters.CounterFactory()) statesampler.set_current_tracker(state_sampler) while not self.shutdown_requested: task = self._get_task_or_none() if task: try: if not self.shutdown_requested: self._update_name(task) task.call(state_sampler) self._update_name() finally: self.queue.task_done()
def test_record_with_per_thread_info(self): self.maxDiff = None tracker = statesampler.StateSampler('stage', CounterFactory()) statesampler.set_current_tracker(tracker) formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') with logger.PerThreadLoggingContext(work_item_id='workitem'): with tracker.scoped_state('step', 'process'): record = self.create_log_record(**self.SAMPLE_RECORD) log_output = json.loads(formatter.format(record)) expected_output = dict(self.SAMPLE_OUTPUT) expected_output.update( {'work': 'workitem', 'stage': 'stage', 'step': 'step'}) self.assertEqual(log_output, expected_output) statesampler.set_current_tracker(None)
def test_basic_sampler(self): # Set up state sampler. counter_factory = CounterFactory() sampler = statesampler.StateSampler('basic', counter_factory, sampling_period_ms=1) # Run basic workload transitioning between 3 states. sampler.start() with sampler.scoped_state('step1', 'statea'): time.sleep(0.1) self.assertEqual( sampler.current_state().name, CounterName( 'statea-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'stateb'): time.sleep(0.2 / 2) self.assertEqual( sampler.current_state().name, CounterName( 'stateb-msecs', step_name='step1', stage_name='basic')) with sampler.scoped_state('step1', 'statec'): time.sleep(0.3) self.assertEqual( sampler.current_state().name, CounterName( 'statec-msecs', step_name='step1', stage_name='basic')) time.sleep(0.2 / 2) sampler.stop() sampler.commit_counters() if not statesampler.FAST_SAMPLER: # The slow sampler does not implement sampling, so we won't test it. return # Test that sampled state timings are close to their expected values. expected_counter_values = { CounterName('statea-msecs', step_name='step1', stage_name='basic'): 100, CounterName('stateb-msecs', step_name='step1', stage_name='basic'): 200, CounterName('statec-msecs', step_name='step1', stage_name='basic'): 300, } for counter in counter_factory.get_counters(): self.assertIn(counter.name, expected_counter_values) expected_value = expected_counter_values[counter.name] actual_value = counter.value() deviation = float(abs(actual_value - expected_value)) / expected_value logging.info('Sampling deviation from expectation: %f', deviation) self.assertGreater(actual_value, expected_value * 0.75) self.assertLess(actual_value, expected_value * 1.25)
def test_basic_counters(self): counter_factory = CounterFactory() sampler = statesampler.StateSampler('stage1', counter_factory) sampler.start() with sampler.scoped_state('step1', 'stateA'): counter = opcounters.SideInputReadCounter(counter_factory, sampler, declaring_step='step1', input_index=1) with sampler.scoped_state('step2', 'stateB'): with counter: counter.add_bytes_read(10) counter.update_current_step() sampler.stop() sampler.commit_counters() actual_counter_names = set( [c.name for c in counter_factory.get_counters()]) expected_counter_names = set([ # Counter names for STEP 1 counters.CounterName('read-sideinput-msecs', stage_name='stage1', step_name='step1', io_target=counters.side_input_id('step1', 1)), counters.CounterName('read-sideinput-byte-count', step_name='step1', io_target=counters.side_input_id('step1', 1)), # Counter names for STEP 2 counters.CounterName('read-sideinput-msecs', stage_name='stage1', step_name='step1', io_target=counters.side_input_id('step2', 1)), counters.CounterName('read-sideinput-byte-count', step_name='step1', io_target=counters.side_input_id('step2', 1)), ]) self.assertTrue( actual_counter_names.issuperset(expected_counter_names))
def create_execution_tree(self, descriptor): # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. counter_factory = counters.CounterFactory() state_sampler = statesampler.StateSampler( 'fnapi-step%s' % descriptor.id, counter_factory) transform_factory = BeamTransformFactory(descriptor, self.data_channel_factory, counter_factory, state_sampler, self.state_handler) pcoll_consumers = collections.defaultdict(list) for transform_id, transform_proto in descriptor.transforms.items(): for pcoll_id in transform_proto.inputs.values(): pcoll_consumers[pcoll_id].append(transform_id) @memoize def get_operation(transform_id): transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, pcoll_id in descriptor.transforms[transform_id].outputs.items() } return transform_factory.create_operation(transform_id, transform_consumers) # Operations must be started (hence returned) in order. @memoize def topological_height(transform_id): return 1 + max([0] + [ topological_height(consumer) for pcoll in descriptor.transforms[transform_id].outputs.values() for consumer in pcoll_consumers[pcoll] ]) return [ get_operation(transform_id) for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True) ]
def __init__( self, process_bundle_descriptor, state_handler, data_channel_factory): """Initialize a bundle processor. Args: process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``): a description of the stage that this ``BundleProcessor``is to execute. state_handler (beam_fn_api_pb2_grpc.BeamFnStateServicer). data_channel_factory (``data_plane.DataChannelFactory``). """ self.process_bundle_descriptor = process_bundle_descriptor self.state_handler = state_handler self.data_channel_factory = data_channel_factory # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. self.counter_factory = counters.CounterFactory() self.state_sampler = statesampler.StateSampler( 'fnapi-step-%s' % self.process_bundle_descriptor.id, self.counter_factory) self.ops = self.create_execution_tree(self.process_bundle_descriptor) for op in self.ops.values(): op.setup() self.splitting_lock = threading.Lock()
def create_execution_tree(self, descriptor): # TODO(vikasrk): Add an id field to Coder proto and use that instead. coders = {coder.function_spec.id: operation_specs.get_coder_from_spec( json.loads(unpack_function_spec_data(coder.function_spec))) for coder in descriptor.coders} counter_factory = counters.CounterFactory() # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. state_sampler = statesampler.StateSampler( 'fnapi-step%s-' % descriptor.id, counter_factory) consumers = collections.defaultdict(lambda: collections.defaultdict(list)) ops_by_id = {} reversed_ops = [] for transform in reversed(descriptor.primitive_transform): # TODO(robertwb): Figure out how to plumb through the operation name (e.g. # "s3") from the service through the FnAPI so that msec counters can be # reported and correctly plumbed through the service and the UI. operation_name = 'fnapis%s' % transform.id def only_element(iterable): element, = iterable return element if transform.function_spec.urn == DATA_OUTPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.outputs.keys())) op = DataOutputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element(transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == DATA_INPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.inputs.keys())) op = DataInputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element(transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == PYTHON_DOFN_URN: def create_side_input(tag, si): # TODO(robertwb): Extract windows (and keys) out of element data. return operation_specs.WorkerSideInputSource( tag=tag, source=SideInputSource( self.state_handler, beam_fn_api_pb2.StateKey( function_spec_reference=si.view_fn.id), coder=unpack_and_deserialize_py_fn(si.view_fn))) output_tags = list(transform.outputs.keys()) spec = operation_specs.WorkerDoFn( serialized_fn=unpack_function_spec_data(transform.function_spec), output_tags=output_tags, input=None, side_inputs=[create_side_input(tag, si) for tag, si in transform.side_inputs.items()], output_coders=[coders[transform.outputs[out].coder_reference] for out in output_tags]) op = operations.DoOperation(operation_name, spec, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver( consumer, output_tags.index(tag)) elif transform.function_spec.urn == IDENTITY_DOFN_URN: op = operations.FlattenOperation(operation_name, None, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver(consumer, 0) elif transform.function_spec.urn == PYTHON_SOURCE_URN: source = load_compressed(unpack_function_spec_data( transform.function_spec)) # TODO(vikasrk): Remove this once custom source is implemented with # splittable dofn via the data plane. spec = operation_specs.WorkerRead( iobase.SourceBundle(1.0, source, None, None), [WindowedValueCoder(source.default_output_coder())]) op = operations.ReadOperation(operation_name, spec, counter_factory, state_sampler) op.step_name = transform.step_name output_tags = list(transform.outputs.keys()) for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver( consumer, output_tags.index(tag)) else: raise NotImplementedError # Record consumers. for _, inputs in transform.inputs.items(): for target in inputs.target: consumers[target.primitive_transform_reference][target.name].append( op) reversed_ops.append(op) ops_by_id[transform.id] = op return list(reversed(reversed_ops)), ops_by_id