예제 #1
0
  def _create_environment(options):
    portable_options = options.view_as(PortableOptions)
    environment_urn = common_urns.environments.DOCKER.urn
    if portable_options.environment_type == 'DOCKER':
      environment_urn = common_urns.environments.DOCKER.urn
    elif portable_options.environment_type == 'PROCESS':
      environment_urn = common_urns.environments.PROCESS.urn

    if environment_urn == common_urns.environments.DOCKER.urn:
      docker_image = (
          portable_options.environment_config
          or PortableRunner.default_docker_image())
      return beam_runner_api_pb2.Environment(
          url=docker_image,
          urn=common_urns.environments.DOCKER.urn,
          payload=beam_runner_api_pb2.DockerPayload(
              container_image=docker_image
          ).SerializeToString())
    elif environment_urn == common_urns.environments.PROCESS.urn:
      config = json.loads(portable_options.environment_config)
      return beam_runner_api_pb2.Environment(
          urn=common_urns.environments.PROCESS.urn,
          payload=beam_runner_api_pb2.ProcessPayload(
              os=(config.get('os') or ''),
              arch=(config.get('arch') or ''),
              command=config.get('command'),
              env=(config.get('env') or '')
          ).SerializeToString())
예제 #2
0
 def test__create_process_environment(self):
     self.assertEqual(
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({
                 'environment_type':
                 "PROCESS",
                 'environment_config':
                 '{"os": "linux", "arch": "amd64", '
                 '"command": "run.sh", '
                 '"env":{"k1": "v1"} }',
             })),
         beam_runner_api_pb2.Environment(
             urn=common_urns.environments.PROCESS.urn,
             payload=beam_runner_api_pb2.ProcessPayload(
                 os='linux',
                 arch='amd64',
                 command='run.sh',
                 env={
                     'k1': 'v1'
                 },
             ).SerializeToString()))
     self.assertEqual(
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({
                 'environment_type':
                 'PROCESS',
                 'environment_config':
                 '{"command": "run.sh"}',
             })),
         beam_runner_api_pb2.Environment(
             urn=common_urns.environments.PROCESS.urn,
             payload=beam_runner_api_pb2.ProcessPayload(
                 command='run.sh', ).SerializeToString()))
예제 #3
0
 def test_stage_resources(self):
     pipeline_options = PipelineOptions([
         '--temp_location', 'gs://test-location/temp', '--staging_location',
         'gs://test-location/staging', '--no_auth'
     ])
     pipeline = beam_runner_api_pb2.Pipeline(
         components=beam_runner_api_pb2.Components(
             environments={
                 'env1':
                 beam_runner_api_pb2.Environment(dependencies=[
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/foo1').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='foo1').SerializeToString()),
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/bar1').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='bar1').SerializeToString())
                 ]),
                 'env2':
                 beam_runner_api_pb2.Environment(dependencies=[
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/foo2').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='foo2').SerializeToString()),
                     beam_runner_api_pb2.ArtifactInformation(
                         type_urn=common_urns.artifact_types.FILE.urn,
                         type_payload=beam_runner_api_pb2.
                         ArtifactFilePayload(
                             path='/tmp/bar2').SerializeToString(),
                         role_urn=common_urns.artifact_roles.STAGING_TO.urn,
                         role_payload=beam_runner_api_pb2.
                         ArtifactStagingToRolePayload(
                             staged_name='bar2').SerializeToString())
                 ])
             }))
     client = apiclient.DataflowApplicationClient(pipeline_options)
     with mock.patch.object(apiclient._LegacyDataflowStager,
                            'stage_job_resources') as mock_stager:
         client._stage_resources(pipeline, pipeline_options)
     mock_stager.assert_called_once_with(
         [('/tmp/foo1', 'foo1'), ('/tmp/bar1', 'bar1'),
          ('/tmp/foo2', 'foo2'), ('/tmp/bar2', 'bar2')],
         staging_location='gs://test-location/staging')
예제 #4
0
 def test__create_external_environment(self):
     self.assertEqual(
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({
                 'environment_type':
                 "EXTERNAL",
                 'environment_config':
                 'localhost:50000',
             })),
         beam_runner_api_pb2.Environment(
             urn=common_urns.environments.EXTERNAL.urn,
             payload=beam_runner_api_pb2.ExternalPayload(
                 endpoint=endpoints_pb2.ApiServiceDescriptor(
                     url='localhost:50000')).SerializeToString()))
     raw_config = ' {"url":"localhost:50000", "params":{"test":"test"}} '
     for env_config in (raw_config, raw_config.lstrip(),
                        raw_config.strip()):
         self.assertEqual(
             PortableRunner._create_environment(
                 PipelineOptions.from_dictionary({
                     'environment_type':
                     "EXTERNAL",
                     'environment_config':
                     env_config,
                 })),
             beam_runner_api_pb2.Environment(
                 urn=common_urns.environments.EXTERNAL.urn,
                 payload=beam_runner_api_pb2.ExternalPayload(
                     endpoint=endpoints_pb2.ApiServiceDescriptor(
                         url='localhost:50000'),
                     params={
                         "test": "test"
                     }).SerializeToString()))
     with self.assertRaises(ValueError):
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({
                 'environment_type':
                 "EXTERNAL",
                 'environment_config':
                 '{invalid}',
             }))
     with self.assertRaises(ValueError) as ctx:
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({
                 'environment_type':
                 "EXTERNAL",
                 'environment_config':
                 '{"params":{"test":"test"}}',
             }))
     self.assertIn('External environment endpoint must be set.',
                   ctx.exception.args)
예제 #5
0
    def test_java_sdk_harness_dedup(self):
        pipeline_options = PipelineOptions([
            '--experiments=beam_fn_api', '--experiments=use_unified_worker',
            '--temp_location', 'gs://any-location/temp'
        ])

        pipeline = Pipeline(options=pipeline_options)
        pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

        proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

        dummy_env_1 = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='apache/beam_java:dummy_tag')
                     ).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id_1'].CopyFrom(
            dummy_env_1)

        dummy_transform_1 = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id_1')
        proto_pipeline.components.transforms['dummy_transform_id_1'].CopyFrom(
            dummy_transform_1)

        dummy_env_2 = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='apache/beam_java:dummy_tag')
                     ).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id_2'].CopyFrom(
            dummy_env_2)

        dummy_transform_2 = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id_2')
        proto_pipeline.components.transforms['dummy_transform_id_2'].CopyFrom(
            dummy_transform_2)

        # Accessing non-public method for testing.
        apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
            proto_pipeline, dict(), pipeline_options)

        # Only one of 'dummy_env_id_1' or 'dummy_env_id_2' should be in the set of
        # environment IDs used by the proto after Java environment de-duping.
        env_ids_from_transforms = [
            proto_pipeline.components.transforms[transform_id].environment_id
            for transform_id in proto_pipeline.components.transforms
        ]
        if 'dummy_env_id_1' in env_ids_from_transforms:
            self.assertTrue('dummy_env_id_2' not in env_ids_from_transforms)
        else:
            self.assertTrue('dummy_env_id_2' in env_ids_from_transforms)
예제 #6
0
  def _create_environment(options):
    portable_options = options.view_as(PortableOptions)
    environment_urn = common_urns.environments.DOCKER.urn
    if portable_options.environment_type == 'DOCKER':
      environment_urn = common_urns.environments.DOCKER.urn
    elif portable_options.environment_type == 'PROCESS':
      environment_urn = common_urns.environments.PROCESS.urn
    elif portable_options.environment_type in ('EXTERNAL', 'LOOPBACK'):
      environment_urn = common_urns.environments.EXTERNAL.urn
    elif portable_options.environment_type:
      if portable_options.environment_type.startswith('beam:env:'):
        environment_urn = portable_options.environment_type
      else:
        raise ValueError(
            'Unknown environment type: %s' % portable_options.environment_type)

    if environment_urn == common_urns.environments.DOCKER.urn:
      docker_image = (
          portable_options.environment_config
          or PortableRunner.default_docker_image())
      return beam_runner_api_pb2.Environment(
          url=docker_image,
          urn=common_urns.environments.DOCKER.urn,
          payload=beam_runner_api_pb2.DockerPayload(
              container_image=docker_image
          ).SerializeToString())
    elif environment_urn == common_urns.environments.PROCESS.urn:
      config = json.loads(portable_options.environment_config)
      return beam_runner_api_pb2.Environment(
          urn=common_urns.environments.PROCESS.urn,
          payload=beam_runner_api_pb2.ProcessPayload(
              os=(config.get('os') or ''),
              arch=(config.get('arch') or ''),
              command=config.get('command'),
              env=(config.get('env') or '')
          ).SerializeToString())
    elif environment_urn == common_urns.environments.EXTERNAL.urn:
      return beam_runner_api_pb2.Environment(
          urn=common_urns.environments.EXTERNAL.urn,
          payload=beam_runner_api_pb2.ExternalPayload(
              endpoint=endpoints_pb2.ApiServiceDescriptor(
                  url=portable_options.environment_config)
          ).SerializeToString())
    else:
      return beam_runner_api_pb2.Environment(
          urn=environment_urn,
          payload=(portable_options.environment_config.encode('ascii')
                   if portable_options.environment_config else None))
예제 #7
0
 def create_pipeline(self):
   # Must be GRPC so we can send data and split requests concurrent
   # to the bundle process request.
   return beam.Pipeline(
       runner=fn_api_runner.FnApiRunner(
           default_environment=beam_runner_api_pb2.Environment(
               urn=python_urns.EMBEDDED_PYTHON_GRPC)))
예제 #8
0
    def _make_beam_pipeline(self) -> beam.Pipeline:
        """Makes beam pipeline."""
        # TODO(b/142684737): refactor when beam support multi-processing by args.
        pipeline_options = PipelineOptions(self._beam_pipeline_args)
        parallelism = pipeline_options.view_as(
            DirectOptions).direct_num_workers

        if parallelism == 0:
            try:
                parallelism = multiprocessing.cpu_count()
            except NotImplementedError as e:
                absl.logging.warning('Cannot get cpu count: %s' % e)
                parallelism = 1
            pipeline_options.view_as(
                DirectOptions).direct_num_workers = parallelism

        absl.logging.info('Using %d process(es) for Beam pipeline execution.' %
                          parallelism)

        if parallelism > 1:
            if beam_runner_api_pb2:
                env = beam_runner_api_pb2.Environment(
                    urn=python_urns.SUBPROCESS_SDK,
                    payload=b'%s -m apache_beam.runners.worker.sdk_worker_main'
                    % (sys.executable or sys.argv[0]).encode('ascii'))
            else:
                env = environments.SubprocessSDKEnvironment(
                    command_string=
                    '%s -m apache_beam.runners.worker.sdk_worker_main' %
                    (sys.executable or sys.argv[0]))
            return beam.Pipeline(
                options=pipeline_options,
                runner=fn_api_runner.FnApiRunner(default_environment=env))

        return beam.Pipeline(argv=self._beam_pipeline_args)
예제 #9
0
 def to_runner_api(self, context):
     urn, typed_param = self.to_runner_api_parameter(context)
     return beam_runner_api_pb2.Environment(
         urn=urn,
         payload=typed_param.SerializeToString() if isinstance(
             typed_param, message.Message) else typed_param if
         (isinstance(typed_param, bytes)
          or typed_param is None) else typed_param.encode('utf-8'))
예제 #10
0
 def test__create_default_environment(self):
     docker_image = PortableRunner.default_docker_image()
     self.assertEqual(
         PortableRunner._create_environment(
             PipelineOptions.from_dictionary({})),
         beam_runner_api_pb2.Environment(
             urn=common_urns.environments.DOCKER.urn,
             payload=beam_runner_api_pb2.DockerPayload(
                 container_image=docker_image).SerializeToString()))
예제 #11
0
 def test__create_docker_environment(self):
   docker_image = 'py-docker'
   self.assertEqual(
       PortableRunner._create_environment(PipelineOptions.from_dictionary({
           'environment_type': 'DOCKER',
           'environment_config': docker_image,
       })), beam_runner_api_pb2.Environment(
           urn=common_urns.environments.DOCKER.urn,
           payload=beam_runner_api_pb2.DockerPayload(
               container_image=docker_image
           ).SerializeToString()))
예제 #12
0
 def to_runner_api(self, context):
   # type: (PipelineContext) -> beam_runner_api_pb2.Environment
   urn, typed_param = self.to_runner_api_parameter(context)
   return beam_runner_api_pb2.Environment(
       urn=urn,
       payload=typed_param.SerializeToString() if isinstance(
           typed_param, message.Message) else typed_param if
       (isinstance(typed_param, bytes) or
        typed_param is None) else typed_param.encode('utf-8'),
       capabilities=self.capabilities(),
       dependencies=self.artifacts())
예제 #13
0
    def test_sdk_harness_container_images_get_set(self):

        pipeline_options = PipelineOptions([
            '--experiments=beam_fn_api', '--experiments=use_unified_worker',
            '--temp_location', 'gs://any-location/temp'
        ])

        pipeline = Pipeline(options=pipeline_options)
        pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

        test_environment = DockerEnvironment(
            container_image='test_default_image')
        proto_pipeline, _ = pipeline.to_runner_api(
            return_context=True, default_environment=test_environment)

        # We have to manually add environments since Dataflow only sets
        # 'sdkHarnessContainerImages' when there are at least two environments.
        dummy_env = beam_runner_api_pb2.Environment(
            urn=common_urns.environments.DOCKER.urn,
            payload=(beam_runner_api_pb2.DockerPayload(
                container_image='dummy_image')).SerializeToString())
        proto_pipeline.components.environments['dummy_env_id'].CopyFrom(
            dummy_env)

        dummy_transform = beam_runner_api_pb2.PTransform(
            environment_id='dummy_env_id')
        proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
            dummy_transform)

        env = apiclient.Environment(
            [],  # packages
            pipeline_options,
            '2.0.0',  # any environment version
            FAKE_PIPELINE_URL,
            proto_pipeline,
            _sdk_image_overrides={
                '.*dummy.*': 'dummy_image',
                '.*test.*': 'test_default_image'
            })
        worker_pool = env.proto.workerPools[0]

        # For the test, a third environment get added since actual default
        # container image for Dataflow is different from 'test_default_image'
        # we've provided above.
        self.assertEqual(3, len(worker_pool.sdkHarnessContainerImages))

        # Container image should be overridden by a Dataflow specific URL.
        self.assertTrue(
            str.startswith(
                (worker_pool.sdkHarnessContainerImages[0]).containerImage,
                'gcr.io/cloud-dataflow/v1beta3/python'))
예제 #14
0
        def get_or_create_environment_with_resource_hints(
            template_env_id,
            resource_hints,
        ):  # type: (str, Dict[str, bytes]) -> str
            """Creates an environment that has necessary hints and returns its id."""
            template_env = self.environments.get_proto_from_id(template_env_id)
            cloned_env = beam_runner_api_pb2.Environment()
            cloned_env.CopyFrom(template_env)
            cloned_env.resource_hints.clear()
            cloned_env.resource_hints.update(resource_hints)

            return self.environments.get_by_proto(
                cloned_env,
                label='environment_with_resource_hints',
                deduplicate=True)
예제 #15
0
 def test_environment_override_translation(self):
   self.default_properties.append('--experiments=beam_fn_api')
   self.default_properties.append('--worker_harness_container_image=FOO')
   remote_runner = DataflowRunner()
   p = Pipeline(remote_runner,
                options=PipelineOptions(self.default_properties))
   (p | ptransform.Create([1, 2, 3])  # pylint: disable=expression-not-assigned
    | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
    | ptransform.GroupByKey())
   p.run()
   self.assertEqual(
       list(remote_runner.proto_pipeline.components.environments.values()),
       [beam_runner_api_pb2.Environment(
           urn=common_urns.environments.DOCKER.urn,
           payload=beam_runner_api_pb2.DockerPayload(
               container_image='FOO').SerializeToString())])
예제 #16
0
 def __init__(self, proto=None, default_environment_url=None):
     if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
         proto = beam_runner_api_pb2.Components(
             coders=dict(proto.coders.items()),
             windowing_strategies=dict(proto.windowing_strategies.items()),
             environments=dict(proto.environments.items()))
     for name, cls in self._COMPONENT_TYPES.items():
         setattr(self, name,
                 _PipelineContextMap(self, cls, getattr(proto, name, None)))
     if default_environment_url:
         self._default_environment_id = self.environments.get_id(
             Environment(
                 beam_runner_api_pb2.Environment(
                     url=default_environment_url)))
     else:
         self._default_environment_id = None
예제 #17
0
  def test_default_environment_get_set(self):

    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    test_environment = DockerEnvironment(container_image='test_default_image')
    proto_pipeline, _ = pipeline.to_runner_api(
        return_context=True, default_environment=test_environment)

    dummy_env = beam_runner_api_pb2.Environment(
        urn=common_urns.environments.DOCKER.urn,
        payload=(
            beam_runner_api_pb2.DockerPayload(
                container_image='dummy_image')).SerializeToString())
    proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env)

    dummy_transform = beam_runner_api_pb2.PTransform(
        environment_id='dummy_env_id')
    proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
        dummy_transform)

    env = apiclient.Environment(
        [],  # packages
        pipeline_options,
        '2.0.0',  # any environment version
        FAKE_PIPELINE_URL,
        proto_pipeline,
        _sdk_image_overrides={
            '.*dummy.*': 'dummy_image', '.*test.*': 'test_default_image'
        })
    worker_pool = env.proto.workerPools[0]

    self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages))

    images_from_proto = [
        sdk_info.containerImage
        for sdk_info in worker_pool.sdkHarnessContainerImages
    ]
    self.assertIn('test_default_image', images_from_proto)
예제 #18
0
  def test_pipeline_sdk_not_overridden(self):
    pipeline_options = PipelineOptions([
        '--experiments=beam_fn_api',
        '--experiments=use_unified_worker',
        '--temp_location',
        'gs://any-location/temp',
        '--worker_harness_container_image=dummy_prefix/dummy_name:dummy_tag'
    ])

    pipeline = Pipeline(options=pipeline_options)
    pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned

    proto_pipeline, _ = pipeline.to_runner_api(return_context=True)

    dummy_env = beam_runner_api_pb2.Environment(
        urn=common_urns.environments.DOCKER.urn,
        payload=(
            beam_runner_api_pb2.DockerPayload(
                container_image='dummy_prefix/dummy_name:dummy_tag')
        ).SerializeToString())
    proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env)

    dummy_transform = beam_runner_api_pb2.PTransform(
        environment_id='dummy_env_id')
    proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
        dummy_transform)

    # Accessing non-public method for testing.
    apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
        proto_pipeline, dict(), pipeline_options)

    self.assertIsNotNone(2, len(proto_pipeline.components.environments))

    from apache_beam.utils import proto_utils
    found_override = False
    for env in proto_pipeline.components.environments.values():
      docker_payload = proto_utils.parse_Bytes(
          env.payload, beam_runner_api_pb2.DockerPayload)
      if docker_payload.container_image.startswith(
          names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
        found_override = True

    self.assertFalse(found_override)
예제 #19
0
 def __init__(self, proto=None, default_environment_url=None):
     if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
         proto = beam_runner_api_pb2.Components(
             coders=dict(proto.coders.items()),
             windowing_strategies=dict(proto.windowing_strategies.items()),
             environments=dict(proto.environments.items()))
     for name, cls in self._COMPONENT_TYPES.items():
         setattr(self, name,
                 _PipelineContextMap(self, cls, getattr(proto, name, None)))
     if default_environment_url:
         self._default_environment_id = self.environments.get_id(
             Environment(
                 beam_runner_api_pb2.Environment(
                     url=default_environment_url,
                     urn=common_urns.environments.DOCKER.urn,
                     payload=beam_runner_api_pb2.DockerPayload(
                         container_image=default_environment_url).
                     SerializeToString())))
     else:
         self._default_environment_id = None
예제 #20
0
 def _make_beam_pipeline(self) -> beam.Pipeline:
     """Makes beam pipeline."""
     pipeline_options = PipelineOptions(self._beam_pipeline_args)
     direct_num_workers = pipeline_options.view_as(
         DirectOptions).direct_num_workers
     # TODO(b/141578059): refactor when direct_num_workers can be directly used
     # as argv. Currently runner and options need to be set separately for
     # multi-process, in other cases, use beam_pipeline_args to set pipeline.
     # Default value for direct_num_workers is 1 if unset:
     #   https://github.com/apache/beam/blob/master/sdks/python/apache_beam/options/pipeline_options.py
     if direct_num_workers > 1:
         tf.logging.info('Multi-process with %d workers' %
                         direct_num_workers)
         return beam.Pipeline(
             options=pipeline_options,
             runner=fn_api_runner.
             FnApiRunner(default_environment=beam_runner_api_pb2.Environment(
                 urn=python_urns.SUBPROCESS_SDK,
                 payload=b'%s -m apache_beam.runners.worker.sdk_worker_main'
                 % (sys.executable or sys.argv[0]).encode('ascii'))))
     else:
         return beam.Pipeline(argv=self._beam_pipeline_args)
예제 #21
0
  def __init__(
      self,
      default_environment=None,
      bundle_repeat=0,
      use_state_iterables=False):
    """Creates a new Fn API Runner.

    Args:
      default_environment: the default environment to use for UserFns.
      bundle_repeat: replay every bundle this many extra times, for profiling
          and debugging
      use_state_iterables: Intentionally split gbk iterables over state API
          (for testing)
    """
    super(FnApiRunner, self).__init__()
    self._last_uid = -1
    self._default_environment = (
        default_environment
        or beam_runner_api_pb2.Environment(urn=python_urns.EMBEDDED_PYTHON))
    self._bundle_repeat = bundle_repeat
    self._progress_frequency = None
    self._profiler_factory = None
    self._use_state_iterables = use_state_iterables
예제 #22
0
def generate_examples(input_transform,
                      output_dir,
                      problem_name,
                      splits,
                      min_hop_size_seconds,
                      max_hop_size_seconds,
                      num_replications,
                      min_pitch,
                      max_pitch,
                      encode_performance_fn,
                      encode_score_fns=None,
                      augment_fns=None,
                      absolute_timing=False,
                      random_crop_length=None):
    """Generate data for a Score2Perf problem.

  Args:
    input_transform: The input PTransform object that reads input NoteSequence
        protos, or dictionary mapping split names to such PTransform objects.
        Should produce `(id, NoteSequence)` tuples.
    output_dir: The directory to write the resulting TFRecord file containing
        examples.
    problem_name: Name of the Tensor2Tensor problem, used as a base filename
        for generated data.
    splits: A dictionary of split names and their probabilities. Probabilites
        should add up to 1. If `input_filename` is a dictionary, this argument
        will be ignored.
    min_hop_size_seconds: Minimum hop size in seconds at which input
        NoteSequence protos can be split. Can also be a dictionary mapping split
        name to minimum hop size.
    max_hop_size_seconds: Maximum hop size in seconds at which input
        NoteSequence protos can be split. If zero or None, will not split at
        all. Can also be a dictionary mapping split name to maximum hop size.
    num_replications: Number of times input NoteSequence protos will be
        replicated prior to splitting.
    min_pitch: Minimum MIDI pitch value; notes with lower pitch will be dropped.
    max_pitch: Maximum MIDI pitch value; notes with greater pitch will be
        dropped.
    encode_performance_fn: Required performance encoding function.
    encode_score_fns: Optional dictionary of named score encoding functions.
    augment_fns: Optional list of data augmentation functions. Only applied in
        the 'train' split.
    absolute_timing: If True, each score will use absolute instead of tempo-
        relative timing. Since chord inference depends on having beats, the
        score will only contain melody.
    random_crop_length: If specified, crop each encoded performance to this
        length. Cannot be specified if using scores.

  Raises:
    ValueError: If split probabilities do not add up to 1, or if splits are not
        provided but `input_filename` is not a dictionary.
  """
    # Make sure Beam's log messages are not filtered.
    logging.getLogger().setLevel(logging.INFO)

    if isinstance(input_transform, dict):
        split_names = input_transform.keys()
    else:
        if not splits:
            raise ValueError(
                'Split probabilities must be provided if input is not presplit.'
            )
        split_names, split_probabilities = zip(*splits.items())
        cumulative_splits = list(
            zip(split_names, np.cumsum(split_probabilities)))
        if round(cumulative_splits[-1][1], 4) != 1.0:
            raise ValueError('Split probabilities must sum to 1; got %f' %
                             cumulative_splits[-1][1])

    # Check for existence of prior outputs. Since the number of shards may be
    # different, the prior outputs will not necessarily be overwritten and must
    # be deleted explicitly.
    output_filenames = [
        os.path.join(output_dir, '%s-%s.tfrecord' % (problem_name, split_name))
        for split_name in split_names
    ]
    for split_name, output_filename in zip(split_names, output_filenames):
        existing_output_filenames = tf.gfile.Glob(output_filename + '*')
        if existing_output_filenames:
            tf.logging.info(
                'Data files already exist for split %s in problem %s, deleting.',
                split_name, problem_name)
            for filename in existing_output_filenames:
                tf.gfile.Remove(filename)

    pipeline_options = beam.options.pipeline_options.PipelineOptions(
        FLAGS.pipeline_options.split(','))
    pipeline_options.view_as(DirectOptions).direct_num_workers = NUM_THREADS

    with beam.Pipeline(
            options=pipeline_options,
            runner=fn_api_runner.
            FnApiRunner(default_environment=beam_runner_api_pb2.Environment(
                urn=python_urns.SUBPROCESS_SDK,
                payload=b'%s -m apache_beam.runners.worker.sdk_worker_main' %
                sys.executable.encode('ascii')))) as p:
        if isinstance(input_transform, dict):
            # Input data is already partitioned into splits.
            split_partitions = [
                p | 'input_transform_%s' % split_name >>
                input_transform[split_name] for split_name in split_names
            ]
        else:
            # Read using a single PTransform.
            p |= 'input_transform' >> input_transform
            split_partitions = p | 'partition' >> beam.Partition(
                functools.partial(select_split, cumulative_splits),
                len(cumulative_splits))

        for split_name, output_filename, s in zip(split_names,
                                                  output_filenames,
                                                  split_partitions):
            if isinstance(min_hop_size_seconds, dict):
                min_hop = min_hop_size_seconds[split_name]
            else:
                min_hop = min_hop_size_seconds
            if isinstance(max_hop_size_seconds, dict):
                max_hop = max_hop_size_seconds[split_name]
            else:
                max_hop = max_hop_size_seconds
            s |= 'preshuffle_%s' % split_name >> beam.Reshuffle()
            s |= 'filter_invalid_notes_%s' % split_name >> beam.Map(
                functools.partial(filter_invalid_notes, min_pitch, max_pitch))
            s |= 'extract_examples_%s' % split_name >> beam.ParDo(
                ExtractExamplesDoFn(
                    min_hop, max_hop, num_replications if split_name == 'train'
                    else 1, encode_performance_fn, encode_score_fns,
                    augment_fns if split_name == 'train' else None,
                    absolute_timing, random_crop_length))
            s |= 'shuffle_%s' % split_name >> beam.Reshuffle()
            s |= 'write_%s' % split_name >> beam.io.WriteToTFRecord(
                output_filename,
                coder=beam.coders.ProtoCoder(tf.train.Example))
예제 #23
0
 def create_pipeline(self):
     return beam.Pipeline(runner=fn_api_runner.FnApiRunner(
         default_environment=beam_runner_api_pb2.Environment(
             urn=python_urns.EMBEDDED_PYTHON_GRPC, payload=b'2')))
예제 #24
0
    def _create_environment(options):
        portable_options = options.view_as(PortableOptions)
        # Do not set a Runner. Otherwise this can cause problems in Java's
        # PipelineOptions, i.e. ClassNotFoundException, if the corresponding Runner
        # does not exist in the Java SDK. In portability, the entry point is clearly
        # defined via the JobService.
        portable_options.view_as(StandardOptions).runner = None
        environment_urn = common_urns.environments.DOCKER.urn
        if portable_options.environment_type == 'DOCKER':
            environment_urn = common_urns.environments.DOCKER.urn
        elif portable_options.environment_type == 'PROCESS':
            environment_urn = common_urns.environments.PROCESS.urn
        elif portable_options.environment_type in ('EXTERNAL', 'LOOPBACK'):
            environment_urn = common_urns.environments.EXTERNAL.urn
        elif portable_options.environment_type:
            if portable_options.environment_type.startswith('beam:env:'):
                environment_urn = portable_options.environment_type
            else:
                raise ValueError('Unknown environment type: %s' %
                                 portable_options.environment_type)

        if environment_urn == common_urns.environments.DOCKER.urn:
            docker_image = (portable_options.environment_config
                            or PortableRunner.default_docker_image())
            return beam_runner_api_pb2.Environment(
                urn=common_urns.environments.DOCKER.urn,
                payload=beam_runner_api_pb2.DockerPayload(
                    container_image=docker_image).SerializeToString())
        elif environment_urn == common_urns.environments.PROCESS.urn:
            config = json.loads(portable_options.environment_config)
            return beam_runner_api_pb2.Environment(
                urn=common_urns.environments.PROCESS.urn,
                payload=beam_runner_api_pb2.ProcessPayload(
                    os=(config.get('os') or ''),
                    arch=(config.get('arch') or ''),
                    command=config.get('command'),
                    env=(config.get('env') or '')).SerializeToString())
        elif environment_urn == common_urns.environments.EXTERNAL.urn:

            def looks_like_json(environment_config):
                import re
                return re.match(r'\s*\{.*\}\s*$', environment_config)

            if looks_like_json(portable_options.environment_config):
                config = json.loads(portable_options.environment_config)
                url = config.get('url')
                if not url:
                    raise ValueError(
                        'External environment endpoint must be set.')
                params = config.get('params')
            else:
                url = portable_options.environment_config
                params = None

            return beam_runner_api_pb2.Environment(
                urn=common_urns.environments.EXTERNAL.urn,
                payload=beam_runner_api_pb2.ExternalPayload(
                    endpoint=endpoints_pb2.ApiServiceDescriptor(url=url),
                    params=params).SerializeToString())
        else:
            return beam_runner_api_pb2.Environment(
                urn=environment_urn,
                payload=(portable_options.environment_config.encode('ascii')
                         if portable_options.environment_config else None))