def _test_flatten_input_visitor(self, input_type, output_type, num_inputs): p = TestPipeline() inputs = [] for _ in range(num_inputs): input_pcoll = PCollection(p) input_pcoll.element_type = input_type inputs.append(input_pcoll) output_pcoll = PCollection(p) output_pcoll.element_type = output_type flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs) flatten.add_output(output_pcoll, None) DataflowRunner.flatten_input_visitor().visit_transform(flatten) for _ in range(num_inputs): self.assertEqual(inputs[0].element_type, output_type)
def test_group_by_key_input_visitor_with_valid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll3 = PCollection(p) for transform in [_GroupByKeyOnly(), beam.GroupByKey()]: pcoll1.element_type = None pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: applied = AppliedPTransform(None, transform, "label", [pcoll]) applied.outputs[None] = PCollection(None) DataflowRunner.group_by_key_input_visitor().visit_transform( applied) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])
def test_group_by_key_input_visitor_for_non_gbk_transforms(self): p = TestPipeline() pcoll = PCollection(p) for transform in [beam.Flatten(), beam.Map(lambda x: x)]: pcoll.element_type = typehints.Any DataflowRunner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", [pcoll])) self.assertEqual(pcoll.element_type, typehints.Any)
def test_group_by_key_input_visitor_with_valid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll3 = PCollection(p) pcoll1.element_type = None pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: applied = AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]) applied.outputs[None] = PCollection(None) DataflowRunner.group_by_key_input_visitor().visit_transform( applied) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])
def test_group_by_key_input_visitor_with_invalid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) for transform in [_GroupByKeyOnly(), beam.GroupByKey()]: pcoll1.element_type = typehints.TupleSequenceConstraint pcoll2.element_type = typehints.Set err_msg = "Input to GroupByKey must be of Tuple or Any type" for pcoll in [pcoll1, pcoll2]: with self.assertRaisesRegexp(ValueError, err_msg): DataflowRunner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", [pcoll]))
def test_side_input_visitor(self): p = TestPipeline() pc = p | beam.Create([]) transform = beam.Map(lambda x, y, z: (x, y, z), beam.pvalue.AsSingleton(pc), beam.pvalue.AsMultiMap(pc)) applied_transform = AppliedPTransform(None, transform, "label", [pc]) DataflowRunner.side_input_visitor().visit_transform(applied_transform) self.assertEqual(2, len(applied_transform.side_inputs)) for side_input in applied_transform.side_inputs: self.assertEqual(common_urns.side_inputs.MULTIMAP.urn, side_input._side_input_data().access_pattern)
def test_group_by_key_input_visitor_with_invalid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll1.element_type = str pcoll2.element_type = typehints.Set err_msg = (r"Input to 'label' must be compatible with KV\[Any, Any\]. " "Found .*") for pcoll in [pcoll1, pcoll2]: with self.assertRaisesRegex(ValueError, err_msg): DataflowRunner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]))
def test_group_by_key_input_visitor_with_valid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll3 = PCollection(p) for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: pcoll1.element_type = None pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: runner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", [pcoll])) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])