def test_group_by_key_input_visitor_with_invalid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: pcoll1.element_type = typehints.TupleSequenceConstraint pcoll2.element_type = typehints.Set err_msg = "Input to GroupByKey must be of Tuple or Any type" for pcoll in [pcoll1, pcoll2]: with self.assertRaisesRegexp(ValueError, err_msg): runner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", [pcoll]))
def test_group_by_key_input_visitor_with_valid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll3 = PCollection(p) for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: pcoll1.element_type = None pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: DataflowRunner.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", [pcoll])) self.assertEqual(pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])