Esempio n. 1
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            beam_operation_cls,
                                            internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])

    if hasattr(spec.serialized_fn, "key_type"):
        # keyed operation, need to create the KeyedStateBackend.
        key_row_coder = from_proto(spec.serialized_fn.key_type)
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder,
            spec.serialized_fn.state_cache_size,
            spec.serialized_fn.map_state_read_cache_size,
            spec.serialized_fn.map_state_write_cache_size)

        return beam_operation_cls(transform_proto.unique_name, spec,
                                  factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    else:
        return beam_operation_cls(transform_proto.unique_name, spec,
                                  factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls)
Esempio n. 2
0
def create(factory, transform_id, transform_proto, parameter, consumers):
    dofn_data = pickler.loads(parameter.value)
    if len(dofn_data) == 2:
        # Has side input data.
        serialized_fn, side_input_data = dofn_data
    else:
        # No side input data.
        serialized_fn, side_input_data = parameter.value, []

    def create_side_input(tag, coder):
        # TODO(robertwb): Extract windows (and keys) out of element data.
        # TODO(robertwb): Extract state key from ParDoPayload.
        return operation_specs.WorkerSideInputSource(
            tag=tag,
            source=SideInputSource(factory.state_handler,
                                   beam_fn_api_pb2.StateKey.MultimapSideInput(
                                       key=side_input_tag(transform_id, tag)),
                                   coder=coder))

    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=serialized_fn,
        output_tags=output_tags,
        input=None,
        side_inputs=[
            create_side_input(tag, coder) for tag, coder in side_input_data
        ],
        output_coders=[output_coders[tag] for tag in output_tags])
    return factory.augment_oldstyle_op(
        operations.DoOperation(transform_proto.unique_name, spec,
                               factory.counter_factory, factory.state_sampler),
        transform_proto.unique_name, consumers, output_tags)
Esempio n. 3
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_inputs_proto=None):

  if side_inputs_proto:
    input_tags_to_coders = factory.get_input_coders(transform_proto)
    tagged_side_inputs = [
        (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
        for tag, si in side_inputs_proto.items()]
    tagged_side_inputs.sort(key=lambda tag_si: int(tag_si[0][4:]))
    side_input_maps = [
        StateBackedSideInputMap(
            factory.state_handler,
            transform_id,
            tag,
            si,
            input_tags_to_coders[tag])
        for tag, si in tagged_side_inputs]
  else:
    side_input_maps = []

  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag

  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    side_input_tags = side_inputs_proto or ()
    pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
                 if tag not in side_input_tags]
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))

  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=None,  # Fn API uses proto definitions and the Fn State API
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler,
          side_input_maps),
      transform_proto.unique_name,
      consumers,
      output_tags)
Esempio n. 4
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            beam_operation_cls,
                                            internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])
    name = common.NameContext(transform_proto.unique_name)

    serialized_fn = spec.serialized_fn
    if hasattr(serialized_fn, "key_type"):
        # keyed operation, need to create the KeyedStateBackend.
        row_schema = serialized_fn.key_type.row_schema
        key_row_coder = FlattenRowCoder(
            [from_proto(f.type) for f in row_schema.fields])
        if serialized_fn.HasField('group_window'):
            if serialized_fn.group_window.is_time_window:
                window_coder = TimeWindowCoder()
            else:
                window_coder = CountWindowCoder()
        else:
            window_coder = None
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, window_coder,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)

        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    elif internal_operation_cls == datastream_operations.StatefulOperation:
        key_row_coder = from_type_info_proto(serialized_fn.key_type_info)
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, None,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    else:
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls)
Esempio n. 5
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])

    return operation_cls(transform_proto.unique_name, spec,
                         factory.counter_factory, factory.state_sampler,
                         consumers)
Esempio n. 6
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, side_input_data):
  def create_side_input(tag, coder):
    # TODO(robertwb): Extract windows (and keys) out of element data.
    # TODO(robertwb): Extract state key from ParDoPayload.
    return operation_specs.WorkerSideInputSource(
        tag=tag,
        source=SideInputSource(
            factory.state_handler,
            beam_fn_api_pb2.StateKey.MultimapSideInput(
                key=side_input_tag(transform_id, tag)),
            coder=coder))
  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag
  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    pcoll_id, = transform_proto.inputs.values()
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=[
          create_side_input(tag, coder) for tag, coder in side_input_data],
      output_coders=[output_coders[tag] for tag in output_tags])
  return factory.augment_oldstyle_op(
      operations.DoOperation(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler),
      transform_proto.unique_name,
      consumers,
      output_tags)
Esempio n. 7
0
def _create_stateful_user_defined_function_operation(factory, transform_proto,
                                                     consumers, udfs_proto,
                                                     beam_operation_cls,
                                                     internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_coders])
    key_type_info = spec.serialized_fn.key_type_info
    key_row_coder = from_type_info_proto(key_type_info.field[0].type)
    keyed_state_backend = RemoteKeyedStateBackend(factory.state_handler,
                                                  key_row_coder, 1000, 1000,
                                                  1000)

    return beam_operation_cls(transform_proto.unique_name, spec,
                              factory.counter_factory, factory.state_sampler,
                              consumers, internal_operation_cls,
                              keyed_state_backend)
Esempio n. 8
0
def _create_pardo_operation(
    factory, transform_id, transform_proto, consumers,
    serialized_fn, pardo_proto=None, operation_cls=operations.DoOperation):

  if pardo_proto and pardo_proto.side_inputs:
    input_tags_to_coders = factory.get_input_coders(transform_proto)
    tagged_side_inputs = [
        (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
        for tag, si in pardo_proto.side_inputs.items()]
    tagged_side_inputs.sort(
        key=lambda tag_si: int(re.match('side([0-9]+)(-.*)?$',
                                        tag_si[0]).group(1)))
    side_input_maps = [
        StateBackedSideInputMap(
            factory.state_handler,
            transform_id,
            tag,
            si,
            input_tags_to_coders[tag])
        for tag, si in tagged_side_inputs]
  else:
    side_input_maps = []

  output_tags = list(transform_proto.outputs.keys())

  # Hack to match out prefix injected by dataflow runner.
  def mutate_tag(tag):
    if 'None' in output_tags:
      if tag == 'None':
        return 'out'
      else:
        return 'out_' + tag
    else:
      return tag

  dofn_data = pickler.loads(serialized_fn)
  if not dofn_data[-1]:
    # Windowing not set.
    if pardo_proto:
      other_input_tags = set.union(
          set(pardo_proto.side_inputs), set(pardo_proto.timer_specs))
    else:
      other_input_tags = ()
    pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
                 if tag not in other_input_tags]
    windowing = factory.context.windowing_strategies.get_by_id(
        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))

  if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs
                      or pardo_proto.splittable):
    main_input_coder = None
    timer_inputs = {}
    for tag, pcoll_id in transform_proto.inputs.items():
      if tag in pardo_proto.timer_specs:
        timer_inputs[tag] = pcoll_id
      elif tag in pardo_proto.side_inputs:
        pass
      else:
        # Must be the main input
        assert main_input_coder is None
        main_input_tag = tag
        main_input_coder = factory.get_windowed_coder(pcoll_id)
    assert main_input_coder is not None

    if pardo_proto.timer_specs or pardo_proto.state_specs:
      user_state_context = FnApiUserStateContext(
          factory.state_handler,
          transform_id,
          main_input_coder.key_coder(),
          main_input_coder.window_coder,
          timer_specs=pardo_proto.timer_specs)
    else:
      user_state_context = None
  else:
    user_state_context = None
    timer_inputs = None

  output_coders = factory.get_output_coders(transform_proto)
  spec = operation_specs.WorkerDoFn(
      serialized_fn=serialized_fn,
      output_tags=[mutate_tag(tag) for tag in output_tags],
      input=None,
      side_inputs=None,  # Fn API uses proto definitions and the Fn State API
      output_coders=[output_coders[tag] for tag in output_tags])

  result = factory.augment_oldstyle_op(
      operation_cls(
          transform_proto.unique_name,
          spec,
          factory.counter_factory,
          factory.state_sampler,
          side_input_maps,
          user_state_context,
          timer_inputs=timer_inputs),
      transform_proto.unique_name,
      consumers,
      output_tags)
  if pardo_proto and pardo_proto.splittable:
    result.input_info = (
        transform_id, main_input_tag, main_input_coder,
        transform_proto.outputs.keys())
  return result
Esempio n. 9
0
  def create_execution_tree(self, descriptor):
    # TODO(vikasrk): Add an id field to Coder proto and use that instead.
    coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
        json.loads(unpack_function_spec_data(coder.function_spec)))
              for coder in descriptor.coders}

    counter_factory = counters.CounterFactory()
    # TODO(robertwb): Figure out the correct prefix to use for output counters
    # from StateSampler.
    state_sampler = statesampler.StateSampler(
        'fnapi-step%s-' % descriptor.id, counter_factory)
    consumers = collections.defaultdict(lambda: collections.defaultdict(list))
    ops_by_id = {}
    reversed_ops = []

    for transform in reversed(descriptor.primitive_transform):
      # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
      # "s3") from the service through the FnAPI so that msec counters can be
      # reported and correctly plumbed through the service and the UI.
      operation_name = 'fnapis%s' % transform.id

      def only_element(iterable):
        element, = iterable
        return element

      if transform.function_spec.urn == DATA_OUTPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.outputs.keys()))

        op = DataOutputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == DATA_INPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.inputs.keys()))
        op = DataInputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == PYTHON_DOFN_URN:
        def create_side_input(tag, si):
          # TODO(robertwb): Extract windows (and keys) out of element data.
          return operation_specs.WorkerSideInputSource(
              tag=tag,
              source=SideInputSource(
                  self.state_handler,
                  beam_fn_api_pb2.StateKey(
                      function_spec_reference=si.view_fn.id),
                  coder=unpack_and_deserialize_py_fn(si.view_fn)))
        output_tags = list(transform.outputs.keys())
        spec = operation_specs.WorkerDoFn(
            serialized_fn=unpack_function_spec_data(transform.function_spec),
            output_tags=output_tags,
            input=None,
            side_inputs=[create_side_input(tag, si)
                         for tag, si in transform.side_inputs.items()],
            output_coders=[coders[transform.outputs[out].coder_reference]
                           for out in output_tags])

        op = operations.DoOperation(operation_name, spec, counter_factory,
                                    state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      elif transform.function_spec.urn == IDENTITY_DOFN_URN:
        op = operations.FlattenOperation(operation_name, None, counter_factory,
                                         state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(consumer, 0)

      elif transform.function_spec.urn == PYTHON_SOURCE_URN:
        source = load_compressed(unpack_function_spec_data(
            transform.function_spec))
        # TODO(vikasrk): Remove this once custom source is implemented with
        # splittable dofn via the data plane.
        spec = operation_specs.WorkerRead(
            iobase.SourceBundle(1.0, source, None, None),
            [WindowedValueCoder(source.default_output_coder())])
        op = operations.ReadOperation(operation_name, spec, counter_factory,
                                      state_sampler)
        op.step_name = transform.step_name
        output_tags = list(transform.outputs.keys())
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      else:
        raise NotImplementedError

      # Record consumers.
      for _, inputs in transform.inputs.items():
        for target in inputs.target:
          consumers[target.primitive_transform_reference][target.name].append(
              op)

      reversed_ops.append(op)
      ops_by_id[transform.id] = op

    return list(reversed(reversed_ops)), ops_by_id
Esempio n. 10
0
    def run_ParDo(self, transform_node):
        transform = transform_node.transform
        output = transform_node.outputs[None]
        element_coder = self._get_coder(output)
        map_task_index, producer_index, output_index = self.outputs[
            transform_node.inputs[0]]

        # If any of this ParDo's side inputs depend on outputs from this map_task,
        # we can't continue growing this map task.
        def is_reachable(leaf, root):
            if leaf == root:
                return True
            else:
                return any(
                    is_reachable(x, root) for x in self.dependencies[leaf])

        if any(
                is_reachable(self.outputs[side_input.pvalue][0],
                             map_task_index)
                for side_input in transform_node.side_inputs):
            # Start a new map tasks.
            input_element_coder = self._get_coder(transform_node.inputs[0])

            output_buffer = OutputBuffer(input_element_coder)

            fusion_break_write = operation_specs.WorkerInMemoryWrite(
                output_buffer=output_buffer,
                write_windowed_values=True,
                input=(producer_index, output_index),
                output_coders=[input_element_coder])
            self.map_tasks[map_task_index].append(
                (transform_node.full_label + '/Write', fusion_break_write))

            original_map_task_index = map_task_index
            map_task_index, producer_index, output_index = len(
                self.map_tasks), 0, 0

            fusion_break_read = operation_specs.WorkerRead(
                output_buffer.source_bundle(),
                output_coders=[input_element_coder])
            self.map_tasks.append([(transform_node.full_label + '/Read',
                                    fusion_break_read)])

            self.dependencies[map_task_index].add(original_map_task_index)

        def create_side_read(side_input):
            label = self.side_input_labels[side_input]
            output_buffer = self.run_side_write(
                side_input.pvalue,
                '%s/%s' % (transform_node.full_label, label))
            return operation_specs.WorkerSideInputSource(
                output_buffer.source(), label)

        do_op = operation_specs.WorkerDoFn(  #
            serialized_fn=pickler.dumps(
                DataflowRunner._pardo_fn_data(
                    transform_node,
                    lambda side_input: self.side_input_labels[side_input])),
            output_tags=[PropertyNames.OUT] + [
                '%s_%s' % (PropertyNames.OUT, tag)
                for tag in transform.output_tags
            ],
            # Same assumption that DataflowRunner has about coders being compatible
            # across outputs.
            output_coders=[element_coder] * (len(transform.output_tags) + 1),
            input=(producer_index, output_index),
            side_inputs=[
                create_side_read(side_input)
                for side_input in transform_node.side_inputs
            ])

        producer_index = len(self.map_tasks[map_task_index])
        self.outputs[transform_node.outputs[None]] = (map_task_index,
                                                      producer_index, 0)
        for ix, tag in enumerate(transform.output_tags):
            self.outputs[transform_node.
                         outputs[tag]] = map_task_index, producer_index, ix + 1
        self.map_tasks[map_task_index].append(
            (transform_node.full_label, do_op))

        for side_input in transform_node.side_inputs:
            self.dependencies[map_task_index].add(
                self.outputs[side_input.pvalue][0])