def test_extract_pipelineparam_with_types(self): """Test _extract_pipelineparams.""" p1 = PipelineParam( name='param1', op_name='op1', param_type={'customized_type_a': { 'property_a': 'value_a' }}) p2 = PipelineParam(name='param2', param_type='customized_type_b') p3 = PipelineParam( name='param3', value='value3', param_type={'customized_type_c': { 'property_c': 'value_c' }}) stuff_chars = ' between ' payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3) params = _extract_pipelineparams(payload) self.assertListEqual([p1, p2, p3], params) # Expecting the _extract_pipelineparam to dedup the pipelineparams among all the payloads. payload = [ str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3) ] params = _extract_pipelineparams(payload) self.assertListEqual([p1, p2, p3], params)
def test_extract_pipelineparams_from_dict(self): """Test extract_pipeleineparams.""" p1 = PipelineParam(name='param1', op_name='op1') p2 = PipelineParam(name='param2') configmap = V1ConfigMap(data={str(p1): str(p2)}) params = extract_pipelineparams_from_any(configmap) self.assertListEqual(sorted([p1, p2]), sorted(params))
def test_str_repr(self): """Test string representation.""" p = PipelineParam(name='param1', op_name='op1') self.assertEqual('{{pipelineparam:op=op1;name=param1}}', str(p)) p = PipelineParam(name='param2') self.assertEqual('{{pipelineparam:op=;name=param2}}', str(p)) p = PipelineParam(name='param3', value='value3') self.assertEqual('{{pipelineparam:op=;name=param3}}', str(p))
def test_extract_pipelineparams_from_any(self): """Test extract_pipeleineparams.""" p1 = PipelineParam(name='param1', op_name='op1') p2 = PipelineParam(name='param2') p3 = PipelineParam(name='param3', value='value3') stuff_chars = ' between ' payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3) container = V1Container(name=p1, image=p2, env=[V1EnvVar(name="foo", value=payload)]) params = extract_pipelineparams_from_any(container) self.assertListEqual(sorted([p1, p2, p3]), sorted(params))
def test_extract_pipelineparams(self): """Test _extract_pipeleineparams.""" p1 = PipelineParam(name='param1', op_name='op1') p2 = PipelineParam(name='param2') p3 = PipelineParam(name='param3', value='value3') stuff_chars = ' between ' payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3) params = _extract_pipelineparams(payload) self.assertListEqual([p1, p2, p3], params) payload = [ str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3) ] params = _extract_pipelineparams(payload) self.assertListEqual([p1, p2, p3], params)
def test_reusable_component_warnings(self): op1 = load_component_from_text('''\ implementation: container: image: busybox ''') with warnings.catch_warnings(record=True) as warning_messages: op1() deprecation_messages = list( str(message) for message in warning_messages if message.category == DeprecationWarning) self.assertListEqual(deprecation_messages, []) with self.assertWarnsRegex(FutureWarning, expected_regex='reusable'): kfp.dsl.ContainerOp(name='name', image='image') with self.assertWarnsRegex(FutureWarning, expected_regex='reusable'): kfp.dsl.ContainerOp( name='name', image='image', arguments=[PipelineParam('param1'), PipelineParam('param2')])
def test_graphcomponent_basic(self): """Test graph_component decorator metadata.""" @graph_component def flip_component(flip_result): with dsl.Condition(flip_result == 'heads'): flip_component(flip_result) with Pipeline('pipeline') as p: param = PipelineParam(name='param') flip_component(param) self.assertEqual(1, len(p.groups)) self.assertEqual(1, len(p.groups[0].groups)) # pipeline self.assertEqual(1, len( p.groups[0].groups[0].groups)) # flip_component self.assertEqual(1, len( p.groups[0].groups[0].groups[0].groups)) # condition self.assertEqual(0, len(p.groups[0].groups[0].groups[0].groups[0] .groups)) # recursive flip_component recursive_group = p.groups[0].groups[0].groups[0].groups[0] self.assertTrue(recursive_group.recursive_ref is not None) self.assertEqual(1, len(recursive_group.inputs)) self.assertEqual('param', recursive_group.inputs[0].name)
def test_invalid(self): """Invalid pipeline param name and op_name.""" with self.assertRaises(ValueError): p = PipelineParam(name='123_abc')