예제 #1
0
 def test_basic(self):
   """Test basic usage."""
   with Pipeline('somename') as p:
     self.assertTrue(Pipeline.get_default_pipeline() is not None)
     op1 = ContainerOp(name='op1', image='image')
     op2 = ContainerOp(name='op2', image='image')
     
   self.assertTrue(Pipeline.get_default_pipeline() is None)
   self.assertEqual(p.ops['op1'].name, 'op1')
   self.assertEqual(p.ops['op2'].name, 'op2')
예제 #2
0
    def test_decorator(self):
        """Test @pipeline decorator."""
        @pipeline(name='p1', description='description1')
        def my_pipeline1():
            pass

        @pipeline(name='p2', description='description2')
        def my_pipeline2():
            pass

        self.assertEqual(('p1', 'description1'),
                         Pipeline.get_pipeline_functions()[my_pipeline1])
        self.assertEqual(('p2', 'description2'),
                         Pipeline.get_pipeline_functions()[my_pipeline2])
예제 #3
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))
            with OpsGroup(group_type='exit_handler'):
                op1 = ContainerOp(name='op1', image='image')
                with OpsGroup(group_type='branch'):
                    op2 = ContainerOp(name='op2', image='image')
                    op3 = ContainerOp(name='op3', image='image')
                with OpsGroup(group_type='loop'):
                    op4 = ContainerOp(name='op4', image='image')

        self.assertEqual(1, len(p.groups))
        self.assertEqual(1, len(p.groups[0].groups))
        exit_handler_group = p.groups[0].groups[0]
        self.assertEqual('exit_handler', exit_handler_group.type)
        self.assertEqual(2, len(exit_handler_group.groups))
        self.assertEqual(1, len(exit_handler_group.ops))
        self.assertEqual('op1', exit_handler_group.ops[0].name)

        branch_group = exit_handler_group.groups[0]
        self.assertFalse(branch_group.groups)
        self.assertCountEqual([x.name for x in branch_group.ops],
                              ['op2', 'op3'])

        loop_group = exit_handler_group.groups[1]
        self.assertFalse(loop_group.groups)
        self.assertCountEqual([x.name for x in loop_group.ops], ['op4'])
예제 #4
0
    def test_decorator_metadata(self):
        """Test @pipeline decorator with metadata."""
        @pipeline(name='p1', description='description1')
        def my_pipeline1(a: {'Schema':
                             {
                                 'file_type': 'csv'
                             }} = 'good',
                         b: Integer() = 12):
            pass

        golden_meta = PipelineMeta(name='p1', description='description1')
        golden_meta.inputs.append(
            ParameterMeta(name='a',
                          description='',
                          param_type=TypeMeta(name='Schema',
                                              properties={'file_type': 'csv'}),
                          default='good'))
        golden_meta.inputs.append(
            ParameterMeta(name='b',
                          description='',
                          param_type=TypeMeta(name='Integer'),
                          default=12))

        pipeline_meta = Pipeline.get_pipeline_functions()[my_pipeline1]
        self.assertEqual(pipeline_meta, golden_meta)
예제 #5
0
    def test_graphcomponent_basic(self):
        """Test graph_component decorator metadata."""
        @graph_component
        def flip_component(flip_result):
            with dsl.Condition(flip_result == 'heads'):
                flip_component(flip_result)
            return {'flip_result': flip_result}

        with Pipeline('pipeline') as p:
            param = PipelineParam(name='param')
            flip_component(param)
            self.assertEqual(1, len(p.groups))
            self.assertEqual(1, len(p.groups[0].groups))  # pipeline
            self.assertEqual(1, len(
                p.groups[0].groups[0].groups))  # flip_component
            self.assertEqual(1, len(
                p.groups[0].groups[0].groups[0].groups))  # condition
            self.assertEqual(0,
                             len(p.groups[0].groups[0].groups[0].groups[0].
                                 groups))  # recursive flip_component
            recursive_group = p.groups[0].groups[0].groups[0].groups[0]
            self.assertTrue(recursive_group.recursive_ref is not None)
            self.assertEqual(1, len(recursive_group.inputs))
            self.assertEqual('param', recursive_group.inputs[0].name)
            original_group = p.groups[0].groups[0]
            self.assertTrue('flip_result' in original_group.outputs)
            self.assertEqual('param', original_group.outputs['flip_result'])
예제 #6
0
    def test_deprecation_warnings(self):
        """Test deprecation warnings."""
        with Pipeline('somename') as p:
            op = ContainerOp(name='op1', image='image')

        with self.assertWarns(PendingDeprecationWarning):
            op.env_variables = [V1EnvVar(name="foo", value="bar")]

        with self.assertWarns(PendingDeprecationWarning):
            op.image = 'image2'

        with self.assertWarns(PendingDeprecationWarning):
            op.set_memory_request('10M')

        with self.assertWarns(PendingDeprecationWarning):
            op.set_memory_limit('10M')

        with self.assertWarns(PendingDeprecationWarning):
            op.set_cpu_request('100m')

        with self.assertWarns(PendingDeprecationWarning):
            op.set_cpu_limit('1')

        with self.assertWarns(PendingDeprecationWarning):
            op.set_gpu_limit('1')

        with self.assertWarns(PendingDeprecationWarning):
            op.add_env_variable(V1EnvVar(name="foo", value="bar"))

        with self.assertWarns(PendingDeprecationWarning):
            op.add_volume_mount(
                V1VolumeMount(mount_path='/secret/gcp-credentials',
                              name='gcp-credentials'))
예제 #7
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline('somename') as p:
            param1 = PipelineParam('param1')
            param2 = PipelineParam('param2')
            op1 = (ContainerOp(
                name='op1',
                image='image',
                arguments=['%s hello %s %s' % (param1, param2, param1)],
                sidecars=[Sidecar(name='sidecar0', image='image0')],
                container_kwargs={
                    'env': [V1EnvVar(name='env1', value='value1')]
                },
                file_outputs={
                    'out1': '/tmp/b'
                }).add_sidecar(Sidecar(name='sidecar1',
                                       image='image1')).add_sidecar(
                                           Sidecar(name='sidecar2',
                                                   image='image2')))

        self.assertCountEqual([x.name for x in op1.inputs],
                              ['param1', 'param2'])
        self.assertCountEqual(list(op1.outputs.keys()), ['out1'])
        self.assertCountEqual([x.op_name for x in op1.outputs.values()],
                              ['op1'])
        self.assertEqual(op1.output.name, 'out1')
        self.assertCountEqual([sidecar.name for sidecar in op1.sidecars],
                              ['sidecar0', 'sidecar1', 'sidecar2'])
        self.assertCountEqual([sidecar.image for sidecar in op1.sidecars],
                              ['image0', 'image1', 'image2'])
        self.assertCountEqual([env.name for env in op1.container.env],
                              ['env1'])
예제 #8
0
 def test_after_op(self):
   """Test duplicate ops."""
   with Pipeline('somename') as p:
     op1 = ContainerOp(name='op1', image='image')
     op2 = ContainerOp(name='op2', image='image')
     op2.after(op1)
   self.assertCountEqual(op2.dependent_op_names, [op1.name])
예제 #9
0
    def test_type_check_with_same_representation(self):
        """Test type check at the decorator."""
        kfp.TYPE_CHECK = True

        @component
        def a_op(field_l: Integer()) -> {
                'field_m': GCSPath(),
                'field_n': {
                    'customized_type': {
                        'property_a': 'value_a',
                        'property_b': 'value_b'
                    }
                },
                'field_o': 'GcsUri'
        }:
            return ContainerOp(name='operator a',
                               image='gcr.io/ml-pipeline/component-a',
                               arguments=[
                                   '--field-l',
                                   field_l,
                               ],
                               file_outputs={
                                   'field_m': '/schema.txt',
                                   'field_n': '/feature.txt',
                                   'field_o': '/output.txt'
                               })

        @component
        def b_op(
            field_x: {
                'customized_type': {
                    'property_a': 'value_a',
                    'property_b': 'value_b'
                }
            }, field_y: 'GcsUri', field_z: GCSPath()
        ) -> {
                'output_model_uri': 'GcsUri'
        }:
            return ContainerOp(name='operator b',
                               image='gcr.io/ml-pipeline/component-b',
                               command=[
                                   'python3',
                                   field_x,
                               ],
                               arguments=[
                                   '--field-y',
                                   field_y,
                                   '--field-z',
                                   field_z,
                               ],
                               file_outputs={
                                   'output_model_uri': '/schema.txt',
                               })

        with Pipeline('pipeline') as p:
            a = a_op(field_l=12)
            b = b_op(field_x=a.outputs['field_n'],
                     field_y=a.outputs['field_o'],
                     field_z=a.outputs['field_m'])
예제 #10
0
    def test_type_check_with_openapi_schema(self):
        """Test type check at the decorator."""
        kfp.TYPE_CHECK = True

        @component
        def a_op(field_l: Integer()) -> {
                'field_m': 'GCSPath',
                'field_n': {
                    'customized_type': {
                        'openapi_schema_validator':
                        '{"type": "string", "pattern": "^gs://.*$"}'
                    }
                },
                'field_o': 'Integer'
        }:
            return ContainerOp(name='operator a',
                               image='gcr.io/ml-pipeline/component-b',
                               arguments=[
                                   '--field-l',
                                   field_l,
                               ],
                               file_outputs={
                                   'field_m': '/schema.txt',
                                   'field_n': '/feature.txt',
                                   'field_o': '/output.txt'
                               })

        @component
        def b_op(
            field_x: {
                'customized_type': {
                    'openapi_schema_validator':
                    '{"type": "string", "pattern": "^gs://.*$"}'
                }
            }, field_y: Integer(), field_z: GCSPath()
        ) -> {
                'output_model_uri': 'GcsUri'
        }:
            return ContainerOp(name='operator b',
                               image='gcr.io/ml-pipeline/component-a',
                               command=[
                                   'python3',
                                   field_x,
                               ],
                               arguments=[
                                   '--field-y',
                                   field_y,
                                   '--field-z',
                                   field_z,
                               ],
                               file_outputs={
                                   'output_model_uri': '/schema.txt',
                               })

        with Pipeline('pipeline') as p:
            a = a_op(field_l=12)
            b = b_op(field_x=a.outputs['field_n'],
                     field_y=a.outputs['field_o'],
                     field_z=a.outputs['field_m'])
예제 #11
0
 def test_invalid_exit_op(self):
     with self.assertRaises(ValueError):
         with Pipeline('somename') as p:
             op1 = ContainerOp(name='op1', image='image')
             exit_op = ContainerOp(name='exit', image='image')
             exit_op.after(op1)
             with ExitHandler(exit_op=exit_op):
                 pass
예제 #12
0
 def test_use_azure_secret(self):
     with Pipeline('somename') as p:
         op1 = ContainerOp(name='op1', image='image')
         op1 = op1.apply(use_azure_secret('foo'))
         assert len(op1.env_variables) == 4
         
         index = 0
         for expected in ['AZ_SUBSCRIPTION_ID', 'AZ_TENANT_ID', 'AZ_CLIENT_ID', 'AZ_CLIENT_SECRET']:
             assert op1.env_variables[index].name == expected
             assert op1.env_variables[index].value_from.secret_key_ref.name == 'foo'
             assert op1.env_variables[index].value_from.secret_key_ref.key == expected
             index += 1
예제 #13
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline('somename') as p:
            exit_op = ContainerOp(name='exit', image='image')
            with ExitHandler(exit_op=exit_op):
                op1 = ContainerOp(name='op1', image='image')

        exit_handler = p.groups[0].groups[0]
        self.assertEqual('exit_handler', exit_handler.type)
        self.assertEqual('exit', exit_handler.exit_op.name)
        self.assertEqual(1, len(exit_handler.ops))
        self.assertEqual('op1', exit_handler.ops[0].name)
예제 #14
0
 def test_basic(self):
   """Test basic usage."""
   with Pipeline('somename') as p:
     param1 = PipelineParam('param1')
     param2 = PipelineParam('param2')
     op1 = ContainerOp(name='op1', image='image',
         arguments=['%s hello %s %s' % (param1, param2, param1)],
         file_outputs={'out1': '/tmp/b'})
     
   self.assertCountEqual([x.name for x in op1.inputs], ['param1', 'param2'])
   self.assertCountEqual(list(op1.outputs.keys()), ['out1'])
   self.assertCountEqual([x.op_name for x in op1.outputs.values()], ['op1'])
   self.assertEqual(op1.output.name, 'out1')
예제 #15
0
    def test_basic(self):
        with Pipeline('somename') as p:
            param1 = 'pizza'
            condition1 = Condition(param1 == 'pizza')
            self.assertEqual(condition1.name, None)
            with condition1:
                pass
            self.assertEqual(condition1.name, 'condition-1')

            condition2 = Condition(param1 == 'pizza', '[param1 is pizza]')
            self.assertEqual(condition2.name, '[param1 is pizza]')
            with condition2:
                pass
            self.assertEqual(condition2.name, 'condition-[param1 is pizza]-2')
예제 #16
0
    def test_use_aws_secret(self):
        with Pipeline('somename') as p:
            op1 = ContainerOp(name='op1', image='image')
            op1 = op1.apply(
                use_aws_secret('myaws-secret', 'key_id', 'access_key'))
            assert len(op1.env_variables) == 2

            index = 0
            for expected in ['key_id', 'access_key']:
                assert op1.env_variables[index].name == expected
                assert op1.env_variables[
                    index].value_from.secret_key_ref.name == 'myaws-secret'
                assert op1.env_variables[
                    index].value_from.secret_key_ref.key == expected
                index += 1
예제 #17
0
    def test_recursive_opsgroups_with_prefix_names(self):
        """Test recursive opsgroups."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))

            # When a graph opsgraph is called.
            graph_ops_group_one = dsl._ops_group.Graph('foo_bar')
            graph_ops_group_one.__enter__()
            self.assertFalse(graph_ops_group_one.recursive_ref)
            self.assertEqual('graph-foo-bar-1', graph_ops_group_one.name)

            # Another graph opsgraph is called with the name as the prefix of the ops_group_one
            # when the previous graph opsgraphs is not finished.
            graph_ops_group_two = dsl._ops_group.Graph('foo')
            graph_ops_group_two.__enter__()
            self.assertFalse(graph_ops_group_two.recursive_ref)
예제 #18
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline("somename") as p:
            vol = VolumeOp(name="myvol_creation",
                           resource_name="myvol",
                           size="1Gi")
            op1 = ContainerOp(name="op1",
                              image="image",
                              pvolumes={"/mnt": vol.volume})
            op2 = ContainerOp(name="op2",
                              image="image",
                              pvolumes={"/data": op1.pvolume})

        self.assertEqual(vol.volume.dependent_names, [])
        self.assertEqual(op1.pvolume.dependent_names, [op1.name])
        self.assertEqual(op2.dependent_names, [op1.name])
예제 #19
0
  def _sanitize_and_inject_artifact(self, pipeline: dsl.Pipeline) -> None:
    """Sanitize operator/param names and inject pipeline artifact location. """

    # Sanitize operator names and param names
    sanitized_ops = {}

    for op in pipeline.ops.values():
      sanitized_name = sanitize_k8s_name(op.name)
      op.name = sanitized_name
      for param in op.outputs.values():
        param.name = sanitize_k8s_name(param.name, True)
        if param.op_name:
          param.op_name = sanitize_k8s_name(param.op_name)
      if op.output is not None and not isinstance(
          op.output, dsl._container_op._MultipleOutputsError):
        op.output.name = sanitize_k8s_name(op.output.name, True)
        op.output.op_name = sanitize_k8s_name(op.output.op_name)
      if op.dependent_names:
        op.dependent_names = [
            sanitize_k8s_name(name) for name in op.dependent_names
        ]
      if isinstance(op, dsl.ContainerOp) and op.file_outputs is not None:
        sanitized_file_outputs = {}
        for key in op.file_outputs.keys():
          sanitized_file_outputs[sanitize_k8s_name(key,
                                                   True)] = op.file_outputs[key]
        op.file_outputs = sanitized_file_outputs
      elif isinstance(op, dsl.ResourceOp) and op.attribute_outputs is not None:
        sanitized_attribute_outputs = {}
        for key in op.attribute_outputs.keys():
          sanitized_attribute_outputs[sanitize_k8s_name(key, True)] = \
            op.attribute_outputs[key]
        op.attribute_outputs = sanitized_attribute_outputs
      if isinstance(op, dsl.ContainerOp):
        if op.input_artifact_paths:
          op.input_artifact_paths = {
              sanitize_k8s_name(key, True): value
              for key, value in op.input_artifact_paths.items()
          }
        if op.artifact_arguments:
          op.artifact_arguments = {
              sanitize_k8s_name(key, True): value
              for key, value in op.artifact_arguments.items()
          }
      sanitized_ops[sanitized_name] = op
    pipeline.ops = sanitized_ops
예제 #20
0
    def _sanitize_and_inject_artifact(self,
                                      pipeline: dsl.Pipeline,
                                      pipeline_conf=None):
        """Sanitize operator/param names and inject pipeline artifact location."""

        # Sanitize operator names and param names
        sanitized_ops = {}
        # pipeline level artifact location
        artifact_location = pipeline_conf.artifact_location

        for op in pipeline.ops.values():
            # inject pipeline level artifact location into if the op does not have
            # an artifact location config already.
            if hasattr(op, "artifact_location"):
                if artifact_location and not op.artifact_location:
                    op.artifact_location = artifact_location

            sanitized_name = sanitize_k8s_name(op.name)
            op.name = sanitized_name
            for param in op.outputs.values():
                param.name = sanitize_k8s_name(param.name, True)
                if param.op_name:
                    param.op_name = sanitize_k8s_name(param.op_name)
            if op.output is not None and not isinstance(
                    op.output, dsl._container_op._MultipleOutputsError):
                op.output.name = sanitize_k8s_name(op.output.name, True)
                op.output.op_name = sanitize_k8s_name(op.output.op_name)
            if op.dependent_names:
                op.dependent_names = [
                    sanitize_k8s_name(name) for name in op.dependent_names
                ]
            if isinstance(op, dsl.ContainerOp) and op.file_outputs is not None:
                sanitized_file_outputs = {}
                for key in op.file_outputs.keys():
                    sanitized_file_outputs[sanitize_k8s_name(
                        key, True)] = op.file_outputs[key]
                op.file_outputs = sanitized_file_outputs
            elif isinstance(
                    op, dsl.ResourceOp) and op.attribute_outputs is not None:
                sanitized_attribute_outputs = {}
                for key in op.attribute_outputs.keys():
                    sanitized_attribute_outputs[sanitize_k8s_name(key, True)] = \
                      op.attribute_outputs[key]
                op.attribute_outputs = sanitized_attribute_outputs
            sanitized_ops[sanitized_name] = op
        pipeline.ops = sanitized_ops
예제 #21
0
    def test_basic_recursive_opsgroups(self):
        """Test recursive opsgroups."""
        with Pipeline('somename') as p:
            self.assertEqual(1, len(p.groups))

            # When a graph opsgraph is called.
            graph_ops_group_one = dsl._ops_group.Graph('hello')
            graph_ops_group_one.__enter__()
            self.assertFalse(graph_ops_group_one.recursive_ref)
            self.assertEqual('graph-hello-1', graph_ops_group_one.name)

            # Another graph opsgraph is called with the same name
            # when the previous graph opsgraphs is not finished.
            graph_ops_group_two = dsl._ops_group.Graph('hello')
            graph_ops_group_two.__enter__()
            self.assertTrue(graph_ops_group_two.recursive_ref)
            self.assertEqual(graph_ops_group_one,
                             graph_ops_group_two.recursive_ref)
예제 #22
0
    def test_after_method(self):
        """Test the after method."""
        with Pipeline("somename") as p:
            op1 = ContainerOp(name="op1", image="image")
            op2 = ContainerOp(name="op2", image="image").after(op1)
            op3 = ContainerOp(name="op3", image="image")
            vol1 = PipelineVolume(name="pipeline-volume")
            vol2 = vol1.after(op1)
            vol3 = vol2.after(op2)
            vol4 = vol3.after(op1, op2)
            vol5 = vol4.after(op3)

        self.assertEqual(vol1.dependent_names, [])
        self.assertEqual(vol2.dependent_names, [op1.name])
        self.assertEqual(vol3.dependent_names, [op2.name])
        self.assertEqual(sorted(vol4.dependent_names), [op1.name, op2.name])
        self.assertEqual(sorted(vol5.dependent_names),
                         [op1.name, op2.name, op3.name])
예제 #23
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline("somename") as p:
            param1 = PipelineParam("param1")
            param2 = PipelineParam("param2")
            vol = VolumeOp(
                name="myvol_creation",
                resource_name=param1,
                size=param2,
                annotations={"test": "annotation"}
            )

        self.assertCountEqual(
            [x.name for x in vol.inputs], ["param1", "param2"]
        )
        self.assertEqual(
            vol.k8s_resource.metadata.name,
            "{{workflow.name}}-%s" % PipelineParam("param1")
        )
        expected_attribute_outputs = {
            "manifest": "{}",
            "name": "{.metadata.name}",
            "size": "{.status.capacity.storage}"
        }
        self.assertEqual(vol.attribute_outputs, expected_attribute_outputs)
        expected_outputs = {
            "manifest": PipelineParam(name="manifest", op_name=vol.name),
            "name": PipelineParam(name="name", op_name=vol.name),
            "size": PipelineParam(name="size", op_name=vol.name)
        }
        self.assertEqual(vol.outputs, expected_outputs)
        self.assertEqual(
            vol.output,
            PipelineParam(name="name", op_name=vol.name)
        )
        self.assertEqual(vol.dependent_names, [])
        expected_volume = PipelineVolume(
            name="myvol-creation",
            persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(
                claim_name=PipelineParam(name="name", op_name=vol.name)
            )
        )
        self.assertEqual(vol.volume, expected_volume)
예제 #24
0
    def test_basic(self):
        """Test basic usage."""
        with Pipeline("somename") as p:
            param = PipelineParam("param")
            resource_metadata = k8s_client.V1ObjectMeta(name="my-resource")
            k8s_resource = k8s_client.V1PersistentVolumeClaim(
                api_version="v1",
                kind="PersistentVolumeClaim",
                metadata=resource_metadata)
            res = ResourceOp(name="resource",
                             k8s_resource=k8s_resource,
                             success_condition=param,
                             attribute_outputs={"test": "attr"})

        self.assertCountEqual([x.name for x in res.inputs], ["param"])
        self.assertEqual(res.name, "resource")
        self.assertEqual(res.resource.success_condition,
                         PipelineParam("param"))
        self.assertEqual(res.resource.action, "create")
        self.assertEqual(res.resource.failure_condition, None)
        self.assertEqual(res.resource.manifest, None)
        expected_attribute_outputs = {
            "manifest": "{}",
            "name": "{.metadata.name}",
            "test": "attr"
        }
        self.assertEqual(res.attribute_outputs, expected_attribute_outputs)
        expected_outputs = {
            "manifest": PipelineParam(name="manifest", op_name=res.name),
            "name": PipelineParam(name="name", op_name=res.name),
            "test": PipelineParam(name="test", op_name=res.name),
        }
        self.assertEqual(res.outputs, expected_outputs)
        self.assertEqual(res.output,
                         PipelineParam(name="test", op_name=res.name))
        self.assertEqual(res.dependent_names, [])
    def test_basic(self):
        """Test basic usage."""
        with Pipeline("somename") as p:
            param1 = PipelineParam("param1")
            param2 = PipelineParam("param2")
            vol = VolumeOp(
                name="myvol_creation",
                resource_name="myvol",
                size="1Gi",
            )
            snap1 = VolumeSnapshotOp(
                name="mysnap_creation",
                resource_name=param1,
                volume=vol.volume,
            )
            snap2 = VolumeSnapshotOp(
                name="mysnap_creation",
                resource_name="mysnap",
                pvc=param2,
                attribute_outputs={"size": "test"}
            )

        self.assertEqual(
            sorted([x.name for x in snap1.inputs]), ["name", "param1"]
        )
        self.assertEqual(
            sorted([x.name for x in snap2.inputs]), ["param2"]
        )
        expected_attribute_outputs_1 = {
            "manifest": "{}",
            "name": "{.metadata.name}",
            "size": "{.status.restoreSize}"
        }
        self.assertEqual(snap1.attribute_outputs, expected_attribute_outputs_1)
        expected_attribute_outputs_2 = {
            "manifest": "{}",
            "name": "{.metadata.name}",
            "size": "test"
        }
        self.assertEqual(snap2.attribute_outputs, expected_attribute_outputs_2)
        expected_outputs_1 = {
            "manifest": PipelineParam(name="manifest", op_name=snap1.name),
            "name": PipelineParam(name="name", op_name=snap1.name),
            "size": PipelineParam(name="name", op_name=snap1.name),
        }
        self.assertEqual(snap1.outputs, expected_outputs_1)
        expected_outputs_2 = {
            "manifest": PipelineParam(name="manifest", op_name=snap2.name),
            "name": PipelineParam(name="name", op_name=snap2.name),
            "size": PipelineParam(name="name", op_name=snap2.name),
        }
        self.assertEqual(snap2.outputs, expected_outputs_2)
        self.assertEqual(
            snap1.output,
            PipelineParam(name="name", op_name=snap1.name)
        )
        self.assertEqual(
            snap2.output,
            PipelineParam(name="size", op_name=snap2.name)
        )
        self.assertEqual(snap1.dependent_names, [])
        self.assertEqual(snap2.dependent_names, [])
        expected_snapshot_1 = k8s_client.V1TypedLocalObjectReference(
            api_group="snapshot.storage.k8s.io",
            kind="VolumeSnapshot",
            name=PipelineParam(name="name", op_name=vol.name)
        )
        self.assertEqual(snap1.snapshot, expected_snapshot_1)
        expected_snapshot_2 = k8s_client.V1TypedLocalObjectReference(
            api_group="snapshot.storage.k8s.io",
            kind="VolumeSnapshot",
            name=PipelineParam(name="param1")
        )
        self.assertEqual(snap2.snapshot, expected_snapshot_2)
예제 #26
0
 def test_nested_pipelines(self):
   """Test nested pipelines"""
   with self.assertRaises(Exception):
     with Pipeline('somename1') as p1:
       with Pipeline('somename2') as p2:
         pass
예제 #27
0
    def test_type_check_with_inconsistent_types_type_name(self):
        """Test type check at the decorator."""
        kfp.TYPE_CHECK = True

        @component
        def a_op(field_l: Integer()) -> {
                'field_m': {
                    'ArtifactB': {
                        'path_type': 'file',
                        'file_type': 'tsv'
                    }
                },
                'field_n': {
                    'customized_type': {
                        'property_a': 'value_a',
                        'property_b': 'value_b'
                    }
                },
                'field_o': 'Integer'
        }:
            return ContainerOp(name='operator a',
                               image='gcr.io/ml-pipeline/component-b',
                               arguments=[
                                   '--field-l',
                                   field_l,
                               ],
                               file_outputs={
                                   'field_m': '/schema.txt',
                                   'field_n': '/feature.txt',
                                   'field_o': '/output.txt'
                               })

        @component
        def b_op(
            field_x: {
                'customized_type_a': {
                    'property_a': 'value_a',
                    'property_b': 'value_b'
                }
            }, field_y: Integer(),
            field_z: {'ArtifactB': {
                'path_type': 'file',
                'file_type': 'tsv'
            }}
        ) -> {
                'output_model_uri': 'GcsUri'
        }:
            return ContainerOp(name='operator b',
                               image='gcr.io/ml-pipeline/component-a',
                               command=[
                                   'python3',
                                   field_x,
                               ],
                               arguments=[
                                   '--field-y',
                                   field_y,
                                   '--field-z',
                                   field_z,
                               ],
                               file_outputs={
                                   'output_model_uri': '/schema.txt',
                               })

        with self.assertRaises(InconsistentTypeException):
            with Pipeline('pipeline') as p:
                a = a_op(field_l=12)
                b = b_op(field_x=a.outputs['field_n'],
                         field_y=a.outputs['field_o'],
                         field_z=a.outputs['field_m'])