def process_bundle(self, inputs, expected_outputs): # Unique id for the instruction processing this bundle. BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter # Register the bundle descriptor, if needed. if self._registered: registration_future = None else: process_bundle_registration = beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[self._bundle_descriptor])) registration_future = self._controller.control_handler.push( process_bundle_registration) self._registered = True # Write all the input data to the channel. for (transform_id, name), elements in inputs.items(): data_out = self._controller.data_plane_handler.output_stream( process_bundle_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in elements: data_out.write(element_data) data_out.close() # Actually start the bundle. if registration_future and registration_future.get().error: raise RuntimeError(registration_future.get().error) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=self._bundle_descriptor.id)) result_future = self._controller.control_handler.push(process_bundle) with ProgressRequester( self._controller, process_bundle_id, self._progress_frequency): # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in expected_outputs.items()] logging.debug('Gather all output data from %s.', expected_targets) for output in self._controller.data_plane_handler.input_elements( process_bundle_id, expected_targets, abort_callback=lambda: (result_future.is_done() and result_future.get().error)): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple in expected_outputs: self._get_buffer(expected_outputs[target_tuple]).append(output.data) logging.debug('Wait for the bundle to finish.') result = result_future.get() if result.error: raise RuntimeError(result.error) return result
def _run_map_task(self, map_task, control_handler, state_handler, data_plane_handler, data_operation_spec): registration, sinks, input_data = self._map_task_registration( map_task, state_handler, data_operation_spec) control_handler.push(registration) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=registration.register. process_bundle_descriptor[0].id)) for (transform_id, name), elements in input_data.items(): data_out = data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) data_out.write(elements) data_out.close() control_handler.push(process_bundle) while True: result = control_handler.pull() if result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) expected_targets = [ beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in sinks.items() ] for output in data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple not in sinks: # Unconsumed output. continue sink_op = sinks[target_tuple] coder = sink_op.output_coders[0] input_stream = create_InputStream(output.data) elements = [] while input_stream.size() > 0: elements.append(coder.get_impl().decode_from_stream( input_stream, True)) if not sink_op.write_windowed_values: elements = [e.value for e in elements] for e in elements: sink_op.output_buffer.append(e) return
def create(factory, transform_id, transform_proto, grpc_port, consumers): # Timers are the one special case where we don't want to call the # (unlabeled) operation.process() method, which we detect here. # TODO(robertwb): Consider generalizing if there are any more cases. output_pcoll = only_element(transform_proto.outputs.values()) output_consumers = only_element(consumers.values()) if (len(output_consumers) == 1 and isinstance(only_element(output_consumers), operations.DoOperation)): do_op = only_element(output_consumers) for tag, pcoll_id in do_op.timer_inputs.items(): if pcoll_id == output_pcoll: output_consumers[:] = [TimerConsumer(tag, do_op)] break target = beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=only_element(list(transform_proto.outputs.keys()))) if grpc_port.coder_id: output_coder = factory.get_coder(grpc_port.coder_id) else: logging.error( 'Missing required coder_id on grpc_port for %s; ' 'using deprecated fallback.', transform_id) output_coder = factory.get_only_output_coder(transform_proto) return DataInputOperation( transform_proto.unique_name, transform_proto.unique_name, consumers, factory.counter_factory, factory.state_sampler, output_coder, input_target=target, data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
def as_target(op_input): input_op_index, input_output_index = op_input input_op = map_task[input_op_index][1] return { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference=transform_index_to_id[ input_op_index], name=output_tags(input_op)[input_output_index]) ]) }
def create(factory, transform_id, transform_proto, grpc_port, consumers): target = beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=only_element(transform_proto.outputs.keys())) return DataInputOperation( transform_proto.unique_name, transform_proto.unique_name, consumers, factory.counter_factory, factory.state_sampler, factory.get_only_output_coder(transform_proto), input_target=target, data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
def create(factory, transform_id, transform_proto, grpc_port, consumers): target = beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=only_element(transform_proto.inputs.keys())) return DataOutputOperation( transform_proto.unique_name, transform_proto.unique_name, consumers, factory.counter_factory, factory.state_sampler, # TODO(robertwb): Perhaps this could be distinct from the input coder? factory.get_only_input_coder(transform_proto), target=target, data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
def _data_channel_test_one_direction(self, from_channel, to_channel): def send(instruction_id, target, data): stream = from_channel.output_stream(instruction_id, target) stream.write(data) stream.close() target_1 = beam_fn_api_pb2.Target(primitive_transform_reference='1', name='out') target_2 = beam_fn_api_pb2.Target(primitive_transform_reference='2', name='out') # Single write. send('0', target_1, 'abc') self.assertEqual(list(to_channel.input_elements('0', [target_1])), [ beam_fn_api_pb2.Elements.Data( instruction_reference='0', target=target_1, data='abc') ]) # Multiple interleaved writes to multiple instructions. target_2 = beam_fn_api_pb2.Target(primitive_transform_reference='2', name='out') send('1', target_1, 'abc') send('2', target_1, 'def') self.assertEqual(list(to_channel.input_elements('1', [target_1])), [ beam_fn_api_pb2.Elements.Data( instruction_reference='1', target=target_1, data='abc') ]) send('2', target_2, 'ghi') self.assertEqual( list(to_channel.input_elements('2', [target_1, target_2])), [ beam_fn_api_pb2.Elements.Data( instruction_reference='2', target=target_1, data='def'), beam_fn_api_pb2.Elements.Data( instruction_reference='2', target=target_2, data='ghi') ])
def create(factory, transform_id, transform_proto, grpc_port, consumers): target = beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=only_element(list(transform_proto.inputs.keys()))) if grpc_port.coder_id: output_coder = factory.get_coder(grpc_port.coder_id) else: logging.error( 'Missing required coder_id on grpc_port for %s; ' 'using deprecated fallback.', transform_id) output_coder = factory.get_only_input_coder(transform_proto) return DataOutputOperation( transform_proto.unique_name, transform_proto.unique_name, consumers, factory.counter_factory, factory.state_sampler, output_coder, target=target, data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
def run_stage(self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): context = pipeline_context.PipelineContext(pipeline_components) data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): pcoll_id = transform.spec.payload if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element( transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element( transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString( ) else: transform.spec.payload = "" elif transform.spec.urn == urns.PARDO_TRANSFORM: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): data_side_input[transform.unique_name, tag] = ( 'materialize:' + transform.inputs[tag], beam.pvalue.SideInputData.from_runner_api( si, None)) return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={ transform.unique_name: transform for transform in stage.transforms }, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=process_bundle_descriptor. id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Store the required side inputs into state. for (transform_id, tag), (pcoll_id, si) in data_side_input.items(): elements_by_window = _WindowGroupingBuffer(si) for element_data in pcoll_buffers[pcoll_id]: elements_by_window.append(element_data) for window, elements_data in elements_by_window.items(): state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey. MultimapSideInput(ptransform_id=transform_id, side_input_id=tag, window=window)) controller.state_handler.blocking_append( state_key, elements_data, process_bundle.instruction_id) # Register and start running the bundle. logging.debug('Register and start running the bundle') controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. logging.debug('Wait for the bundle to finish.') while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items() ] # Gather all output data. logging.debug('Gather all output data from %s.', expected_targets) for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = (output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element( transform_proto.inputs.values()) output_pcoll = only_element( transform_proto.outputs.values()) pre_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[input_pcoll].coder_id]] post_gbk_coder = context.coders[ safe_coders[pipeline_components. pcollections[output_pcoll].coder_id]] windowing_strategy = context.windowing_strategies[ pipeline_components.pcollections[output_pcoll]. windowing_strategy_id] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder, windowing_strategy) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id) return result
def _map_task_registration(self, map_task, state_handler, data_operation_spec): input_data = {} runner_sinks = {} transforms = [] transform_index_to_id = {} # Maps coders to new coder objects and references. coders = {} def coder_id(coder): if coder not in coders: coders[coder] = beam_fn_api_pb2.Coder( function_spec=sdk_worker.pack_function_spec_data( json.dumps(coder.as_cloud_object()), sdk_worker.PYTHON_CODER_URN, id=self._next_uid())) return coders[coder].function_spec.id def output_tags(op): return getattr(op, 'output_tags', ['out']) def as_target(op_input): input_op_index, input_output_index = op_input input_op = map_task[input_op_index][1] return { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference=transform_index_to_id[ input_op_index], name=output_tags(input_op)[input_output_index]) ]) } def outputs(op): return { tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder)) for tag, coder in zip(output_tags(op), op.output_coders) } for op_ix, (stage_name, operation) in enumerate(map_task): transform_id = transform_index_to_id[op_ix] = self._next_uid() if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_OUTPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = as_target(operation.input) side_inputs = {} runner_sinks[(transform_id, 'out')] = operation elif isinstance(operation, operation_specs.WorkerRead): # A Read is either translated to a direct injection of windowed values # into the sdk worker, or an injection of the source object into the # sdk worker as data followed by an SDF that reads that source. if (isinstance(operation.source.source, maptask_executor_runner.InMemorySource) and isinstance( operation.source.source.default_output_coder(), WindowedValueCoder)): output_stream = create_OutputStream() element_coder = (operation.source.source. default_output_coder().get_impl()) # Re-encode the elements in the nested context and # concatenate them together for element in operation.source.source.read(None): element_coder.encode_to_stream(element, output_stream, True) target_name = self._next_uid() input_data[(transform_id, target_name)] = output_stream.get() fn = beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()) if data_operation_spec: fn.data.Pack(data_operation_spec) inputs = {target_name: beam_fn_api_pb2.Target.List()} side_inputs = {} else: # Read the source object from the runner. source_coder = beam.coders.DillCoder() input_transform_id = self._next_uid() output_stream = create_OutputStream() source_coder.get_impl().encode_to_stream( GlobalWindows.windowed_value(operation.source), output_stream, True) target_name = self._next_uid() input_data[(input_transform_id, target_name)] = output_stream.get() input_ptransform = beam_fn_api_pb2.PrimitiveTransform( id=input_transform_id, function_spec=beam_fn_api_pb2.FunctionSpec( urn=sdk_worker.DATA_INPUT_URN, id=self._next_uid()), # TODO(robertwb): Possible name collision. step_name=stage_name + '/inject_source', inputs={target_name: beam_fn_api_pb2.Target.List()}, outputs={ 'out': beam_fn_api_pb2.PCollection( coder_reference=coder_id(source_coder)) }) if data_operation_spec: input_ptransform.function_spec.data.Pack( data_operation_spec) transforms.append(input_ptransform) # Read the elements out of the source. fn = sdk_worker.pack_function_spec_data( OLDE_SOURCE_SPLITTABLE_DOFN_DATA, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= input_transform_id, name='out') ]) } side_inputs = {} elif isinstance(operation, operation_specs.WorkerDoFn): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.PYTHON_DOFN_URN, id=self._next_uid()) inputs = as_target(operation.input) # Store the contents of each side input for state access. for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() view_id = self._next_uid() # TODO(robertwb): Actually flesh out the ViewFn API. side_inputs[si.tag] = beam_fn_api_pb2.SideInput( view_fn=sdk_worker.serialize_and_pack_py_fn( element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN, id=view_id)) # Re-encode the elements in the nested context and # concatenate them together output_stream = create_OutputStream() for element in si.source.read( si.source.get_range_tracker(None, None)): element_coder.get_impl().encode_to_stream( element, output_stream, True) elements_data = output_stream.get() state_key = beam_fn_api_pb2.StateKey.MultimapSideInput( key=view_id) state_handler.Clear(state_key) state_handler.Append(state_key, elements_data) elif isinstance(operation, operation_specs.WorkerFlatten): fn = sdk_worker.pack_function_spec_data( operation.serialized_fn, sdk_worker.IDENTITY_DOFN_URN, id=self._next_uid()) inputs = { 'ignored_input_tag': beam_fn_api_pb2.Target.List(target=[ beam_fn_api_pb2.Target( primitive_transform_reference= transform_index_to_id[input_op_index], name=output_tags(map_task[input_op_index] [1])[input_output_index]) for input_op_index, input_output_index in operation.inputs ]) } side_inputs = {} else: raise TypeError(operation) ptransform = beam_fn_api_pb2.PrimitiveTransform( id=transform_id, function_spec=fn, step_name=stage_name, inputs=inputs, side_inputs=side_inputs, outputs=outputs(operation)) transforms.append(ptransform) process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), coders=coders.values(), primitive_transform=transforms) return beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor ])), runner_sinks, input_data
def run_stage( self, controller, pipeline_components, stage, pcoll_buffers, safe_coders): coders = pipeline_context.PipelineContext(pipeline_components).coders data_operation_spec = controller.data_operation_spec() def extract_endpoints(stage): # Returns maps of transform names to PCollection identifiers. # Also mutates IO stages to point to the data data_operation_spec. data_input = {} data_side_input = {} data_output = {} for transform in stage.transforms: pcoll_id = transform.spec.payload if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, bundle_processor.DATA_OUTPUT_URN): if transform.spec.urn == bundle_processor.DATA_INPUT_URN: target = transform.unique_name, only_element(transform.outputs) data_input[target] = pcoll_id elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: target = transform.unique_name, only_element(transform.inputs) data_output[target] = pcoll_id else: raise NotImplementedError if data_operation_spec: transform.spec.payload = data_operation_spec.SerializeToString() transform.spec.any_param.Pack(data_operation_spec) else: transform.spec.payload = "" transform.spec.any_param.Clear() return data_input, data_side_input, data_output logging.info('Running %s', stage.name) logging.debug(' %s', stage) data_input, data_side_input, data_output = extract_endpoints(stage) if data_side_input: raise NotImplementedError('Side inputs.') process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( id=self._next_uid(), transforms={transform.unique_name: transform for transform in stage.transforms}, pcollections=dict(pipeline_components.pcollections.items()), coders=dict(pipeline_components.coders.items()), windowing_strategies=dict( pipeline_components.windowing_strategies.items()), environments=dict(pipeline_components.environments.items())) process_bundle_registration = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=[process_bundle_descriptor])) process_bundle = beam_fn_api_pb2.InstructionRequest( instruction_id=self._next_uid(), process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference= process_bundle_descriptor.id)) # Write all the input data to the channel. for (transform_id, name), pcoll_id in data_input.items(): data_out = controller.data_plane_handler.output_stream( process_bundle.instruction_id, beam_fn_api_pb2.Target( primitive_transform_reference=transform_id, name=name)) for element_data in pcoll_buffers[pcoll_id]: data_out.write(element_data) data_out.close() # Register and start running the bundle. controller.control_handler.push(process_bundle_registration) controller.control_handler.push(process_bundle) # Wait for the bundle to finish. while True: result = controller.control_handler.pull() if result and result.instruction_id == process_bundle.instruction_id: if result.error: raise RuntimeError(result.error) break # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, name=output_name) for (transform_id, output_name), _ in data_output.items()] for output in controller.data_plane_handler.input_elements( process_bundle.instruction_id, expected_targets): target_tuple = ( output.target.primitive_transform_reference, output.target.name) if target_tuple in data_output: pcoll_id = data_output[target_tuple] if pcoll_id.startswith('materialize:'): # Just store the data chunks for replay. pcoll_buffers[pcoll_id].append(output.data) elif pcoll_id.startswith('group:'): # This is a grouping write, create a grouping buffer if needed. if pcoll_id not in pcoll_buffers: original_gbk_transform = pcoll_id.split(':', 1)[1] transform_proto = pipeline_components.transforms[ original_gbk_transform] input_pcoll = only_element(transform_proto.inputs.values()) output_pcoll = only_element(transform_proto.outputs.values()) pre_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[input_pcoll].coder_id]] post_gbk_coder = coders[safe_coders[ pipeline_components.pcollections[output_pcoll].coder_id]] pcoll_buffers[pcoll_id] = _GroupingBuffer( pre_gbk_coder, post_gbk_coder) pcoll_buffers[pcoll_id].append(output.data) else: # These should be the only two identifiers we produce for now, # but special side input writes may go here. raise NotImplementedError(pcoll_id)
def create_execution_tree_from_fn_api(self, descriptor): # TODO(vikasrk): Add an id field to Coder proto and use that instead. coders = { coder.function_spec.id: operation_specs.get_coder_from_spec( json.loads(unpack_function_spec_data(coder.function_spec))) for coder in descriptor.coders } counter_factory = counters.CounterFactory() # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. state_sampler = statesampler.StateSampler( 'fnapi-step%s-' % descriptor.id, counter_factory) consumers = collections.defaultdict( lambda: collections.defaultdict(list)) ops_by_id = {} reversed_ops = [] for transform in reversed(descriptor.primitive_transform): # TODO(robertwb): Figure out how to plumb through the operation name (e.g. # "s3") from the service through the FnAPI so that msec counters can be # reported and correctly plumbed through the service and the UI. operation_name = 'fnapis%s' % transform.id def only_element(iterable): element, = iterable return element if transform.function_spec.urn == DATA_OUTPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.outputs.keys())) op = DataOutputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element( transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == DATA_INPUT_URN: target = beam_fn_api_pb2.Target( primitive_transform_reference=transform.id, name=only_element(transform.inputs.keys())) op = DataInputOperation( operation_name, transform.step_name, consumers[transform.id], counter_factory, state_sampler, coders[only_element( transform.outputs.values()).coder_reference], target, self.data_channel_factory.create_data_channel( transform.function_spec)) elif transform.function_spec.urn == PYTHON_DOFN_URN: def create_side_input(tag, si): # TODO(robertwb): Extract windows (and keys) out of element data. return operation_specs.WorkerSideInputSource( tag=tag, source=SideInputSource( self.state_handler, beam_fn_api_pb2.StateKey.MultimapSideInput( key=si.view_fn.id.encode('utf-8')), coder=unpack_and_deserialize_py_fn(si.view_fn))) output_tags = list(transform.outputs.keys()) spec = operation_specs.WorkerDoFn( serialized_fn=unpack_function_spec_data( transform.function_spec), output_tags=output_tags, input=None, side_inputs=[ create_side_input(tag, si) for tag, si in transform.side_inputs.items() ], output_coders=[ coders[transform.outputs[out].coder_reference] for out in output_tags ]) op = operations.DoOperation(operation_name, spec, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver(consumer, output_tags.index(tag)) elif transform.function_spec.urn == IDENTITY_DOFN_URN: op = operations.FlattenOperation(operation_name, None, counter_factory, state_sampler) # TODO(robertwb): Move these to the constructor. op.step_name = transform.step_name for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver(consumer, 0) elif transform.function_spec.urn == PYTHON_SOURCE_URN: source = load_compressed( unpack_function_spec_data(transform.function_spec)) # TODO(vikasrk): Remove this once custom source is implemented with # splittable dofn via the data plane. spec = operation_specs.WorkerRead( iobase.SourceBundle(1.0, source, None, None), [WindowedValueCoder(source.default_output_coder())]) op = operations.ReadOperation(operation_name, spec, counter_factory, state_sampler) op.step_name = transform.step_name output_tags = list(transform.outputs.keys()) for tag, op_consumers in consumers[transform.id].items(): for consumer in op_consumers: op.add_receiver(consumer, output_tags.index(tag)) else: raise NotImplementedError # Record consumers. for _, inputs in transform.inputs.items(): for target in inputs.target: consumers[target.primitive_transform_reference][ target.name].append(op) reversed_ops.append(op) ops_by_id[transform.id] = op return list(reversed(reversed_ops))