def test_coercion_fail(self): cases = [ ((str, 'label', 'producer'), r'producer.*compatible'), ((Tuple[str], ), r'two components'), # It seems that the only Unions that may be successfully coerced are not # Unions but Any (e.g. Union[Any, Tuple[Any, Any]] is Any). ((Union[str, int], ), r'compatible'), ((Union, ), r'compatible'), ((typehints.List[Any], ), r'compatible'), ] for args, regex in cases: with self.assertRaisesRegexp(ValueError, regex): typehints.coerce_to_kv_type(*args)
def test_coercion_success(self): cases = [ ((Any, ), typehints.KV[Any, Any]), ((typehints.KV[Any, Any], ), typehints.KV[Any, Any]), ((typehints.Tuple[str, int], ), typehints.KV[str, int]), ] for args, expected in cases: self.assertEqual(typehints.coerce_to_kv_type(*args), expected) self.assertCompatible(args[0], expected)
def visit_transform(self, transform_node): if not transform_node.transform: return if transform_node.transform.runner_api_requires_keyed_input(): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( pcoll.element_type, transform_node.full_label) if len(transform_node.outputs) == 1: # The runner often has expectations about the output types as well. output, = transform_node.outputs.values() output.element_type = transform_node.transform.infer_output_type( pcoll.element_type) for side_input in transform_node.transform.side_inputs: if side_input.requires_keyed_input(): side_input.pvalue.element_type = typehints.coerce_to_kv_type( side_input.pvalue.element_type, transform_node.full_label, side_input_producer=side_input.pvalue.producer. full_label)
def visit_transform(self, transform_node): if (transform_node.transform and transform_node.transform. runner_api_requires_keyed_input()): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( pcoll.element_type, transform_node.full_label) if len(transform_node.outputs) == 1: # The runner often has expectations about the output types as well. output, = transform_node.outputs.values() output.element_type = transform_node.transform.infer_output_type( pcoll.element_type)
def visit_transform(self, transform_node): if (transform_node.transform and transform_node.transform.runner_api_requires_keyed_input()): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( pcoll.element_type, transform_node.full_label) if len(transform_node.outputs) == 1: # The runner often has expectations about the output types as well. output, = transform_node.outputs.values() output.element_type = transform_node.transform.infer_output_type( pcoll.element_type)
def visit_transform(self, transform_node): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.transforms.core import GroupByKey, _GroupByKeyOnly if isinstance(transform_node.transform, (GroupByKey, _GroupByKeyOnly)): pcoll = transform_node.inputs[0] pcoll.element_type = typehints.coerce_to_kv_type( pcoll.element_type, transform_node.full_label) key_type, value_type = pcoll.element_type.tuple_types if transform_node.outputs: transform_node.outputs[None].element_type = typehints.KV[ key_type, typehints.Iterable[value_type]]
def visit_transform(self, transform_node): if isinstance(transform_node.transform, ParDo): new_side_inputs = [] for ix, side_input in enumerate(transform_node.side_inputs): access_pattern = side_input._side_input_data().access_pattern if access_pattern == common_urns.ITERABLE_SIDE_INPUT: # Add a map to ('', value) as Dataflow currently only handles # keyed side inputs. pipeline = side_input.pvalue.pipeline new_side_input = _DataflowIterableSideInput(side_input) new_side_input.pvalue = beam.pvalue.PCollection( pipeline, element_type=typehints.KV[ str, side_input.pvalue.element_type]) parent = transform_node.parent or pipeline._root_transform() map_to_void_key = beam.pipeline.AppliedPTransform( pipeline, beam.Map(lambda x: ('', x)), transform_node.full_label + '/MapToVoidKey%s' % ix, (side_input.pvalue,)) new_side_input.pvalue.producer = map_to_void_key map_to_void_key.add_output(new_side_input.pvalue) parent.add_part(map_to_void_key) transform_node.update_input_refcounts() elif access_pattern == common_urns.MULTIMAP_SIDE_INPUT: # Ensure the input coder is a KV coder and patch up the # access pattern to appease Dataflow. side_input.pvalue.element_type = typehints.coerce_to_kv_type( side_input.pvalue.element_type, transform_node.full_label) new_side_input = _DataflowMultimapSideInput(side_input) else: raise ValueError( 'Unsupported access pattern for %r: %r' % (transform_node.full_label, access_pattern)) new_side_inputs.append(new_side_input) transform_node.side_inputs = new_side_inputs transform_node.transform.side_inputs = new_side_inputs