Esempio n. 1
0
    def expand(self, pcoll):
        windowing_saved = pcoll.windowing
        if windowing_saved.is_default():
            # In this (common) case we can use a trivial trigger driver
            # and avoid the (expensive) window param.
            globally_windowed = window.GlobalWindows.windowed_value(None)
            window_fn = window.GlobalWindows()
            MIN_TIMESTAMP = window.MIN_TIMESTAMP

            def reify_timestamps(element, timestamp=DoFn.TimestampParam):
                key, value = element
                if timestamp == MIN_TIMESTAMP:
                    timestamp = None
                return key, (value, timestamp)

            def restore_timestamps(element):
                key, values = element
                return [
                    globally_windowed.with_value((key, value)) if
                    timestamp is None else window.GlobalWindows.windowed_value(
                        (key, value), timestamp)
                    for (value, timestamp) in values
                ]

        else:
            # The linter is confused.
            # hash(1) is used to force "runtime" selection of _IdentityWindowFn
            # pylint: disable=abstract-class-instantiated
            cls = hash(1) and _IdentityWindowFn
            window_fn = cls(windowing_saved.windowfn.get_window_coder())

            def reify_timestamps(element, timestamp=DoFn.TimestampParam):
                key, value = element
                return key, TimestampedValue(value, timestamp)

            def restore_timestamps(element, window=DoFn.WindowParam):
                # Pass the current window since _IdentityWindowFn wouldn't know how
                # to generate it.
                key, values = element
                return [
                    windowed_value.WindowedValue((key, value.value),
                                                 value.timestamp, [window])
                    for value in values
                ]

        ungrouped = pcoll | Map(reify_timestamps)
        ungrouped._windowing = Windowing(
            window_fn,
            triggerfn=AfterCount(1),
            accumulation_mode=AccumulationMode.DISCARDING,
            timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
        result = (ungrouped | GroupByKey() | FlatMap(restore_timestamps))
        result._windowing = windowing_saved
        return result
Esempio n. 2
0
    def replace_first(pcoll, regex, replacement):
        """
    Returns the matches if a portion of the line matches the regex and replaces
    the first match with the replacement string.

    Args:
      regex: the regular expression string or (re.compile) pattern.
      replacement: the string to be substituted for each match.
    """
        regex = Regex._regex_compile(regex)
        return pcoll | Map(lambda elem: regex.sub(replacement, elem, 1))
Esempio n. 3
0
    def expand(self, pcoll):
        windowing_saved = pcoll.windowing
        if windowing_saved.is_default():
            # In this (common) case we can use a trivial trigger driver
            # and avoid the (expensive) window param.
            globally_windowed = window.GlobalWindows.windowed_value(None)
            MIN_TIMESTAMP = window.MIN_TIMESTAMP

            def reify_timestamps(element, timestamp=DoFn.TimestampParam):
                key, value = element
                if timestamp == MIN_TIMESTAMP:
                    timestamp = None
                return key, (value, timestamp)

            def restore_timestamps(element):
                key, values = element
                return [
                    globally_windowed.with_value((key, value)) if
                    timestamp is None else window.GlobalWindows.windowed_value(
                        (key, value), timestamp)
                    for (value, timestamp) in values
                ]

        else:

            def reify_timestamps(element,
                                 timestamp=DoFn.TimestampParam,
                                 window=DoFn.WindowParam):
                key, value = element
                # Transport the window as part of the value and restore it later.
                return key, windowed_value.WindowedValue(
                    value, timestamp, [window])

            def restore_timestamps(element):
                key, windowed_values = element
                return [
                    wv.with_value((key, wv.value)) for wv in windowed_values
                ]

        ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)

        # TODO(BEAM-8104) Using global window as one of the standard window.
        # This is to mitigate the Dataflow Java Runner Harness limitation to
        # accept only standard coders.
        ungrouped._windowing = Windowing(
            window.GlobalWindows(),
            triggerfn=AfterCount(1),
            accumulation_mode=AccumulationMode.DISCARDING,
            timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
        result = (ungrouped
                  | GroupByKey()
                  | FlatMap(restore_timestamps).with_output_types(Any))
        result._windowing = windowing_saved
        return result
Esempio n. 4
0
 def Iterables(delimiter=None):
   """
   Transforms each item in the iterable of the input of PCollection to a
   string. There is no trailing delimiter.
   """
   if delimiter is None:
     delimiter = ','
   return (
       'IterablesToString' >>
       Map(lambda xs: delimiter.join(str(x) for x in xs)).with_input_types(
           Iterable[Any]).with_output_types(str))
Esempio n. 5
0
        def expand(self, pcoll):
            if reify_windows:
                pcoll = pcoll | ParDo(ReifyTimestampWindow())

            keyed_singleton = pcoll.pipeline | Create([(None, None)])

            if use_global_window:
                pcoll = pcoll | WindowInto(window.GlobalWindows())

            keyed_actual = pcoll | "ToVoidKey" >> Map(lambda v: (None, v))

            # This is a CoGroupByKey so that the matcher always runs, even if the
            # PCollection is empty.
            plain_actual = ((keyed_singleton, keyed_actual)
                            | "Group" >> CoGroupByKey()
                            | "Unkey" >> Map(lambda k_values: k_values[1][1]))

            if not use_global_window:
                plain_actual = plain_actual | "AddWindow" >> ParDo(AddWindow())

            plain_actual = plain_actual | "Match" >> Map(matcher)
Esempio n. 6
0
def WithKeys(pcoll, k, *args, **kwargs):
  """PTransform that takes a PCollection, and either a constant key or a
  callable, and returns a PCollection of (K, V), where each of the values in
  the input PCollection has been paired with either the constant key or a key
  computed from the value.  The callable may optionally accept positional or
  keyword arguments, which should be passed to WithKeys directly.  These may
  be either SideInputs or static (non-PCollection) values, such as ints.
  """
  if callable(k):
    if fn_takes_side_inputs(k):
      if all(isinstance(arg, AsSideInput)
             for arg in args) and all(isinstance(kwarg, AsSideInput)
                                      for kwarg in kwargs.values()):
        return pcoll | Map(
            lambda v,
            *args,
            **kwargs: (k(v, *args, **kwargs), v),
            *args,
            **kwargs)
      return pcoll | Map(lambda v: (k(v, *args, **kwargs), v))
    return pcoll | Map(lambda v: (k(v), v))
  return pcoll | Map(lambda v: (k, v))
Esempio n. 7
0
    def expand(self, pvalue):
        keyed_pc = (pvalue | 'AssignKey' >> Map(lambda x: (uuid.uuid4(), x)))
        if keyed_pc.windowing.windowfn.is_merging():
            raise ValueError(
                'Transform ReadAllFiles cannot be used in the presence '
                'of merging windows')
        if not isinstance(keyed_pc.windowing.triggerfn, DefaultTrigger):
            raise ValueError(
                'Transform ReadAllFiles cannot be used in the presence '
                'of non-trivial triggers')

        return (keyed_pc | 'GroupByKey' >> GroupByKey()
                # Using FlatMap below due to the possibility of key collisions.
                | 'DropKey' >> FlatMap(lambda (k, values): values))
Esempio n. 8
0
 def test_globally(self):
   l = [window.TimestampedValue(3, 100),
        window.TimestampedValue(1, 200),
        window.TimestampedValue(2, 300)]
   with TestPipeline() as p:
     # Map(lambda x: x) PTransform is added after Create here, because when
     # a PCollection of TimestampedValues is created with Create PTransform,
     # the timestamps are not assigned to it. Adding a Map forces the
     # PCollection to go through a DoFn so that the PCollection consists of
     # the elements with timestamps assigned to them instead of a PCollection
     # of TimestampedValue(element, timestamp).
     pc = p | Create(l) | Map(lambda x: x)
     latest = pc | combine.Latest.Globally()
     assert_that(latest, equal_to([2]))
Esempio n. 9
0
 def expand(self, pcoll):
     key_type, value_type = pcoll.element_type.tuple_types
     sharded_pcoll = pcoll | Map(lambda key_value: (
         ShardedKey(
             key_value[0],
             # Use [uuid, thread id] as the shard id.
             GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes(
                 threading.get_ident().to_bytes(8, 'big'))),
         key_value[1])).with_output_types(typehints.Tuple[
             ShardedKeyType[key_type],  # type: ignore[misc]
             value_type])
     return (sharded_pcoll
             |
             GroupIntoBatches(self.params.batch_size,
                              self.params.max_buffering_duration_secs))
Esempio n. 10
0
    def test_top(self):
        with TestPipeline() as pipeline:
            timestamp = 0

            # First for global combines.
            pcoll = pipeline | 'start' >> Create(
                [6, 3, 1, 1, 9, 1, 5, 2, 0, 6])
            result_top = pcoll | 'top' >> combine.Top.Largest(5)
            result_bot = pcoll | 'bot' >> combine.Top.Smallest(4)
            assert_that(result_top,
                        equal_to([[9, 6, 6, 5, 3]]),
                        label='assert:top')
            assert_that(result_bot,
                        equal_to([[0, 1, 1, 1]]),
                        label='assert:bot')

            # Now for global combines without default
            timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
            windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
            result_windowed_top = windowed | 'top-wo-defaults' >> combine.Top.Largest(
                5, has_defaults=False)
            result_windowed_bot = (windowed
                                   | 'bot-wo-defaults' >> combine.Top.Smallest(
                                       4, has_defaults=False))
            assert_that(result_windowed_top,
                        equal_to([[9, 6, 6, 5, 3]]),
                        label='assert:top-wo-defaults')
            assert_that(result_windowed_bot,
                        equal_to([[0, 1, 1, 1]]),
                        label='assert:bot-wo-defaults')

            # Again for per-key combines.
            pcoll = pipeline | 'start-perkey' >> Create(
                [('a', x) for x in [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]])
            result_key_top = pcoll | 'top-perkey' >> combine.Top.LargestPerKey(
                5)
            result_key_bot = pcoll | 'bot-perkey' >> combine.Top.SmallestPerKey(
                4)
            assert_that(result_key_top,
                        equal_to([('a', [9, 6, 6, 5, 3])]),
                        label='key:top')
            assert_that(result_key_bot,
                        equal_to([('a', [0, 1, 1, 1])]),
                        label='key:bot')
Esempio n. 11
0
    def test_builtin_combines(self):
        with TestPipeline() as pipeline:

            vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
            mean = sum(vals) / float(len(vals))
            size = len(vals)
            timestamp = 0

            # First for global combines.
            pcoll = pipeline | 'start' >> Create(vals)
            result_mean = pcoll | 'mean' >> combine.Mean.Globally()
            result_count = pcoll | 'count' >> combine.Count.Globally()
            assert_that(result_mean, equal_to([mean]), label='assert:mean')
            assert_that(result_count, equal_to([size]), label='assert:size')

            # Now for global combines without default
            timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
            windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
            result_windowed_mean = (windowed
                                    | 'mean-wo-defaults' >>
                                    combine.Mean.Globally().without_defaults())
            assert_that(result_windowed_mean,
                        equal_to([mean]),
                        label='assert:mean-wo-defaults')
            result_windowed_count = (
                windowed
                | 'count-wo-defaults' >>
                combine.Count.Globally().without_defaults())
            assert_that(result_windowed_count,
                        equal_to([size]),
                        label='assert:count-wo-defaults')

            # Again for per-key combines.
            pcoll = pipeline | 'start-perkey' >> Create([('a', x)
                                                         for x in vals])
            result_key_mean = pcoll | 'mean-perkey' >> combine.Mean.PerKey()
            result_key_count = pcoll | 'count-perkey' >> combine.Count.PerKey()
            assert_that(result_key_mean,
                        equal_to([('a', mean)]),
                        label='key:mean')
            assert_that(result_key_count,
                        equal_to([('a', size)]),
                        label='key:size')
Esempio n. 12
0
  def test_to_set(self):
    pipeline = TestPipeline()
    the_list = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
    timestamp = 0
    pcoll = pipeline | 'start' >> Create(the_list)
    result = pcoll | 'to set' >> combine.ToSet()

    # Now for global combines without default
    timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
    windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
    result_windowed = (
        windowed
        | 'to set wo defaults' >> combine.ToSet().without_defaults())

    def matcher(expected):
      def match(actual):
        equal_to(expected[0])(actual[0])

      return match

    assert_that(result, matcher(set(the_list)))
    assert_that(
        result_windowed, matcher(set(the_list)), label='to-set-wo-defaults')
Esempio n. 13
0
  def test_to_list_and_to_dict2(self):
    with TestPipeline() as pipeline:
      pairs = [(1, 2), (3, 4), (5, 6)]
      timestamp = 0
      pcoll = pipeline | 'start-pairs' >> Create(pairs)
      result = pcoll | 'to dict' >> combine.ToDict()

      # Now for global combines without default
      timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
      windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))
      result_windowed = (
          windowed
          | 'to dict wo defaults' >> combine.ToDict().without_defaults())

      def matcher():
        def match(actual):
          equal_to([1])([len(actual)])
          equal_to(pairs)(actual[0].items())

        return match

      assert_that(result, matcher())
      assert_that(result_windowed, matcher(), label='to-dict-wo-defaults')
Esempio n. 14
0
  def test_global_sample(self):
    def is_good_sample(actual):
      assert len(actual) == 1
      assert sorted(actual[0]) in [[1, 1, 2], [1, 2, 2]], actual

    with TestPipeline() as pipeline:
      timestamp = 0
      pcoll = pipeline | 'start' >> Create([1, 1, 2, 2])

      # Now for global combines without default
      timestamped = pcoll | Map(lambda x: TimestampedValue(x, timestamp))
      windowed = timestamped | 'window' >> WindowInto(FixedWindows(60))

      for ix in range(9):
        assert_that(
            pcoll | 'sample-%d' % ix >> combine.Sample.FixedSizeGlobally(3),
            is_good_sample,
            label='check-%d' % ix)
        result_windowed = (
            windowed
            | 'sample-wo-defaults-%d' % ix >>
            combine.Sample.FixedSizeGlobally(3).without_defaults())
        assert_that(
            result_windowed, is_good_sample, label='check-wo-defaults-%d' % ix)
Esempio n. 15
0
 def test_globally_empty(self):
     l = []
     with TestPipeline() as p:
         pc = p | Create(l) | Map(lambda x: x)
         latest = pc | combine.Latest.Globally()
         assert_that(latest, equal_to([None]))
Esempio n. 16
0
File: util.py Progetto: xmarker/beam
 def expand(self, pcoll):
   return (pcoll
           | 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t))
           | ReshufflePerKey()
           | 'RemoveRandomKeys' >> Map(lambda t: t[1]))
Esempio n. 17
0
def KvSwap(label='KvSwap'):  # pylint: disable=invalid-name
    """Produces a PCollection reversing 2-tuples in a PCollection."""
    return label >> Map(lambda k_v2: (k_v2[1], k_v2[0]))
Esempio n. 18
0
def Keys(label='Keys'):  # pylint: disable=invalid-name
    """Produces a PCollection of first elements of 2-tuples in a PCollection."""
    return label >> Map(lambda k_v: k_v[0])
Esempio n. 19
0
File: util.py Progetto: mszb/beam
  Returns:
    A stream with the contents of the opened files.
  """
    if 'b' in mode:
        encoding = None

    with tempfile.NamedTemporaryFile(delete=False) as out_file:
        for shard in glob.glob(glob_pattern):
            with open(shard, 'rb') as in_file:
                out_file.write(in_file.read())
        concatenated_file_name = out_file.name
    return io.open(concatenated_file_name, mode, encoding=encoding)


def _sort_lists(result):
    if isinstance(result, list):
        return sorted(result)
    elif isinstance(result, tuple):
        return tuple(_sort_lists(e) for e in result)
    elif isinstance(result, dict):
        return {k: _sort_lists(v) for k, v in result.items()}
    elif isinstance(result, Iterable) and not isinstance(result, str):
        return sorted(result)
    else:
        return result


# A utility transform that recursively sorts lists for easier testing.
SortLists = Map(_sort_lists)
Esempio n. 20
0
 def test_per_key_empty(self):
     l = []
     with TestPipeline() as p:
         pc = p | Create(l) | Map(lambda x: x)
         latest = pc | combine.Latest.PerKey()
         assert_that(latest, equal_to([]))
Esempio n. 21
0
 def expand(self, pcoll):
     side = pcoll | CombineGlobally(sum).as_singleton_view()
     main = pcoll.pipeline | Create([None])
     return main | Map(lambda _, s: s, side)
Esempio n. 22
0
      pcolls = pcolls.items()
    except AttributeError:
      # Otherwise, pcolls is a list/tuple, so we turn it into (index, pcoll)
      # pairs. The result value constructor makes tuples with len(pcolls) slots.
      pcolls = list(enumerate(pcolls))
      result_ctor_arg = len(pcolls)
      result_ctor = lambda size: tuple([] for _ in xrange(size))

    # Check input PCollections for PCollection-ness, and that they all belong
    # to the same pipeline.
    for _, pcoll in pcolls:
      self._check_pcollection(pcoll)
      if self.pipeline:
        assert pcoll.pipeline == self.pipeline

    return ([pcoll | 'pair_with_%s' % tag >> Map(_pair_tag_with_value, tag)
             for tag, pcoll in pcolls]
            | Flatten(pipeline=self.pipeline)
            | GroupByKey()
            | Map(_merge_tagged_vals_under_key, result_ctor, result_ctor_arg))


def Keys(label='Keys'):  # pylint: disable=invalid-name
  """Produces a PCollection of first elements of 2-tuples in a PCollection."""
  return label >> Map(lambda (k, v): k)


def Values(label='Values'):  # pylint: disable=invalid-name
  """Produces a PCollection of second elements of 2-tuples in a PCollection."""
  return label >> Map(lambda (k, v): v)
Esempio n. 23
0
def Values(label='Values'):  # pylint: disable=invalid-name
    """Produces a PCollection of second elements of 2-tuples in a PCollection."""
    return label >> Map(lambda k_v1: k_v1[1])
Esempio n. 24
0
    def test_avro_it(self):
        num_records = self.test_pipeline.get_option('records')
        num_records = int(num_records) if num_records else 1000000

        # Seed a `PCollection` with indices that will each be FlatMap'd into
        # `batch_size` records, to avoid having a too-large list in memory at
        # the outset
        batch_size = self.test_pipeline.get_option('batch-size')
        batch_size = int(batch_size) if batch_size else 10000

        # pylint: disable=range-builtin-not-iterating
        batches = range(int(num_records / batch_size))

        def batch_indices(start):
            # pylint: disable=range-builtin-not-iterating
            return range(start * batch_size, (start + 1) * batch_size)

        # A `PCollection` with `num_records` avro records
        records_pcoll = \
            self.test_pipeline \
            | 'create-batches' >> Create(batches) \
            | 'expand-batches' >> FlatMap(batch_indices) \
            | 'create-records' >> Map(record)

        fastavro_output = '/'.join([self.output, 'fastavro'])
        avro_output = '/'.join([self.output, 'avro'])

        self.addCleanup(delete_files, [self.output + '*'])

        # pylint: disable=expression-not-assigned
        records_pcoll \
        | 'write_fastavro' >> WriteToAvro(
            fastavro_output,
            self.SCHEMA,
            use_fastavro=True
        )

        # pylint: disable=expression-not-assigned
        records_pcoll \
        | 'write_avro' >> WriteToAvro(
            avro_output,
            self.SCHEMA,
            use_fastavro=False
        )

        result = self.test_pipeline.run()
        result.wait_until_finish()
        assert result.state == PipelineState.DONE

        fastavro_read_pipeline = TestPipeline(is_integration_test=True)

        fastavro_records = \
            fastavro_read_pipeline \
            | 'create-fastavro' >> Create(['%s*' % fastavro_output]) \
            | 'read-fastavro' >> ReadAllFromAvro(use_fastavro=True) \
            | Map(lambda rec: (rec['number'], rec))

        avro_records = \
            fastavro_read_pipeline \
            | 'create-avro' >> Create(['%s*' % avro_output]) \
            | 'read-avro' >> ReadAllFromAvro(use_fastavro=False) \
            | Map(lambda rec: (rec['number'], rec))

        def check(elem):
            v = elem[1]

            def assertEqual(l, r):
                if l != r:
                    raise BeamAssertException('Assertion failed: %s == %s' %
                                              (l, r))

            assertEqual(v.keys(), ['avro', 'fastavro'])
            avro_values = v['avro']
            fastavro_values = v['fastavro']
            assertEqual(avro_values, fastavro_values)
            assertEqual(len(avro_values), 1)

        # pylint: disable=expression-not-assigned
        {
            'avro': avro_records,
            'fastavro': fastavro_records
        } \
        | CoGroupByKey() \
        | Map(check)

        fastavro_read_pipeline.run().wait_until_finish()
        assert result.state == PipelineState.DONE
Esempio n. 25
0
def Distinct(pcoll):  # pylint: disable=invalid-name
    """Produces a PCollection containing distinct elements of a PCollection."""
    return (pcoll
            | 'ToPairs' >> Map(lambda v: (v, None))
            | 'Group' >> CombinePerKey(lambda vs: None)
            | 'Distinct' >> Keys())
Esempio n. 26
0
def RemoveDuplicates(pcoll):  # pylint: disable=invalid-name
    """Produces a PCollection containing the unique elements of a PCollection."""
    return (pcoll
            | 'ToPairs' >> Map(lambda v: (v, None))
            | 'Group' >> CombinePerKey(lambda vs: None)
            | 'RemoveDuplicates' >> Keys())
Esempio n. 27
0
 def Element():
     """
 Transforms each element of the PCollection to a string.
 """
     return 'ElementToString' >> Map(str)
Esempio n. 28
0
 def expand(self, pcoll):
     input_type = T
     output_type = str
     return (pcoll | ('%s:ElementToString' % self.label >>
                      (Map(lambda x: str(x))).with_input_types(
                          input_type).with_output_types(output_type)))
Esempio n. 29
0
        except AttributeError:
            # Otherwise, pcolls is a list/tuple, so we turn it into (index, pcoll)
            # pairs. The result value constructor makes tuples with len(pcolls) slots.
            pcolls = list(enumerate(pcolls))
            result_ctor_arg = len(pcolls)
            result_ctor = lambda size: tuple([] for _ in xrange(size))

        # Check input PCollections for PCollection-ness, and that they all belong
        # to the same pipeline.
        for _, pcoll in pcolls:
            self._check_pcollection(pcoll)
            if self.pipeline:
                assert pcoll.pipeline == self.pipeline

        return ([
            pcoll | Map('pair_with_%s' % tag, _pair_tag_with_value, tag)
            for tag, pcoll in pcolls
        ]
                | Flatten(pipeline=self.pipeline)
                | GroupByKey()
                | Map(_merge_tagged_vals_under_key, result_ctor,
                      result_ctor_arg))


def Keys(label='Keys'):  # pylint: disable=invalid-name
    """Produces a PCollection of first elements of 2-tuples in a PCollection."""
    return Map(label, lambda (k, v): k)


def Values(label='Values'):  # pylint: disable=invalid-name
    """Produces a PCollection of second elements of 2-tuples in a PCollection."""
Esempio n. 30
0
    def test_avro_it(self):
        num_records = self.test_pipeline.get_option('records')
        num_records = int(num_records) if num_records else 1000000
        fastavro_output = '/'.join([self.output, 'fastavro'])

        # Seed a `PCollection` with indices that will each be FlatMap'd into
        # `batch_size` records, to avoid having a too-large list in memory at
        # the outset
        batch_size = self.test_pipeline.get_option('batch-size')
        batch_size = int(batch_size) if batch_size else 10000

        # pylint: disable=bad-option-value
        batches = range(int(num_records / batch_size))

        def batch_indices(start):
            # pylint: disable=bad-option-value
            return range(start * batch_size, (start + 1) * batch_size)

        # A `PCollection` with `num_records` avro records
        records_pcoll = \
            self.test_pipeline \
            | 'create-batches' >> Create(batches) \
            | 'expand-batches' >> FlatMap(batch_indices) \
            | 'create-records' >> Map(record)

        # pylint: disable=expression-not-assigned
        records_pcoll \
        | 'write_fastavro' >> WriteToAvro(
            fastavro_output,
            parse_schema(json.loads(self.SCHEMA_STRING)),
        )
        result = self.test_pipeline.run()
        result.wait_until_finish()
        fastavro_pcoll = self.test_pipeline \
                         | 'create-fastavro' >> Create(['%s*' % fastavro_output]) \
                         | 'read-fastavro' >> ReadAllFromAvro()

        mapped_fastavro_pcoll = fastavro_pcoll | "map_fastavro" >> Map(
            lambda x: (x['number'], x))
        mapped_record_pcoll = records_pcoll | "map_record" >> Map(
            lambda x: (x['number'], x))

        def validate_record(elem):
            v = elem[1]

            def assertEqual(l, r):
                if l != r:
                    raise BeamAssertException('Assertion failed: %s == %s' %
                                              (l, r))

            assertEqual(sorted(v.keys()), ['fastavro', 'record_pcoll'])
            record_pcoll_values = v['record_pcoll']
            fastavro_values = v['fastavro']
            assertEqual(record_pcoll_values, fastavro_values)
            assertEqual(len(record_pcoll_values), 1)

        {
            "record_pcoll": mapped_record_pcoll,
            "fastavro": mapped_fastavro_pcoll
        } | CoGroupByKey() | Map(validate_record)

        result = self.test_pipeline.run()
        result.wait_until_finish()

        self.addCleanup(delete_files, [self.output])
        assert result.state == PipelineState.DONE