def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) if hasattr(spec.serialized_fn, "key_type"): # keyed operation, need to create the KeyedStateBackend. key_row_coder = from_proto(spec.serialized_fn.key_type) keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, spec.serialized_fn.state_cache_size, spec.serialized_fn.map_state_read_cache_size, spec.serialized_fn.map_state_write_cache_size) return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) else: return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls)
def create(factory, transform_id, transform_proto, parameter, consumers): dofn_data = pickler.loads(parameter.value) if len(dofn_data) == 2: # Has side input data. serialized_fn, side_input_data = dofn_data else: # No side input data. serialized_fn, side_input_data = parameter.value, [] def create_side_input(tag, coder): # TODO(robertwb): Extract windows (and keys) out of element data. # TODO(robertwb): Extract state key from ParDoPayload. return operation_specs.WorkerSideInputSource( tag=tag, source=SideInputSource(factory.state_handler, beam_fn_api_pb2.StateKey.MultimapSideInput( key=side_input_tag(transform_id, tag)), coder=coder)) output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=serialized_fn, output_tags=output_tags, input=None, side_inputs=[ create_side_input(tag, coder) for tag, coder in side_input_data ], output_coders=[output_coders[tag] for tag in output_tags]) return factory.augment_oldstyle_op( operations.DoOperation(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler), transform_proto.unique_name, consumers, output_tags)
def _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn, side_inputs_proto=None): if side_inputs_proto: input_tags_to_coders = factory.get_input_coders(transform_proto) tagged_side_inputs = [ (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) for tag, si in side_inputs_proto.items()] tagged_side_inputs.sort(key=lambda tag_si: int(tag_si[0][4:])) side_input_maps = [ StateBackedSideInputMap( factory.state_handler, transform_id, tag, si, input_tags_to_coders[tag]) for tag, si in tagged_side_inputs] else: side_input_maps = [] output_tags = list(transform_proto.outputs.keys()) # Hack to match out prefix injected by dataflow runner. def mutate_tag(tag): if 'None' in output_tags: if tag == 'None': return 'out' else: return 'out_' + tag else: return tag dofn_data = pickler.loads(serialized_fn) if not dofn_data[-1]: # Windowing not set. side_input_tags = side_inputs_proto or () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() if tag not in side_input_tags] windowing = factory.context.windowing_strategies.get_by_id( factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,)) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=serialized_fn, output_tags=[mutate_tag(tag) for tag in output_tags], input=None, side_inputs=None, # Fn API uses proto definitions and the Fn State API output_coders=[output_coders[tag] for tag in output_tags]) return factory.augment_oldstyle_op( operations.DoOperation( transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, side_input_maps), transform_proto.unique_name, consumers, output_tags)
def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) name = common.NameContext(transform_proto.unique_name) serialized_fn = spec.serialized_fn if hasattr(serialized_fn, "key_type"): # keyed operation, need to create the KeyedStateBackend. row_schema = serialized_fn.key_type.row_schema key_row_coder = FlattenRowCoder( [from_proto(f.type) for f in row_schema.fields]) if serialized_fn.HasField('group_window'): if serialized_fn.group_window.is_time_window: window_coder = TimeWindowCoder() else: window_coder = CountWindowCoder() else: window_coder = None keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, window_coder, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) elif internal_operation_cls == datastream_operations.StatefulOperation: key_row_coder = from_type_info_proto(serialized_fn.key_type_info) keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, None, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) else: return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls)
def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) return operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers)
def _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn, side_input_data): def create_side_input(tag, coder): # TODO(robertwb): Extract windows (and keys) out of element data. # TODO(robertwb): Extract state key from ParDoPayload. return operation_specs.WorkerSideInputSource( tag=tag, source=SideInputSource( factory.state_handler, beam_fn_api_pb2.StateKey.MultimapSideInput( key=side_input_tag(transform_id, tag)), coder=coder)) output_tags = list(transform_proto.outputs.keys()) # Hack to match out prefix injected by dataflow runner. def mutate_tag(tag): if 'None' in output_tags: if tag == 'None': return 'out' else: return 'out_' + tag else: return tag dofn_data = pickler.loads(serialized_fn) if not dofn_data[-1]: # Windowing not set. pcoll_id, = transform_proto.inputs.values() windowing = factory.context.windowing_strategies.get_by_id( factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,)) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=serialized_fn, output_tags=[mutate_tag(tag) for tag in output_tags], input=None, side_inputs=[ create_side_input(tag, coder) for tag, coder in side_input_data], output_coders=[output_coders[tag] for tag in output_tags]) return factory.augment_oldstyle_op( operations.DoOperation( transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler), transform_proto.unique_name, consumers, output_tags)
def _create_stateful_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_coders]) key_type_info = spec.serialized_fn.key_type_info key_row_coder = from_type_info_proto(key_type_info.field[0].type) keyed_state_backend = RemoteKeyedStateBackend(factory.state_handler, key_row_coder, 1000, 1000, 1000) return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend)
def _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn, pardo_proto=None, operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) tagged_side_inputs = [ (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context)) for tag, si in pardo_proto.side_inputs.items()] tagged_side_inputs.sort( key=lambda tag_si: int(re.match('side([0-9]+)(-.*)?$', tag_si[0]).group(1))) side_input_maps = [ StateBackedSideInputMap( factory.state_handler, transform_id, tag, si, input_tags_to_coders[tag]) for tag, si in tagged_side_inputs] else: side_input_maps = [] output_tags = list(transform_proto.outputs.keys()) # Hack to match out prefix injected by dataflow runner. def mutate_tag(tag): if 'None' in output_tags: if tag == 'None': return 'out' else: return 'out_' + tag else: return tag dofn_data = pickler.loads(serialized_fn) if not dofn_data[-1]: # Windowing not set. if pardo_proto: other_input_tags = set.union( set(pardo_proto.side_inputs), set(pardo_proto.timer_specs)) else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() if tag not in other_input_tags] windowing = factory.context.windowing_strategies.get_by_id( factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,)) if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs or pardo_proto.splittable): main_input_coder = None timer_inputs = {} for tag, pcoll_id in transform_proto.inputs.items(): if tag in pardo_proto.timer_specs: timer_inputs[tag] = pcoll_id elif tag in pardo_proto.side_inputs: pass else: # Must be the main input assert main_input_coder is None main_input_tag = tag main_input_coder = factory.get_windowed_coder(pcoll_id) assert main_input_coder is not None if pardo_proto.timer_specs or pardo_proto.state_specs: user_state_context = FnApiUserStateContext( factory.state_handler, transform_id, main_input_coder.key_coder(), main_input_coder.window_coder, timer_specs=pardo_proto.timer_specs) else: user_state_context = None else: user_state_context = None timer_inputs = None output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=serialized_fn, output_tags=[mutate_tag(tag) for tag in output_tags], input=None, side_inputs=None, # Fn API uses proto definitions and the Fn State API output_coders=[output_coders[tag] for tag in output_tags]) result = factory.augment_oldstyle_op( operation_cls( transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, side_input_maps, user_state_context, timer_inputs=timer_inputs), transform_proto.unique_name, consumers, output_tags) if pardo_proto and pardo_proto.splittable: result.input_info = ( transform_id, main_input_tag, main_input_coder, transform_proto.outputs.keys()) return result
def create_execution_tree(self, descriptor): # TODO(vikasrk): Add an id field to Coder proto and use that instead. coders = {coder.function_spec.id: operation_specs.get_coder_from_spec( json.loads(unpack_function_spec_data(coder.function_spec))) for coder in descriptor.coders} counter_factory = counters.CounterFactory() # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. state_sampler = statesampler.StateSampler( 'fnapi-step%s-' % descriptor.id, counter_factory) consumers = collections.defaultdict(lambda: collections.defaultdict(list)) ops_by_id = {} reversed_ops = [] for transform in reversed(descriptor.primitive_transform): # TODO(robertwb): Figure out how to plumb through the operation name (e.g. # "s3") from the service through the FnAPI so that msec counters can be # reported and correctly plumbed through the service and the UI. operation_name = 'fnapis%s' % transform.id def only_element(iterable): element, = iterable return element if transform.function_spec.urn == DATA_OUTPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.outputs.keys())) op = DataOutputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element(transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == DATA_INPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.inputs.keys())) op = DataInputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element(transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == PYTHON_DOFN_URN: def create_side_input(tag, si): # TODO(robertwb): Extract windows (and keys) out of element data. return operation_specs.WorkerSideInputSource( tag=tag, source=SideInputSource( self.state_handler, beam_fn_api_pb2.StateKey( function_spec_reference=si.view_fn.id), coder=unpack_and_deserialize_py_fn(si.view_fn))) output_tags = list(transform.outputs.keys()) spec = operation_specs.WorkerDoFn( serialized_fn=unpack_function_spec_data(transform.function_spec), output_tags=output_tags, input=None, side_inputs=[create_side_input(tag, si) for tag, si in transform.side_inputs.items()], output_coders=[coders[transform.outputs[out].coder_reference] for out in output_tags]) op = operations.DoOperation(operation_name, spec, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver( consumer, output_tags.index(tag)) elif transform.function_spec.urn == IDENTITY_DOFN_URN: op = operations.FlattenOperation(operation_name, None, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver(consumer, 0) elif transform.function_spec.urn == PYTHON_SOURCE_URN: source = load_compressed(unpack_function_spec_data( transform.function_spec)) # TODO(vikasrk): Remove this once custom source is implemented with # splittable dofn via the data plane. spec = operation_specs.WorkerRead( iobase.SourceBundle(1.0, source, None, None), [WindowedValueCoder(source.default_output_coder())]) op = operations.ReadOperation(operation_name, spec, counter_factory, state_sampler) op.step_name = transform.step_name output_tags = list(transform.outputs.keys()) for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver( consumer, output_tags.index(tag)) else: raise NotImplementedError # Record consumers. for _, inputs in transform.inputs.items(): for target in inputs.target: consumers[target.primitive_transform_reference][target.name].append( op) reversed_ops.append(op) ops_by_id[transform.id] = op return list(reversed(reversed_ops)), ops_by_id
def run_ParDo(self, transform_node): transform = transform_node.transform output = transform_node.outputs[None] element_coder = self._get_coder(output) map_task_index, producer_index, output_index = self.outputs[ transform_node.inputs[0]] # If any of this ParDo's side inputs depend on outputs from this map_task, # we can't continue growing this map task. def is_reachable(leaf, root): if leaf == root: return True else: return any( is_reachable(x, root) for x in self.dependencies[leaf]) if any( is_reachable(self.outputs[side_input.pvalue][0], map_task_index) for side_input in transform_node.side_inputs): # Start a new map tasks. input_element_coder = self._get_coder(transform_node.inputs[0]) output_buffer = OutputBuffer(input_element_coder) fusion_break_write = operation_specs.WorkerInMemoryWrite( output_buffer=output_buffer, write_windowed_values=True, input=(producer_index, output_index), output_coders=[input_element_coder]) self.map_tasks[map_task_index].append( (transform_node.full_label + '/Write', fusion_break_write)) original_map_task_index = map_task_index map_task_index, producer_index, output_index = len( self.map_tasks), 0, 0 fusion_break_read = operation_specs.WorkerRead( output_buffer.source_bundle(), output_coders=[input_element_coder]) self.map_tasks.append([(transform_node.full_label + '/Read', fusion_break_read)]) self.dependencies[map_task_index].add(original_map_task_index) def create_side_read(side_input): label = self.side_input_labels[side_input] output_buffer = self.run_side_write( side_input.pvalue, '%s/%s' % (transform_node.full_label, label)) return operation_specs.WorkerSideInputSource( output_buffer.source(), label) do_op = operation_specs.WorkerDoFn( # serialized_fn=pickler.dumps( DataflowRunner._pardo_fn_data( transform_node, lambda side_input: self.side_input_labels[side_input])), output_tags=[PropertyNames.OUT] + [ '%s_%s' % (PropertyNames.OUT, tag) for tag in transform.output_tags ], # Same assumption that DataflowRunner has about coders being compatible # across outputs. output_coders=[element_coder] * (len(transform.output_tags) + 1), input=(producer_index, output_index), side_inputs=[ create_side_read(side_input) for side_input in transform_node.side_inputs ]) producer_index = len(self.map_tasks[map_task_index]) self.outputs[transform_node.outputs[None]] = (map_task_index, producer_index, 0) for ix, tag in enumerate(transform.output_tags): self.outputs[transform_node. outputs[tag]] = map_task_index, producer_index, ix + 1 self.map_tasks[map_task_index].append( (transform_node.full_label, do_op)) for side_input in transform_node.side_inputs: self.dependencies[map_task_index].add( self.outputs[side_input.pvalue][0])