def test_group_by_key_input_visitor_for_non_gbk_transforms(self): p = TestPipeline() pcoll = PCollection(p) for transform in [beam.Flatten(), beam.Map(lambda x: x)]: pcoll.element_type = typehints.Any common.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, transform, "label", {'in': pcoll})) self.assertEqual(pcoll.element_type, typehints.Any)
def test_group_by_key_input_visitor_with_invalid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll1.element_type = str pcoll2.element_type = typehints.Set err_msg = (r"Input to 'label' must be compatible with KV\[Any, Any\]. " "Found .*") for pcoll in [pcoll1, pcoll2]: with self.assertRaisesRegex(ValueError, err_msg): common.group_by_key_input_visitor().visit_transform( AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
def test_group_by_key_input_visitor_with_valid_inputs(self): p = TestPipeline() pcoll1 = PCollection(p) pcoll2 = PCollection(p) pcoll3 = PCollection(p) pcoll1.element_type = None pcoll2.element_type = typehints.Any pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] for pcoll in [pcoll1, pcoll2, pcoll3]: applied = AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]) applied.outputs[None] = PCollection(None) common.group_by_key_input_visitor().visit_transform(applied) self.assertEqual( pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])
def test_gbk_then_flatten_input_visitor(self): p = TestPipeline(runner=DataflowRunner(), options=PipelineOptions(self.default_properties)) none_str_pc = p | 'c1' >> beam.Create({None: 'a'}) none_int_pc = p | 'c2' >> beam.Create({None: 3}) flat = (none_str_pc, none_int_pc) | beam.Flatten() _ = flat | beam.GroupByKey() # This may change if type inference changes, but we assert it here # to make sure the check below is not vacuous. self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint) p.visit(common.group_by_key_input_visitor()) p.visit(DataflowRunner.flatten_input_visitor()) # The dataflow runner requires gbk input to be tuples *and* flatten # inputs to be equal to their outputs. Assert both hold. self.assertIsInstance(flat.element_type, typehints.TupleConstraint) self.assertEqual(flat.element_type, none_str_pc.element_type) self.assertEqual(flat.element_type, none_int_pc.element_type)