Ejemplo n.º 1
0
 def _run_read_from(self, transform_node, source):
   """Used when this operation is the result of reading source."""
   if not isinstance(source, NativeSource):
     source = iobase.SourceBundle(1.0, source, None, None)
   output = transform_node.outputs[None]
   element_coder = self._get_coder(output)
   read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])
   self.outputs[output] = len(self.map_tasks), 0, 0
   self.map_tasks.append([(transform_node.full_label, read_op)])
   return len(self.map_tasks) - 1
Ejemplo n.º 2
0
def create(factory, transform_id, transform_proto, parameter, consumers):
    source = pickler.loads(parameter.value)
    spec = operation_specs.WorkerRead(
        iobase.SourceBundle(1.0, source, None, None),
        [WindowedValueCoder(source.default_output_coder())])
    return factory.augment_oldstyle_op(
        operations.ReadOperation(transform_proto.unique_name, spec,
                                 factory.counter_factory,
                                 factory.state_sampler),
        transform_proto.unique_name, consumers)
def create(factory, transform_id, transform_proto, parameter, consumers):
    # The Dataflow runner harness strips the base64 encoding.
    source = pickler.loads(base64.b64encode(parameter))
    spec = operation_specs.WorkerRead(
        iobase.SourceBundle(1.0, source, None, None),
        [WindowedValueCoder(source.default_output_coder())])
    return factory.augment_oldstyle_op(
        operations.ReadOperation(transform_proto.unique_name, spec,
                                 factory.counter_factory,
                                 factory.state_sampler),
        transform_proto.unique_name, consumers)
Ejemplo n.º 4
0
  def create_execution_tree(self, descriptor):
    # TODO(vikasrk): Add an id field to Coder proto and use that instead.
    coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
        json.loads(unpack_function_spec_data(coder.function_spec)))
              for coder in descriptor.coders}

    counter_factory = counters.CounterFactory()
    # TODO(robertwb): Figure out the correct prefix to use for output counters
    # from StateSampler.
    state_sampler = statesampler.StateSampler(
        'fnapi-step%s-' % descriptor.id, counter_factory)
    consumers = collections.defaultdict(lambda: collections.defaultdict(list))
    ops_by_id = {}
    reversed_ops = []

    for transform in reversed(descriptor.primitive_transform):
      # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
      # "s3") from the service through the FnAPI so that msec counters can be
      # reported and correctly plumbed through the service and the UI.
      operation_name = 'fnapis%s' % transform.id

      def only_element(iterable):
        element, = iterable
        return element

      if transform.function_spec.urn == DATA_OUTPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.outputs.keys()))

        op = DataOutputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == DATA_INPUT_URN:
        target = beam_fn_api_pb2.Target(
            primitive_transform_reference=transform.id,
            name=only_element(transform.inputs.keys()))
        op = DataInputOperation(
            operation_name,
            transform.step_name,
            consumers[transform.id],
            counter_factory,
            state_sampler,
            coders[only_element(transform.outputs.values()).coder_reference],
            target,
            self.data_channel_factory.create_data_channel(
                transform.function_spec))

      elif transform.function_spec.urn == PYTHON_DOFN_URN:
        def create_side_input(tag, si):
          # TODO(robertwb): Extract windows (and keys) out of element data.
          return operation_specs.WorkerSideInputSource(
              tag=tag,
              source=SideInputSource(
                  self.state_handler,
                  beam_fn_api_pb2.StateKey(
                      function_spec_reference=si.view_fn.id),
                  coder=unpack_and_deserialize_py_fn(si.view_fn)))
        output_tags = list(transform.outputs.keys())
        spec = operation_specs.WorkerDoFn(
            serialized_fn=unpack_function_spec_data(transform.function_spec),
            output_tags=output_tags,
            input=None,
            side_inputs=[create_side_input(tag, si)
                         for tag, si in transform.side_inputs.items()],
            output_coders=[coders[transform.outputs[out].coder_reference]
                           for out in output_tags])

        op = operations.DoOperation(operation_name, spec, counter_factory,
                                    state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      elif transform.function_spec.urn == IDENTITY_DOFN_URN:
        op = operations.FlattenOperation(operation_name, None, counter_factory,
                                         state_sampler)
        # TODO(robertwb): Move these to the constructor.
        op.step_name = transform.step_name
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(consumer, 0)

      elif transform.function_spec.urn == PYTHON_SOURCE_URN:
        source = load_compressed(unpack_function_spec_data(
            transform.function_spec))
        # TODO(vikasrk): Remove this once custom source is implemented with
        # splittable dofn via the data plane.
        spec = operation_specs.WorkerRead(
            iobase.SourceBundle(1.0, source, None, None),
            [WindowedValueCoder(source.default_output_coder())])
        op = operations.ReadOperation(operation_name, spec, counter_factory,
                                      state_sampler)
        op.step_name = transform.step_name
        output_tags = list(transform.outputs.keys())
        for tag, op_consumers in consumers[transform.id].items():
          for consumer in op_consumers:
            op.add_receiver(
                consumer, output_tags.index(tag))

      else:
        raise NotImplementedError

      # Record consumers.
      for _, inputs in transform.inputs.items():
        for target in inputs.target:
          consumers[target.primitive_transform_reference][target.name].append(
              op)

      reversed_ops.append(op)
      ops_by_id[transform.id] = op

    return list(reversed(reversed_ops)), ops_by_id
Ejemplo n.º 5
0
    def run_ParDo(self, transform_node):
        transform = transform_node.transform
        output = transform_node.outputs[None]
        element_coder = self._get_coder(output)
        map_task_index, producer_index, output_index = self.outputs[
            transform_node.inputs[0]]

        # If any of this ParDo's side inputs depend on outputs from this map_task,
        # we can't continue growing this map task.
        def is_reachable(leaf, root):
            if leaf == root:
                return True
            else:
                return any(
                    is_reachable(x, root) for x in self.dependencies[leaf])

        if any(
                is_reachable(self.outputs[side_input.pvalue][0],
                             map_task_index)
                for side_input in transform_node.side_inputs):
            # Start a new map tasks.
            input_element_coder = self._get_coder(transform_node.inputs[0])

            output_buffer = OutputBuffer(input_element_coder)

            fusion_break_write = operation_specs.WorkerInMemoryWrite(
                output_buffer=output_buffer,
                write_windowed_values=True,
                input=(producer_index, output_index),
                output_coders=[input_element_coder])
            self.map_tasks[map_task_index].append(
                (transform_node.full_label + '/Write', fusion_break_write))

            original_map_task_index = map_task_index
            map_task_index, producer_index, output_index = len(
                self.map_tasks), 0, 0

            fusion_break_read = operation_specs.WorkerRead(
                output_buffer.source_bundle(),
                output_coders=[input_element_coder])
            self.map_tasks.append([(transform_node.full_label + '/Read',
                                    fusion_break_read)])

            self.dependencies[map_task_index].add(original_map_task_index)

        def create_side_read(side_input):
            label = self.side_input_labels[side_input]
            output_buffer = self.run_side_write(
                side_input.pvalue,
                '%s/%s' % (transform_node.full_label, label))
            return operation_specs.WorkerSideInputSource(
                output_buffer.source(), label)

        do_op = operation_specs.WorkerDoFn(  #
            serialized_fn=pickler.dumps(
                DataflowRunner._pardo_fn_data(
                    transform_node,
                    lambda side_input: self.side_input_labels[side_input])),
            output_tags=[PropertyNames.OUT] + [
                '%s_%s' % (PropertyNames.OUT, tag)
                for tag in transform.output_tags
            ],
            # Same assumption that DataflowRunner has about coders being compatible
            # across outputs.
            output_coders=[element_coder] * (len(transform.output_tags) + 1),
            input=(producer_index, output_index),
            side_inputs=[
                create_side_read(side_input)
                for side_input in transform_node.side_inputs
            ])

        producer_index = len(self.map_tasks[map_task_index])
        self.outputs[transform_node.outputs[None]] = (map_task_index,
                                                      producer_index, 0)
        for ix, tag in enumerate(transform.output_tags):
            self.outputs[transform_node.
                         outputs[tag]] = map_task_index, producer_index, ix + 1
        self.map_tasks[map_task_index].append(
            (transform_node.full_label, do_op))

        for side_input in transform_node.side_inputs:
            self.dependencies[map_task_index].add(
                self.outputs[side_input.pvalue][0])