def testProtoExecPropertyInvalidField(self): # Access a repeated field. placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".some_invalid_field" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) with self.assertRaises(AttributeError): placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context)
def testProtoWithoutSerializationFormat(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) with self.assertRaises(ValueError): placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context)
def testProtoExecPropertyPrimitiveField(self): # Access a non-message type proto field placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" proto_field_path: "[1]" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "1.15.0-gpu")
def _assert_serialized_proto_b64encode_eq(self, serialize_format, expected): placeholder_expression = """ operator { base64_encode_op { expression { operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } serialization_format: """ + serialize_format + """ } } } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) resolved_base64_str = placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context) decoded = base64.urlsafe_b64decode(resolved_base64_str).decode() self.assertEqual(decoded, expected)
def testProtoRuntimeInfoNoneAccess(self): # Access a missing platform config. placeholder_expression = """ operator { proto_op { expression { placeholder { type: RUNTIME_INFO key: "platform_config" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) self.assertIsNone( placeholder_utils.resolve_placeholder_expression( pb, self._none_resolution_context))
def testBase64EncodeOperator(self): placeholder_expression = """ operator { base64_encode_op { expression { operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" proto_field_path: "[0]" } } } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), base64.urlsafe_b64encode(b"latest").decode("ASCII"))
def testProtoExecPropertyMessageFieldTextFormat(self): # Access a message type proto field placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" serialization_format: TEXT_FORMAT } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) # If proto_field_path points to a message type field, the message will # be rendered using text_format. self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "tags: \"latest\"\ntags: \"1.15.0-gpu\"\n")
def testArtifactUriNoneAccess(self): # Access a missing optional channel. placeholder_expression = """ operator { artifact_uri_op { expression { operator { index_op{ expression { placeholder { type: INPUT_ARTIFACT key: "examples" } } index: 0 } } } split: "train" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertIsNone( placeholder_utils.resolve_placeholder_expression( pb, self._none_resolution_context))
def testRuntimeInfoPlaceholderSimple(self): placeholder_expression = """ placeholder { type: RUNTIME_INFO key: "executor_output_uri" } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "test_executor_output_uri")
def testSerializeDoubleValue(self): # Read a primitive value placeholder_expression = """ value { double_value: 1.000000009 } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), 1.000000009)
def run_executor( self, execution_info: data_types.ExecutionInfo ) -> execution_result_pb2.ExecutorOutput: """Execute underlying component implementation.""" context = placeholder_utils.ResolutionContext( exec_info=execution_info, executor_spec=self._executor_spec, platform_config=self._platform_config) component_executor_spec = ( executor_specs.TemplatedExecutorContainerSpec( image=self._container_executor_spec.image, command=[ placeholder_utils.resolve_placeholder_expression( cmd, context) for cmd in self._container_executor_spec.commands ])) docker_config = docker_component_config.DockerComponentConfig() logging.info('Container spec: %s', vars(component_executor_spec)) logging.info('Docker config: %s', vars(docker_config)) # Call client.containers.run and wait for completion. # ExecutorContainerSpec follows k8s container spec which has different # names to Docker's container spec. It's intended to set command to docker's # entrypoint and args to docker's command. if docker_config.docker_server_url: client = docker.DockerClient( base_url=docker_config.docker_server_url) else: client = docker.from_env() run_args = docker_config.to_run_args() container = client.containers.run( image=component_executor_spec.image, command=component_executor_spec.command, detach=True, **run_args) # Streaming logs for log in container.logs(stream=True): logging.info('Docker: %s', log.decode('utf-8')) exit_code = container.wait()['StatusCode'] if exit_code != 0: raise RuntimeError( 'Container exited with error code "{}"'.format(exit_code)) # TODO(b/141192583): Report data to publisher # - report container digest # - report replaced command line entrypoints # - report docker run args return execution_result_pb2.ExecutorOutput()
def resolve_artifacts( self, metadata_handler: metadata.Metadata, input_dict: Dict[str, List[types.Artifact]] ) -> Optional[Dict[str, List[types.Artifact]]]: for placeholder_pb in self._predicates: context = placeholder_utils.ResolutionContext( exec_info=portable_data_types.ExecutionInfo( input_dict=input_dict)) predicate_result = placeholder_utils.resolve_placeholder_expression( placeholder_pb, context) if not isinstance(predicate_result, bool): raise ValueError( "Predicate evaluates to a non-boolean result.") if not predicate_result: raise exceptions.SkipSignal("Predicate evaluates to False.") return input_dict
def testExecutionInvocationPlaceholderSimple(self): # TODO(b/170469176): Update when proto encoding Operator is available. placeholder_expression = """ placeholder { type: EXEC_INVOCATION } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) resolved = placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context) got_exec_invocation = executor_invocation_pb2.ExecutorInvocation.FromString( base64.b64decode(resolved)) want_exec_invocation = text_format.Parse( _WANT_EXEC_INVOCATION, executor_invocation_pb2.ExecutorInvocation()) self.assertProtoEquals(want_exec_invocation, got_exec_invocation)
def testProtoRuntimeInfoPlaceholderMessageField(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: RUNTIME_INFO key: "executor_spec" } } proto_field_path: ".class_path" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "test_class_path")
def testProtoRuntimeInfoPlaceholderMessageField(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: RUNTIME_INFO key: "node_info" } } proto_field_path: ".type" proto_field_path: ".name" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "infra_validator")
def testExecPropertyIndex(self): # Access a specific index of an exec property list placeholder_expression = """ operator { index_op { expression { placeholder { type: EXEC_PROPERTY key: "double_list_property" } } index: 1 } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "0.8")
def testExecutionInvocationPlaceholderSimple(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_INVOCATION } } serialization_format: JSON } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) resolved = placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context) got_exec_invocation = json_format.Parse( resolved, execution_invocation_pb2.ExecutionInvocation()) want_exec_invocation = text_format.Parse( _WANT_EXEC_INVOCATION, execution_invocation_pb2.ExecutionInvocation()) self.assertProtoEquals(want_exec_invocation, got_exec_invocation)
def testProtoSerializationJSON(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } serialization_format: JSON } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) expected_json_serialization = """\ { "tensorflow_serving": { "tags": [ "latest", "1.15.0-gpu" ] } }""" self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), expected_json_serialization)
def testConcatArtifactUri(self): pb = text_format.Parse(_CONCAT_SPLIT_URI_EXPRESSION, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "/tmp/train/1")
def run_executor( self, execution_info: data_types.ExecutionInfo ) -> execution_result_pb2.ExecutorOutput: """Execute underlying component implementation. Runs executor container in a Kubernetes Pod and wait until it goes into `Succeeded` or `Failed` state. Args: execution_info: All the information that the launcher provides. Raises: RuntimeError: when the pod is in `Failed` state or unexpected failure from Kubernetes API. Returns: An ExecutorOutput instance """ context = placeholder_utils.ResolutionContext( exec_info=execution_info, executor_spec=self._executor_spec, platform_config=self._platform_config) container_spec = executor_specs.TemplatedExecutorContainerSpec( image=self._container_executor_spec.image, command=[ placeholder_utils.resolve_placeholder_expression(cmd, context) for cmd in self._container_executor_spec.commands ] or None, args=[ placeholder_utils.resolve_placeholder_expression(arg, context) for arg in self._container_executor_spec.args ] or None, ) pod_name = self._build_pod_name(execution_info) # TODO(hongyes): replace the default value from component config. try: namespace = kube_utils.get_kfp_namespace() except RuntimeError: namespace = 'kubeflow' pod_manifest = self._build_pod_manifest(pod_name, container_spec) core_api = kube_utils.make_core_v1_api() if kube_utils.is_inside_kfp(): launcher_pod = kube_utils.get_current_kfp_pod(core_api) pod_manifest['spec'][ 'serviceAccount'] = launcher_pod.spec.service_account pod_manifest['spec'][ 'serviceAccountName'] = launcher_pod.spec.service_account_name pod_manifest['metadata'][ 'ownerReferences'] = container_common.to_swagger_dict( launcher_pod.metadata.owner_references) else: pod_manifest['spec'][ 'serviceAccount'] = kube_utils.TFX_SERVICE_ACCOUNT pod_manifest['spec'][ 'serviceAccountName'] = kube_utils.TFX_SERVICE_ACCOUNT logging.info('Looking for pod "%s:%s".', namespace, pod_name) resp = kube_utils.get_pod(core_api, pod_name, namespace) if not resp: logging.info('Pod "%s:%s" does not exist. Creating it...', namespace, pod_name) logging.info('Pod manifest: %s', pod_manifest) try: resp = core_api.create_namespaced_pod(namespace=namespace, body=pod_manifest) except client.rest.ApiException as e: raise RuntimeError( 'Failed to created container executor pod!\nReason: %s\nBody: %s' % (e.reason, e.body)) # Wait up to 300 seconds for the pod to move from pending to another status. logging.info('Waiting for pod "%s:%s" to start.', namespace, pod_name) kube_utils.wait_pod( core_api, pod_name, namespace, exit_condition_lambda=kube_utils.pod_is_not_pending, condition_description='non-pending status', timeout_sec=300) logging.info('Start log streaming for pod "%s:%s".', namespace, pod_name) try: logs = core_api.read_namespaced_pod_log( name=pod_name, namespace=namespace, container=kube_utils.ARGO_MAIN_CONTAINER_NAME, follow=True, _preload_content=False).stream() except client.rest.ApiException as e: raise RuntimeError( 'Failed to stream the logs from the pod!\nReason: %s\nBody: %s' % (e.reason, e.body)) for log in logs: logging.info(log.decode().rstrip('\n')) # Wait indefinitely for the pod to complete. resp = kube_utils.wait_pod( core_api, pod_name, namespace, exit_condition_lambda=kube_utils.pod_is_done, condition_description='done state') if resp.status.phase == kube_utils.PodPhase.FAILED.value: raise RuntimeError('Pod "%s:%s" failed with status "%s".' % (namespace, pod_name, resp.status)) logging.info('Pod "%s:%s" is done.', namespace, pod_name) return execution_result_pb2.ExecutorOutput()