예제 #1
0
    def testProtoExecPropertyInvalidField(self):
        # Access a repeated field.
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".some_invalid_field"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        with self.assertRaises(AttributeError):
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context)
예제 #2
0
    def testProtoWithoutSerializationFormat(self):
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        with self.assertRaises(ValueError):
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context)
예제 #3
0
    def testProtoExecPropertyPrimitiveField(self):
        # Access a non-message type proto field
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          proto_field_path: ".tags"
          proto_field_path: "[1]"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        self.assertEqual(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context), "1.15.0-gpu")
예제 #4
0
 def _assert_serialized_proto_b64encode_eq(self, serialize_format,
                                           expected):
     placeholder_expression = """
     operator {
       base64_encode_op {
         expression {
           operator {
             proto_op {
               expression {
                 placeholder {
                   type: EXEC_PROPERTY
                   key: "proto_property"
                 }
               }
               proto_schema {
                 message_type: "tfx.components.infra_validator.ServingSpec"
               }
               serialization_format: """ + serialize_format + """
             }
           }
         }
       }
     }
   """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     resolved_base64_str = placeholder_utils.resolve_placeholder_expression(
         pb, self._resolution_context)
     decoded = base64.urlsafe_b64decode(resolved_base64_str).decode()
     self.assertEqual(decoded, expected)
예제 #5
0
    def testProtoRuntimeInfoNoneAccess(self):
        # Access a missing platform config.
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: RUNTIME_INFO
              key: "platform_config"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          proto_field_path: ".tags"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        self.assertIsNone(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._none_resolution_context))
예제 #6
0
 def testBase64EncodeOperator(self):
     placeholder_expression = """
   operator {
     base64_encode_op {
       expression {
         operator {
           proto_op {
             expression {
               placeholder {
                 type: EXEC_PROPERTY
                 key: "proto_property"
               }
             }
             proto_schema {
               message_type: "tfx.components.infra_validator.ServingSpec"
             }
             proto_field_path: ".tensorflow_serving"
             proto_field_path: ".tags"
             proto_field_path: "[0]"
           }
         }
       }
     }
   }
 """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     self.assertEqual(
         placeholder_utils.resolve_placeholder_expression(
             pb, self._resolution_context),
         base64.urlsafe_b64encode(b"latest").decode("ASCII"))
예제 #7
0
    def testProtoExecPropertyMessageFieldTextFormat(self):
        # Access a message type proto field
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          serialization_format: TEXT_FORMAT
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        # If proto_field_path points to a message type field, the message will
        # be rendered using text_format.
        self.assertEqual(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context),
            "tags: \"latest\"\ntags: \"1.15.0-gpu\"\n")
예제 #8
0
    def testArtifactUriNoneAccess(self):
        # Access a missing optional channel.
        placeholder_expression = """
      operator {
        artifact_uri_op {
          expression {
            operator {
              index_op{
                expression {
                  placeholder {
                    type: INPUT_ARTIFACT
                    key: "examples"
                  }
                }
                index: 0
              }
            }
          }
          split: "train"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        self.assertIsNone(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._none_resolution_context))
예제 #9
0
 def testRuntimeInfoPlaceholderSimple(self):
   placeholder_expression = """
     placeholder {
       type: RUNTIME_INFO
       key: "executor_output_uri"
     }
   """
   pb = text_format.Parse(placeholder_expression,
                          placeholder_pb2.PlaceholderExpression())
   self.assertEqual(
       placeholder_utils.resolve_placeholder_expression(
           pb, self._resolution_context), "test_executor_output_uri")
예제 #10
0
 def testSerializeDoubleValue(self):
     # Read a primitive value
     placeholder_expression = """
   value {
     double_value: 1.000000009
   }
 """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     self.assertEqual(
         placeholder_utils.resolve_placeholder_expression(
             pb, self._resolution_context), 1.000000009)
예제 #11
0
    def run_executor(
        self, execution_info: data_types.ExecutionInfo
    ) -> execution_result_pb2.ExecutorOutput:
        """Execute underlying component implementation."""

        context = placeholder_utils.ResolutionContext(
            exec_info=execution_info,
            executor_spec=self._executor_spec,
            platform_config=self._platform_config)

        component_executor_spec = (
            executor_specs.TemplatedExecutorContainerSpec(
                image=self._container_executor_spec.image,
                command=[
                    placeholder_utils.resolve_placeholder_expression(
                        cmd, context)
                    for cmd in self._container_executor_spec.commands
                ]))

        docker_config = docker_component_config.DockerComponentConfig()

        logging.info('Container spec: %s', vars(component_executor_spec))
        logging.info('Docker config: %s', vars(docker_config))

        # Call client.containers.run and wait for completion.
        # ExecutorContainerSpec follows k8s container spec which has different
        # names to Docker's container spec. It's intended to set command to docker's
        # entrypoint and args to docker's command.
        if docker_config.docker_server_url:
            client = docker.DockerClient(
                base_url=docker_config.docker_server_url)
        else:
            client = docker.from_env()

        run_args = docker_config.to_run_args()
        container = client.containers.run(
            image=component_executor_spec.image,
            command=component_executor_spec.command,
            detach=True,
            **run_args)

        # Streaming logs
        for log in container.logs(stream=True):
            logging.info('Docker: %s', log.decode('utf-8'))
        exit_code = container.wait()['StatusCode']
        if exit_code != 0:
            raise RuntimeError(
                'Container exited with error code "{}"'.format(exit_code))
        # TODO(b/141192583): Report data to publisher
        # - report container digest
        # - report replaced command line entrypoints
        # - report docker run args
        return execution_result_pb2.ExecutorOutput()
예제 #12
0
    def resolve_artifacts(
        self, metadata_handler: metadata.Metadata,
        input_dict: Dict[str, List[types.Artifact]]
    ) -> Optional[Dict[str, List[types.Artifact]]]:
        for placeholder_pb in self._predicates:
            context = placeholder_utils.ResolutionContext(
                exec_info=portable_data_types.ExecutionInfo(
                    input_dict=input_dict))
            predicate_result = placeholder_utils.resolve_placeholder_expression(
                placeholder_pb, context)
            if not isinstance(predicate_result, bool):
                raise ValueError(
                    "Predicate evaluates to a non-boolean result.")

            if not predicate_result:
                raise exceptions.SkipSignal("Predicate evaluates to False.")
        return input_dict
예제 #13
0
  def testExecutionInvocationPlaceholderSimple(self):
    # TODO(b/170469176): Update when proto encoding Operator is available.
    placeholder_expression = """
      placeholder {
        type: EXEC_INVOCATION
      }
    """
    pb = text_format.Parse(placeholder_expression,
                           placeholder_pb2.PlaceholderExpression())
    resolved = placeholder_utils.resolve_placeholder_expression(
        pb, self._resolution_context)
    got_exec_invocation = executor_invocation_pb2.ExecutorInvocation.FromString(
        base64.b64decode(resolved))

    want_exec_invocation = text_format.Parse(
        _WANT_EXEC_INVOCATION, executor_invocation_pb2.ExecutorInvocation())
    self.assertProtoEquals(want_exec_invocation, got_exec_invocation)
예제 #14
0
 def testProtoRuntimeInfoPlaceholderMessageField(self):
     placeholder_expression = """
   operator {
     proto_op {
       expression {
         placeholder {
           type: RUNTIME_INFO
           key: "executor_spec"
         }
       }
       proto_field_path: ".class_path"
     }
   }
 """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     self.assertEqual(
         placeholder_utils.resolve_placeholder_expression(
             pb, self._resolution_context), "test_class_path")
예제 #15
0
 def testProtoRuntimeInfoPlaceholderMessageField(self):
   placeholder_expression = """
     operator {
       proto_op {
         expression {
           placeholder {
             type: RUNTIME_INFO
             key: "node_info"
           }
         }
         proto_field_path: ".type"
         proto_field_path: ".name"
       }
     }
   """
   pb = text_format.Parse(placeholder_expression,
                          placeholder_pb2.PlaceholderExpression())
   self.assertEqual(
       placeholder_utils.resolve_placeholder_expression(
           pb, self._resolution_context), "infra_validator")
예제 #16
0
  def testExecPropertyIndex(self):
    # Access a specific index of an exec property list
    placeholder_expression = """
      operator {
        index_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "double_list_property"
            }
          }
          index: 1
        }
      }
    """
    pb = text_format.Parse(placeholder_expression,
                           placeholder_pb2.PlaceholderExpression())

    self.assertEqual(
        placeholder_utils.resolve_placeholder_expression(
            pb, self._resolution_context), "0.8")
예제 #17
0
 def testExecutionInvocationPlaceholderSimple(self):
   placeholder_expression = """
     operator {
       proto_op {
         expression {
           placeholder {
             type: EXEC_INVOCATION
           }
         }
         serialization_format: JSON
       }
     }
   """
   pb = text_format.Parse(placeholder_expression,
                          placeholder_pb2.PlaceholderExpression())
   resolved = placeholder_utils.resolve_placeholder_expression(
       pb, self._resolution_context)
   got_exec_invocation = json_format.Parse(
       resolved, execution_invocation_pb2.ExecutionInvocation())
   want_exec_invocation = text_format.Parse(
       _WANT_EXEC_INVOCATION, execution_invocation_pb2.ExecutionInvocation())
   self.assertProtoEquals(want_exec_invocation, got_exec_invocation)
예제 #18
0
    def testProtoSerializationJSON(self):
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          serialization_format: JSON
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        expected_json_serialization = """\
{
  "tensorflow_serving": {
    "tags": [
      "latest",
      "1.15.0-gpu"
    ]
  }
}"""

        self.assertEqual(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context), expected_json_serialization)
예제 #19
0
 def testConcatArtifactUri(self):
     pb = text_format.Parse(_CONCAT_SPLIT_URI_EXPRESSION,
                            placeholder_pb2.PlaceholderExpression())
     self.assertEqual(
         placeholder_utils.resolve_placeholder_expression(
             pb, self._resolution_context), "/tmp/train/1")
예제 #20
0
    def run_executor(
        self, execution_info: data_types.ExecutionInfo
    ) -> execution_result_pb2.ExecutorOutput:
        """Execute underlying component implementation.

    Runs executor container in a Kubernetes Pod and wait until it goes into
    `Succeeded` or `Failed` state.

    Args:
      execution_info: All the information that the launcher provides.

    Raises:
      RuntimeError: when the pod is in `Failed` state or unexpected failure from
      Kubernetes API.

    Returns:
      An ExecutorOutput instance

    """

        context = placeholder_utils.ResolutionContext(
            exec_info=execution_info,
            executor_spec=self._executor_spec,
            platform_config=self._platform_config)

        container_spec = executor_specs.TemplatedExecutorContainerSpec(
            image=self._container_executor_spec.image,
            command=[
                placeholder_utils.resolve_placeholder_expression(cmd, context)
                for cmd in self._container_executor_spec.commands
            ] or None,
            args=[
                placeholder_utils.resolve_placeholder_expression(arg, context)
                for arg in self._container_executor_spec.args
            ] or None,
        )

        pod_name = self._build_pod_name(execution_info)
        # TODO(hongyes): replace the default value from component config.
        try:
            namespace = kube_utils.get_kfp_namespace()
        except RuntimeError:
            namespace = 'kubeflow'

        pod_manifest = self._build_pod_manifest(pod_name, container_spec)
        core_api = kube_utils.make_core_v1_api()

        if kube_utils.is_inside_kfp():
            launcher_pod = kube_utils.get_current_kfp_pod(core_api)
            pod_manifest['spec'][
                'serviceAccount'] = launcher_pod.spec.service_account
            pod_manifest['spec'][
                'serviceAccountName'] = launcher_pod.spec.service_account_name
            pod_manifest['metadata'][
                'ownerReferences'] = container_common.to_swagger_dict(
                    launcher_pod.metadata.owner_references)
        else:
            pod_manifest['spec'][
                'serviceAccount'] = kube_utils.TFX_SERVICE_ACCOUNT
            pod_manifest['spec'][
                'serviceAccountName'] = kube_utils.TFX_SERVICE_ACCOUNT

        logging.info('Looking for pod "%s:%s".', namespace, pod_name)
        resp = kube_utils.get_pod(core_api, pod_name, namespace)
        if not resp:
            logging.info('Pod "%s:%s" does not exist. Creating it...',
                         namespace, pod_name)
            logging.info('Pod manifest: %s', pod_manifest)
            try:
                resp = core_api.create_namespaced_pod(namespace=namespace,
                                                      body=pod_manifest)
            except client.rest.ApiException as e:
                raise RuntimeError(
                    'Failed to created container executor pod!\nReason: %s\nBody: %s'
                    % (e.reason, e.body))

        # Wait up to 300 seconds for the pod to move from pending to another status.
        logging.info('Waiting for pod "%s:%s" to start.', namespace, pod_name)
        kube_utils.wait_pod(
            core_api,
            pod_name,
            namespace,
            exit_condition_lambda=kube_utils.pod_is_not_pending,
            condition_description='non-pending status',
            timeout_sec=300)

        logging.info('Start log streaming for pod "%s:%s".', namespace,
                     pod_name)
        try:
            logs = core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
                container=kube_utils.ARGO_MAIN_CONTAINER_NAME,
                follow=True,
                _preload_content=False).stream()
        except client.rest.ApiException as e:
            raise RuntimeError(
                'Failed to stream the logs from the pod!\nReason: %s\nBody: %s'
                % (e.reason, e.body))

        for log in logs:
            logging.info(log.decode().rstrip('\n'))

        # Wait indefinitely for the pod to complete.
        resp = kube_utils.wait_pod(
            core_api,
            pod_name,
            namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            condition_description='done state')

        if resp.status.phase == kube_utils.PodPhase.FAILED.value:
            raise RuntimeError('Pod "%s:%s" failed with status "%s".' %
                               (namespace, pod_name, resp.status))

        logging.info('Pod "%s:%s" is done.', namespace, pod_name)

        return execution_result_pb2.ExecutorOutput()