Exemplo n.º 1
0
    def test_runner_api_transformation_with_subscription(
            self, unused_mock_pubsub):
        source = _PubSubSource(
            topic=None,
            subscription='projects/fakeprj/subscriptions/a_subscription',
            id_label='a_label',
            timestamp_attribute='b_label',
            with_attributes=True)
        transform = Read(source)

        context = pipeline_context.PipelineContext()
        proto_transform_spec = transform.to_runner_api(context)
        self.assertEqual(common_urns.composites.PUBSUB_READ.urn,
                         proto_transform_spec.urn)

        pubsub_read_payload = (proto_utils.parse_Bytes(
            proto_transform_spec.payload,
            beam_runner_api_pb2.PubSubReadPayload))
        self.assertEqual('projects/fakeprj/subscriptions/a_subscription',
                         pubsub_read_payload.subscription)
        self.assertEqual('a_label', pubsub_read_payload.id_attribute)
        self.assertEqual('b_label', pubsub_read_payload.timestamp_attribute)
        self.assertEqual('', pubsub_read_payload.topic)
        self.assertTrue(pubsub_read_payload.with_attributes)

        proto_transform = beam_runner_api_pb2.PTransform(
            unique_name="dummy_label", spec=proto_transform_spec)

        transform_from_proto = Read.from_runner_api_parameter(
            proto_transform, pubsub_read_payload, None)
        self.assertTrue(isinstance(transform_from_proto, Read))
        self.assertTrue(isinstance(transform_from_proto.source, _PubSubSource))
        self.assertTrue(transform_from_proto.source.with_attributes)
        self.assertEqual('projects/fakeprj/subscriptions/a_subscription',
                         transform_from_proto.source.full_subscription)
Exemplo n.º 2
0
    def test_root_transforms(self):
        root_create = Create('create', [[1, 2, 3]])

        class DummySource(iobase.BoundedSource):
            pass

        root_read = Read('read', DummySource())
        root_flatten = Flatten('flatten', pipeline=self.pipeline)

        pbegin = pvalue.PBegin(self.pipeline)
        pcoll_create = pbegin | root_create
        pbegin | root_read
        pcoll_create | FlatMap(lambda x: x)
        [] | root_flatten

        self.pipeline.visit(self.visitor)

        root_transforms = sorted(
            [t.transform for t in self.visitor.root_transforms])
        self.assertEqual(root_transforms,
                         sorted([root_read, root_create, root_flatten]))

        pbegin_consumers = sorted(
            [c.transform for c in self.visitor.value_to_consumers[pbegin]])
        self.assertEqual(pbegin_consumers, sorted([root_read, root_create]))
        self.assertEqual(len(self.visitor.step_names), 4)
Exemplo n.º 3
0
    def test_side_inputs(self):
        class SplitNumbersFn(DoFn):
            def process(self, element):
                if element < 0:
                    yield pvalue.SideOutputValue('tag_negative', element)
                else:
                    yield element

        class ProcessNumbersFn(DoFn):
            def process(self, element, negatives):
                yield element

        class DummySource(iobase.BoundedSource):
            pass

        root_read = Read(DummySource())

        result = (self.pipeline
                  | 'read' >> root_read
                  | ParDo(SplitNumbersFn()).with_outputs('tag_negative',
                                                         main='positive'))
        positive, negative = result
        positive | ParDo(ProcessNumbersFn(), AsList(negative))

        self.pipeline.visit(self.visitor)

        root_transforms = sorted(
            [t.transform for t in self.visitor.root_transforms])
        self.assertEqual(root_transforms, sorted([root_read]))
        self.assertEqual(len(self.visitor.step_names), 3)
        self.assertEqual(len(self.visitor.views), 1)
        self.assertTrue(isinstance(self.visitor.views[0], pvalue.AsList))
    def test(self):
        def format_record(record):
            import base64
            return base64.b64encode(record[1])

        def make_insert_mutations(element):
            import uuid  # pylint: disable=reimported
            from apache_beam.io.gcp.experimental.spannerio import WriteMutation
            ins_mutation = WriteMutation.insert(table='test',
                                                columns=('id', 'data'),
                                                values=[(str(uuid.uuid1()),
                                                         element)])
            return [ins_mutation]

        (  # pylint: disable=expression-not-assigned
            self.pipeline
            | 'Produce rows' >> Read(
                SyntheticSource(self.parse_synthetic_source_options()))
            | 'Count messages' >> ParDo(CountMessages(self.metrics_namespace))
            | 'Format' >> Map(format_record)
            | 'Make mutations' >> FlatMap(make_insert_mutations)
            | 'Measure time' >> ParDo(MeasureTime(self.metrics_namespace))
            | 'Write to Spanner' >> WriteToSpanner(
                project_id=self.project,
                instance_id=self.spanner_instance,
                database_id=self.TEST_DATABASE,
                max_batch_size_bytes=5120))
Exemplo n.º 5
0
    def _create_input_data(self):
        """
    Runs an additional pipeline which creates test data and waits for its
    completion.
    """
        SCHEMA = parse_table_schema_from_json(
            '{"fields": [{"name": "data", "type": "BYTES"}]}')

        def format_record(record):
            # Since Synthetic Source returns data as a dictionary, we should skip one
            # of the part
            import base64
            return {'data': base64.b64encode(record[1])}

        with TestPipeline() as p:
            (  # pylint: disable=expression-not-assigned
                p
                | 'Produce rows' >> Read(
                    SyntheticSource(self.parse_synthetic_source_options()))
                | 'Format' >> Map(format_record)
                | 'Write to BigQuery' >> WriteToBigQuery(
                    dataset=self.input_dataset,
                    table=self.input_table,
                    schema=SCHEMA,
                    create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
                    write_disposition=BigQueryDisposition.WRITE_EMPTY))
Exemplo n.º 6
0
    def _create_input_data(self):
        """
    Runs an additional pipeline which creates test data and waits for its
    completion.
    """
        def format_record(record):
            import base64
            return base64.b64encode(record[1])

        def make_insert_mutations(element):
            import uuid
            from apache_beam.io.gcp.experimental.spannerio import WriteMutation
            ins_mutation = WriteMutation.insert(table='test_data',
                                                columns=('id', 'data'),
                                                values=[(str(uuid.uuid1()),
                                                         element)])
            return [ins_mutation]

        with TestPipeline() as p:
            (  # pylint: disable=expression-not-assigned
                p
                | 'Produce rows' >> Read(
                    SyntheticSource(self.parse_synthetic_source_options()))
                | 'Format' >> Map(format_record)
                | 'Make mutations' >> FlatMap(make_insert_mutations)
                | 'Write to Spanner' >> WriteToSpanner(
                    project_id=self.project,
                    instance_id=self.spanner_instance,
                    database_id=self.spanner_database,
                    max_batch_size_bytes=5120))
Exemplo n.º 7
0
 def test_track_pcoll_unbounded(self):
     pipeline = TestPipeline()
     pcoll1 = pipeline | 'read' >> Read(FakeUnboundedSource())
     pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
     pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
     self.assertIs(pcoll1.is_bounded, False)
     self.assertIs(pcoll2.is_bounded, False)
     self.assertIs(pcoll3.is_bounded, False)
Exemplo n.º 8
0
 def test(self):
   self.result = (self.pipeline
                  | 'Read from BigQuery' >> Read(BigQuerySource(
                      dataset=self.input_dataset, table=self.input_table))
                  | 'Count messages' >> ParDo(CountMessages(
                      self.metrics_namespace))
                  | 'Measure time' >> ParDo(MeasureTime(
                      self.metrics_namespace))
                  | 'Count' >> Count.Globally())
Exemplo n.º 9
0
 def test(self):
   output = (
       self.pipeline
       | 'Read from BigQuery' >> Read(
           BigQuerySource(dataset=self.input_dataset, table=self.input_table))
       | 'Count messages' >> ParDo(CountMessages(self.metrics_namespace))
       | 'Measure time' >> ParDo(MeasureTime(self.metrics_namespace))
       | 'Count' >> Count.Globally())
   assert_that(output, equal_to([self.input_options['num_records']]))
Exemplo n.º 10
0
 def get_replacement_transform(self, ptransform):
   from apache_beam import pvalue
   from apache_beam.io import iobase
   class Read(iobase.Read):
     override = True
     def expand(self, pbegin):
       return pvalue.PCollection(
           self.pipeline, is_bounded=self.source.is_bounded())
   return Read(ptransform.source).with_output_types(
       ptransform.get_type_hints().simple_output_type('Read'))
Exemplo n.º 11
0
 def test_metrics_in_fake_source(self):
     pipeline = TestPipeline()
     pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6]))
     assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6]))
     res = pipeline.run()
     metric_results = res.metrics().query()
     outputs_counter = metric_results['counters'][0]
     self.assertEqual(outputs_counter.key.step, 'Read')
     self.assertEqual(outputs_counter.key.metric.name, 'outputs')
     self.assertEqual(outputs_counter.committed, 6)
Exemplo n.º 12
0
def run(argv=None):
    """Main entry point; defines and runs the wordcount pipeline."""

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input',
        dest='input',
        default='gs://dataflow-samples/shakespeare/kinglear.txt',
        help='Input file to process.')
    parser.add_argument('--output',
                        dest='output',
                        required=True,
                        help='Output file to write results to.')
    known_args, pipeline_args = parser.parse_known_args(argv)

    ###############################################
    # (1) pipeline を作成する
    ###############################################

    # まず PipelineOptions オブジェクトを作成
    # パイプラインを実行する pipeline runner や、選択した runner が必要とする固有の設定など、さまざまなオプションを設定できる
    pipeline_options = PipelineOptions(pipeline_args)

    # 作成した PipelineOptions オプジェクトを直接編集する例
    # 今回は DoFn transform を使用するため、save_main_sessionオプションを有効にする
    pipeline_options.view_as(SetupOptions).save_main_session = True

    # オプションを元に pipeline (p) を作成
    p6 = beam.Pipeline(options=pipeline_options)  #in→bigquery out→textのパイプライン

    ##############################################
    # (2) transformを設定
    ###############################################

    #p2にtransformを設定
    query = 'select * from babynames.testtable3'
    (p6 | 'read' >> Read(
        beam.io.BigQuerySource(
            project='gcp-project-210712', use_standard_sql=False, query=query))
     | 'pair' >> beam.Map(lambda x: (x['name'], x['count']))
     | 'groupby' >> beam.GroupByKey()
     | 'modify' >> beam.Map(modify_data2)
     | 'wirte' >> beam.io.Write(
         beam.io.BigQuerySink(
             'babynames.testtable4',
             schema='name:string, sum:integer',
             create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
             write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE)))

    #テーブル定義は全て書かないとwriteの時エラーとなる

    ###############################################
    # (3) Pipeline を実行
    ##############################################)
    p6.run().wait_until_finish()
Exemplo n.º 13
0
def run(argv=None):
  """Main entry point; defines and runs the wordcount pipeline."""

  parser = argparse.ArgumentParser()
  parser.add_argument('--input',
                      dest='input',
                      default='gs://dataflow-samples/shakespeare/kinglear.txt',
                      help='Input file to process.')
  parser.add_argument('--output',
                      dest='output',
                      required=True,
                      help='Output file to write results to.')
  known_args, pipeline_args = parser.parse_known_args(argv)


  ###############################################
  # (1) pipeline を作成する
  ###############################################

  # まず PipelineOptions オブジェクトを作成
  # パイプラインを実行する pipeline runner や、選択した runner が必要とする固有の設定など、さまざまなオプションを設定できる
  pipeline_options = PipelineOptions(pipeline_args)

  # 作成した PipelineOptions オプジェクトを直接編集する例
  # 今回は DoFn transform を使用するため、save_main_sessionオプションを有効にする
  pipeline_options.view_as(SetupOptions).save_main_session = True

  # オプションを元に pipeline (p) を作成
  p4 = beam.Pipeline(options=pipeline_options) #in→bigquery out→textのパイプライン

  ##############################################
  # (2) transformを設定
  ###############################################

  #p2にtransformを設定
  query = 'select * from babynames.names2012'

  p4 | 'read' >> Read(beam.io.BigQuerySource(project='gcp-project-210712', use_standard_sql=False, query=query)) \
     | 'modify' >> beam.Map(modify_data1) \
     | 'write' >> beam.io.Write(beam.io.BigQuerySink( \
        'babynames.testtable2',schema='name:string, count:integer', \
        create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
        write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE))

  ###############################################
  # (3) Pipeline を実行
  ##############################################)
  result4 = p4.run()

  # 終了を待つ
  # 記述しなければそのまま抜ける
  # →DataFlowRunnerの場合、Ctrl-Cでもパイプラインは停止しない。Gooleコンソールから停止する必要がある
  #ここで結果が終了するのを待ち合わせている。記載がなければ後続は処理されない。
  result4.wait_until_finish()
Exemplo n.º 14
0
def run(argv=None):
    """Main entry point; defines and runs the wordcount pipeline."""

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input',
        dest='input',
        default='gs://dataflow-samples/shakespeare/kinglear.txt',
        help='Input file to process.')
    parser.add_argument('--output',
                        dest='output',
                        required=True,
                        help='Output file to write results to.')
    known_args, pipeline_args = parser.parse_known_args(argv)

    ###############################################
    # (1) pipeline を作成する
    ###############################################

    # まず PipelineOptions オブジェクトを作成
    # パイプラインを実行する pipeline runner や、選択した runner が必要とする固有の設定など、さまざまなオプションを設定できる
    pipeline_options = PipelineOptions(pipeline_args)

    # 作成した PipelineOptions オプジェクトを直接編集する例
    # 今回は DoFn transform を使用するため、save_main_sessionオプションを有効にする
    pipeline_options.view_as(SetupOptions).save_main_session = True

    # オプションを元に pipeline (p) を作成
    p2 = beam.Pipeline(options=pipeline_options)  #in→bigquery out→textのパイプライン

    ##############################################
    # (2) transformを設定
    ###############################################

    #p2にtransformを設定
    query = 'select * from babynames.names2012 limit 10000'
    p2 | 'read' >> Read(beam.io.BigQuerySource(project='gcp-project-210712', use_standard_sql=False, query=query)) \
       | 'write' >> WriteToText('gs://gcp_dataflowsample/query_result.txt', num_shards=1)

    ###############################################
    # (3) Pipeline を実行
    ##############################################)
    result2 = p2.run()

    # 終了を待つ
    # 記述しなければそのまま抜ける
    # →DataFlowRunnerの場合、Ctrl-Cでもパイプラインは停止しない。Gooleコンソールから停止する必要がある
    #ここで結果が終了するのを待ち合わせている。記載がなければ後続は処理されない。
    result2.wait_until_finish()
Exemplo n.º 15
0
  def test_track_pcoll_unbounded_flatten(self):
    pipeline = TestPipeline()
    pcoll1_bounded = pipeline | 'label1' >> Create([1, 2, 3])
    pcoll2_bounded = pcoll1_bounded | 'do1' >> FlatMap(lambda x: [x + 1])

    pcoll1_unbounded = pipeline | 'read' >> Read(FakeUnboundedSource())
    pcoll2_unbounded = pcoll1_unbounded | 'do2' >> FlatMap(lambda x: [x + 1])

    merged = (pcoll2_bounded, pcoll2_unbounded) | beam.Flatten()

    self.assertIs(pcoll1_bounded.is_bounded, True)
    self.assertIs(pcoll2_bounded.is_bounded, True)
    self.assertIs(pcoll1_unbounded.is_bounded, False)
    self.assertIs(pcoll2_unbounded.is_bounded, False)
    self.assertIs(merged.is_bounded, False)
Exemplo n.º 16
0
 def test_read(self):
     schema = 'struct<a:int,b:struct<x:string,y:boolean>>'
     files = []
     with tempfile.NamedTemporaryFile() as f1, \
          tempfile.NamedTemporaryFile() as f2:
         files.append(f1.name)
         with pyorc.Writer(f1, schema) as writer:
             writer.write((1, ('x', True)))
         files.append(f2.name)
         with pyorc.Writer(f2, schema) as writer:
             writer.write((2, ('y', False)))
             writer.write((3, ('z', False)))
         with TestPipeline() as p:
             pc = (p | Read(
                 FileSource(
                     file_patterns=files,
                     reader=OrcReader(pyorc_options={
                         'struct_repr': pyorc.StructRepr.DICT,
                     }))))
         assert_that(
             pc,
             equal_to([
                 {
                     'a': 1,
                     'b': {
                         'x': 'x',
                         'y': True,
                     },
                 },
                 {
                     'a': 2,
                     'b': {
                         'x': 'y',
                         'y': False,
                     },
                 },
                 {
                     'a': 3,
                     'b': {
                         'x': 'z',
                         'y': False,
                     },
                 },
             ]))
Exemplo n.º 17
0
  def get_replacement_transform(self, ptransform):
    # Imported here to avoid circular dependencies.
    # pylint: disable=wrong-import-order, wrong-import-position
    from apache_beam import pvalue
    from apache_beam.io import iobase

    # This is purposely subclassed from the Read transform to take advantage of
    # the existing windowing, typing, and display data.
    class Read(iobase.Read):
      override = True

      def expand(self, pbegin):
        return pvalue.PCollection.from_(pbegin)

    # Use the source's coder type hint as this replacement's output. Otherwise,
    # the typing information is not properly forwarded to the DataflowRunner and
    # will choose the incorrect coder for this transform.
    return Read(ptransform.source).with_output_types(
        ptransform.source.coder.to_type_hint())
Exemplo n.º 18
0
    def test(self):
        SCHEMA = parse_table_schema_from_json(
            '{"fields": [{"name": "data", "type": "BYTES"}]}')

        def format_record(record):
            # Since Synthetic Source returns data as a dictionary, we should skip one
            # of the part
            return {'data': base64.b64encode(record[1])}

        # pylint: disable=expression-not-assigned
        (self.pipeline
         | 'ProduceRows' >> Read(
             SyntheticSource(self.parseTestPipelineOptions()))
         | 'Format' >> Map(format_record)
         | 'WriteToBigQuery' >> WriteToBigQuery(
             self.output_dataset + '.' + self.output_table,
             schema=SCHEMA,
             create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
             write_disposition=BigQueryDisposition.WRITE_EMPTY))
Exemplo n.º 19
0
  def test(self):
    def to_pubsub_message(element):
      import uuid
      from apache_beam.io import PubsubMessage
      return PubsubMessage(
          data=element[1],
          attributes={'id': str(uuid.uuid1()).encode('utf-8')},
      )

    _ = (
        self.pipeline
        | 'Create input' >> Read(
            SyntheticSource(self.parse_synthetic_source_options()))
        | 'Format to pubsub message in bytes' >> beam.Map(to_pubsub_message)
        | 'Measure time' >> beam.ParDo(MeasureTime(self.metrics_namespace))
        | 'Write to Pubsub' >> beam.io.WriteToPubSub(
            self.topic_name,
            with_attributes=True,
            id_label='id',
        ))
Exemplo n.º 20
0
    def test(self):
        SCHEMA = parse_table_schema_from_json(
            '{"fields": [{"name": "data", "type": "BYTES"}]}')

        def format_record(record):
            # Since Synthetic Source returns data as a dictionary, we should skip one
            # of the part
            return {'data': base64.b64encode(record[1])}

        (  # pylint: disable=expression-not-assigned
            self.pipeline
            | 'Produce rows' >> Read(
                SyntheticSource(self.parse_synthetic_source_options()))
            | 'Count messages' >> ParDo(CountMessages(self.metrics_namespace))
            | 'Format' >> Map(format_record)
            | 'Measure time' >> ParDo(MeasureTime(self.metrics_namespace))
            | 'Write to BigQuery' >> WriteToBigQuery(
                dataset=self.output_dataset,
                table=self.output_table,
                schema=SCHEMA,
                create_disposition=BigQueryDisposition.CREATE_IF_NEEDED,
                write_disposition=BigQueryDisposition.WRITE_TRUNCATE))
    def test_root_transforms(self):
        class DummySource(iobase.BoundedSource):
            pass

        root_read = Read(DummySource())
        root_flatten = Flatten(pipeline=self.pipeline)

        pbegin = pvalue.PBegin(self.pipeline)
        pcoll_read = pbegin | 'read' >> root_read
        pcoll_read | FlatMap(lambda x: x)
        [] | 'flatten' >> root_flatten

        self.pipeline.visit(self.visitor)

        root_transforms = sorted(
            [t.transform for t in self.visitor.root_transforms])

        self.assertEqual(root_transforms, sorted([root_read, root_flatten]))

        pbegin_consumers = sorted(
            [c.transform for c in self.visitor.value_to_consumers[pbegin]])
        self.assertEqual(pbegin_consumers, sorted([root_read]))
        self.assertEqual(len(self.visitor.step_names), 3)
Exemplo n.º 22
0
    def test_runner_api_transformation_properties_none(self,
                                                       unused_mock_pubsub):
        # Confirming that properties stay None after a runner API transformation.
        source = _PubSubSource(topic='projects/fakeprj/topics/a_topic',
                               with_attributes=True)
        transform = Read(source)

        context = pipeline_context.PipelineContext()
        proto_transform_spec = transform.to_runner_api(context)
        self.assertEqual(common_urns.composites.PUBSUB_READ.urn,
                         proto_transform_spec.urn)

        pubsub_read_payload = (proto_utils.parse_Bytes(
            proto_transform_spec.payload,
            beam_runner_api_pb2.PubSubReadPayload))

        proto_transform = beam_runner_api_pb2.PTransform(
            unique_name="dummy_label", spec=proto_transform_spec)

        transform_from_proto = Read.from_runner_api_parameter(
            proto_transform, pubsub_read_payload, None)
        self.assertIsNone(transform_from_proto.source.full_subscription)
        self.assertIsNone(transform_from_proto.source.id_label)
        self.assertIsNone(transform_from_proto.source.timestamp_attribute)
def run(argv=None, comments=None):
    """Run the beam pipeline.

    Args:
        argv: (optional) the command line flags to parse.
        comments_collection: (optional) a list of comment JSON objects to
            process. Used in unit-tests to avoid requiring a BigQuery source.
    """
    args, pipeline_args = _parse_args(argv)

    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    p = beam.Pipeline(options=pipeline_options)

    if comments is not None:
        comments = p | ("Read in-memory comments") >> beam.Create(comments)
    else:
        comments = p | ("Read " + args.reddit_table) >> Read(
            BigQuerySource(args.reddit_table))

    comments |= (
        "Normalise comments" >> beam.Map(
            partial(normalise_comment, max_length=args.max_length)))

    thread_id_to_comments = comments | (
        "Key by thread id" >> beam.Map(
            lambda comment: (comment.thread_id, comment)))
    threads = thread_id_to_comments | (
        "Group comments by thread ID" >> beam.GroupByKey())
    threads = threads | ("Get threads" >> beam.Map(lambda t: t[1]))

    examples = threads | (
        "Create {} examples".format(args.dataset_format) >> beam.FlatMap(
            partial(create_examples,
                    parent_depth=args.parent_depth,
                    min_length=args.min_length,
                    format=args.dataset_format,
                    )))
    examples = _shuffle(examples)

    # [START dataflow_molecules_split_to_train_and_eval_datasets]
    # Split the dataset into a training set and an evaluation set
    assert 0 < (100 - args.train_split*100) < 100, 'eval_percent must in the range (0-100)'
    eval_percent = 100 - args.train_split*100
    train_dataset, eval_dataset = (
        examples
        | 'Split dataset' >> beam.Partition(
            lambda elem, _: int(random.uniform(0, 100) < eval_percent), 2))
    # [END dataflow_molecules_split_to_train_and_eval_datasets]

    if args.dataset_format == _JSON_FORMAT:
        write_sink = WriteToText
        file_name_suffix = ".json"
        serialize_fn = json.dumps

    serialized_train_examples = train_dataset | (
        "serialize {} examples".format('train') >> beam.Map(serialize_fn))
    (
        serialized_train_examples | ("write " + 'train')
        >> write_sink(
            os.path.join(args.output_dir, 'train'),
            file_name_suffix=file_name_suffix,
            num_shards=args.num_shards_train,
        )
    )

    serialized_test_examples = eval_dataset | (
        "serialize {} examples".format('valid') >> beam.Map(serialize_fn))
    (
        serialized_test_examples | ("write " + 'valid')
        >> write_sink(
            os.path.join(args.output_dir, 'valid'),
            file_name_suffix=file_name_suffix,
            num_shards=args.num_shards_train,
        )
    )

    result = p.run()
    result.wait_until_finish()
Exemplo n.º 24
0
 def test_fake_read(self):
     with TestPipeline() as pipeline:
         pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
         assert_that(pcoll, equal_to([1, 2, 3]))
Exemplo n.º 25
0
 def test_fake_read(self):
     # FakeSource mock requires DirectRunner.
     pipeline = TestPipeline(runner='DirectRunner')
     pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
     assert_that(pcoll, equal_to([1, 2, 3]))
     pipeline.run()
Exemplo n.º 26
0
 def expand(self, pvalue):
   return pvalue.pipeline | Read(self._source)
Exemplo n.º 27
0
 def test_read(self):
     pipeline = TestPipeline()
     pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
     assert_that(pcoll, equal_to([1, 2, 3]))
     pipeline.run()
Exemplo n.º 28
0
def run(argv=None, comments=None):
    """Run the beam pipeline.

    Args:
        argv: (optional) the command line flags to parse.
        comments_collection: (optional) a list of comment JSON objects to
            process. Used in unit-tests to avoid requiring a BigQuery source.
    """
    args, pipeline_args = _parse_args(argv)

    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    p = beam.Pipeline(options=pipeline_options)

    if comments is not None:
        comments = p | ("Read in-memory comments") >> beam.Create(comments)
    else:
        comments = p | ("Read " + args.reddit_table) >> Read(
            BigQuerySource(args.reddit_table))

    comments |= ("Normalise comments" >> beam.Map(
        partial(normalise_comment, max_length=args.max_length)))

    thread_id_to_comments = comments | (
        "Key by thread id" >> beam.Map(lambda comment:
                                       (comment.thread_id, comment)))
    threads = thread_id_to_comments | (
        "Group comments by thread ID" >> beam.GroupByKey())
    threads = threads | ("Get threads" >> beam.Map(lambda t: t[1]))

    examples = threads | (
        "Create {} examples".format(args.dataset_format) >> beam.FlatMap(
            partial(
                create_examples,
                parent_depth=args.parent_depth,
                min_length=args.min_length,
                format=args.dataset_format,
            )))
    examples = _shuffle(examples)

    examples |= "split train and test" >> beam.ParDo(
        _TrainTestSplitFn(train_split=args.train_split)).with_outputs(
            _TrainTestSplitFn.TEST_TAG, _TrainTestSplitFn.TRAIN_TAG)

    if args.dataset_format == _JSON_FORMAT:
        write_sink = WriteToText
        file_name_suffix = ".json"
        serialize_fn = json.dumps
    else:
        assert args.dataset_format == _TF_FORMAT
        write_sink = WriteToTFRecord
        file_name_suffix = ".tfrecord"
        serialize_fn = _features_to_serialized_tf_example

    for name, tag in [("train", _TrainTestSplitFn.TRAIN_TAG),
                      ("test", _TrainTestSplitFn.TEST_TAG)]:

        serialized_examples = examples[tag] | (
            "serialize {} examples".format(name) >> beam.Map(serialize_fn))
        (serialized_examples | ("write " + name) >> write_sink(
            os.path.join(args.output_dir, name),
            file_name_suffix=file_name_suffix,
            num_shards=args.num_shards_train,
        ))

    result = p.run()
    result.wait_until_finish()
Exemplo n.º 29
0
 def expand(self, pvalue):
   return pvalue.pipeline | Read(_TFRecordSource(*self._args))
Exemplo n.º 30
0
 def expand(self, pvalue):  # pylint: disable=arguments-differ
     return pvalue.pipeline | Read(self._source)