def testDebugPlaceholder(self): pb = text_format.Parse(_CONCAT_SPLIT_URI_EXPRESSION, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.debug_str(pb), "(input(\"examples\")[0].split_uri(\"train\") + \"/\" + \"1\")") another_pb_str = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "serving_spec" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" serialization_format: TEXT_FORMAT } } """ another_pb = text_format.Parse(another_pb_str, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.debug_str(another_pb), "exec_property(\"serving_spec\").tensorflow_serving.serialize(TEXT_FORMAT)" )
def encode( self, component_spec: Optional[types.ComponentSpec] = None) -> message.Message: """Encodes ExecutorSpec into an IR proto for compiling. This method will be used by DSL compiler to generate the corresponding IR. Args: component_spec: Optional. The ComponentSpec to help with the encoding. Returns: An executor spec proto. """ result = executable_spec_pb2.ContainerExecutableSpec() result.image = self.image for command in self.command: cmd = result.commands.add() str_or_placeholder = self._recursively_encode(command) if isinstance(str_or_placeholder, str): expression = placeholder_pb2.PlaceholderExpression() expression.value.string_value = str_or_placeholder cmd.CopyFrom(expression) else: cmd.CopyFrom(self._recursively_encode(command).encode()) for arg in self.args: cmd = result.args.add() str_or_placeholder = self._recursively_encode(arg) if isinstance(str_or_placeholder, str): expression = placeholder_pb2.PlaceholderExpression() expression.value.string_value = str_or_placeholder cmd.CopyFrom(expression) else: cmd.CopyFrom(self._recursively_encode(arg).encode()) return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[Type['types.ComponentSpec']] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec # Unused by ConcatOperator # ConcatOperator's proto version contains multiple placeholder expressions # as operands. For convenience, the Python version is implemented taking # only two operands. if self._right: # Resolve other expression if isinstance(self._right, Placeholder): other_expression = cast(Placeholder, self._right) other_expression_pb = other_expression.encode() else: other_expression_pb = placeholder_pb2.PlaceholderExpression() other_expression_pb.value.string_value = self._right # Try combining with existing concat operator if sub_expression_pb.HasField( 'operator') and sub_expression_pb.operator.HasField( 'concat_op'): sub_expression_pb.operator.concat_op.expressions.append( other_expression_pb) return sub_expression_pb else: result = placeholder_pb2.PlaceholderExpression() result.operator.concat_op.expressions.extend( [sub_expression_pb, other_expression_pb]) return result if self._left: # Resolve other expression: left operand must be str other_expression_pb = placeholder_pb2.PlaceholderExpression() other_expression_pb.value.string_value = self._left # Try combining with existing concat operator if sub_expression_pb.HasField( 'operator') and sub_expression_pb.operator.HasField( 'concat_op'): sub_expression_pb.operator.concat_op.expressions.insert( 0, other_expression_pb) return sub_expression_pb else: result = placeholder_pb2.PlaceholderExpression() result.operator.concat_op.expressions.extend( [other_expression_pb, sub_expression_pb]) return result raise RuntimeError( 'ConcatOperator does not have the other expression to concat.')
def testBase64EncodeOperator(self): placeholder_expression = """ operator { base64_encode_op { expression { operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" proto_field_path: "[0]" } } } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), base64.urlsafe_b64encode(b"latest").decode("ASCII"))
def testProtoRuntimeInfoNoneAccess(self): # Access a missing platform config. placeholder_expression = """ operator { proto_op { expression { placeholder { type: RUNTIME_INFO key: "platform_config" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) self.assertIsNone( placeholder_utils.resolve_placeholder_expression( pb, self._none_resolution_context))
def testProtoExecPropertyMessageFieldTextFormat(self): # Access a message type proto field placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" serialization_format: TEXT_FORMAT } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) # If proto_field_path points to a message type field, the message will # be rendered using text_format. self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "tags: \"latest\"\ntags: \"1.15.0-gpu\"\n")
def testArtifactUriNoneAccess(self): # Access a missing optional channel. placeholder_expression = """ operator { artifact_uri_op { expression { operator { index_op{ expression { placeholder { type: INPUT_ARTIFACT key: "examples" } } index: 0 } } } split: "train" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) self.assertIsNone( placeholder_utils.resolve_placeholder_expression( pb, self._none_resolution_context))
def testProtoExecPropertyInvalidField(self): # Access a repeated field. placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".some_invalid_field" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) with self.assertRaises(AttributeError): placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context)
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[types.ComponentSpec] = None ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.operator.proto_op.expression.CopyFrom(sub_expression_pb) result.operator.proto_op.proto_field_path.extend(self._proto_field_path) # Attach proto descriptor if available through component spec. if (component_spec and sub_expression_pb.placeholder.type == placeholder_pb2.Placeholder.EXEC_PROPERTY): exec_property_name = sub_expression_pb.placeholder.key if exec_property_name not in component_spec.PARAMETERS: raise ValueError( f"Can't find provided placeholder key {exec_property_name} in " "component spec's exec properties. " f"Available exec property keys: {component_spec.PARAMETERS.keys()}." ) execution_param = component_spec.PARAMETERS[exec_property_name] if not issubclass(execution_param.type, message.Message): raise ValueError( "Can't apply placehodler proto operator on non-proto type " f"exec property. Got {execution_param.type}.") fd_set = result.operator.proto_op.proto_schema.file_descriptors for fd in proto_utils.gather_file_descriptors( execution_param.type.DESCRIPTOR): fd.CopyToProto(fd_set.file.add()) return result
def testProtoWithoutSerializationFormat(self): placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) with self.assertRaises(ValueError): placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context)
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.operator.artifact_value_op.expression.CopyFrom( sub_expression_pb) return result
def _assert_serialized_proto_b64encode_eq(self, serialize_format, expected): placeholder_expression = """ operator { base64_encode_op { expression { operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } serialization_format: """ + serialize_format + """ } } } } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) resolved_base64_str = placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context) decoded = base64.urlsafe_b64decode(resolved_base64_str).decode() self.assertEqual(decoded, expected)
def testEncodeWithKeys(self): channel = Channel(type=_MyType) channel_future = channel.future()[0].value actual_pb = channel_future.encode_with_keys( lambda channel: channel.type_name) expected_pb = text_format.Parse( """ operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "MyTypeName" } } } } } } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb) self.assertIsNone(channel_future._key)
def encode(self) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.placeholder.type = self._type result.placeholder.key = self._key for op in self._operators: result = op.encode(result) return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.operator.index_op.expression.CopyFrom(sub_expression_pb) result.operator.index_op.index = self._index return result
def testProtoExecPropertyPrimitiveField(self): # Access a non-message type proto field placeholder_expression = """ operator { proto_op { expression { placeholder { type: EXEC_PROPERTY key: "proto_property" } } proto_schema { message_type: "tfx.components.infra_validator.ServingSpec" } proto_field_path: ".tensorflow_serving" proto_field_path: ".tags" proto_field_path: "[1]" } } """ pb = text_format.Parse(placeholder_expression, placeholder_pb2.PlaceholderExpression()) # Prepare FileDescriptorSet fd = descriptor_pb2.FileDescriptorProto() infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd) pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd) self.assertEqual( placeholder_utils.resolve_placeholder_expression( pb, self._resolution_context), "1.15.0-gpu")
def _encode_value_like( x: _ValueLikeType, channel_to_key_fn: Optional[Callable[['types.Channel'], str]] = None ) -> placeholder_pb2.PlaceholderExpression: """Encodes x to a placeholder expression proto.""" if isinstance(x, ChannelWrappedPlaceholder): if channel_to_key_fn: # pylint: disable=protected-access old_key = x._key x._key = channel_to_key_fn(x.channel) result = x.encode() x._key = old_key # pylint: enable=protected-access else: result = x.encode() return result result = placeholder_pb2.PlaceholderExpression() if isinstance(x, int): result.value.int_value = x elif isinstance(x, float): result.value.double_value = x elif isinstance(x, str): result.value.string_value = x else: raise ValueError( f'x must be an int, float, str, or ChannelWrappedPlaceholder. x: {x}' ) return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.operator.proto_op.expression.CopyFrom(sub_expression_pb) result.operator.proto_op.proto_field_path.extend( self._proto_field_path) return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.operator.artifact_uri_op.expression.CopyFrom(sub_expression_pb) if self._split: result.operator.artifact_uri_op.split = self._split return result
def testProtoOperatorDescriptor(self): test_pb_filepath = os.path.join(os.path.dirname(__file__), 'testdata', 'proto_placeholder_operator.pbtxt') with open(test_pb_filepath) as text_pb_file: expected_pb = text_format.ParseLines( text_pb_file, placeholder_pb2.PlaceholderExpression()) placeholder = ph.exec_property('splits_config').analyze[0] component_spec = standard_component_specs.TransformSpec self.assertProtoEquals(placeholder.encode(component_spec), expected_pb)
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[Type['types.ComponentSpec']] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec # Unused by B64EncodeOperator result = placeholder_pb2.PlaceholderExpression() result.operator.base64_encode_op.expression.CopyFrom(sub_expression_pb) return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[types.ComponentSpec] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec # Unused by ArtifactValueOperator result = placeholder_pb2.PlaceholderExpression() result.operator.artifact_value_op.expression.CopyFrom(sub_expression_pb) return result
def encode( self, component_spec: Optional[types.ComponentSpec] = None ) -> placeholder_pb2.PlaceholderExpression: result = placeholder_pb2.PlaceholderExpression() result.placeholder.type = self._type result.placeholder.key = self._key for op in self._operators: result = op.encode(result, component_spec) return result
def testDoubleNegation(self): """Treat `not(not(a))` as `a`.""" channel_1 = Channel(type=_MyType) channel_2 = Channel(type=_MyType) pred = channel_1.future().value < channel_2.future().value not_not_pred = ph.logical_not(ph.logical_not(pred)) channel_to_key_map = { channel_1: 'channel_1_key', channel_2: 'channel_2_key', } actual_pb = not_not_pred.encode_with_keys( lambda channel: channel_to_key_map[channel]) expected_pb = text_format.Parse( """ operator { compare_op { lhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_1_key" } } } } } } } } rhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_2_key" } } } } } } } } op: LESS_THAN } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb)
def _assert_placeholder_pb_equal_and_deepcopyable(self, placeholder, expected_pb_str): """This function will delete the original copy of placeholder.""" placeholder_copy = copy.deepcopy(placeholder) expected_pb = text_format.Parse( expected_pb_str, placeholder_pb2.PlaceholderExpression()) # The original placeholder is deleted to verify deepcopy works. If caller # needs to use an instance of placeholder after calling to this function, # we can consider returning placeholder_copy. del placeholder self.assertProtoEquals(placeholder_copy.encode(), expected_pb)
def testStrategy_IrMode_PredicateFalse(self): artifact_1 = standard_artifacts.Integer() artifact_1.uri = self.create_tempfile().full_path artifact_1.value = 0 artifact_2 = standard_artifacts.Integer() artifact_2.uri = self.create_tempfile().full_path artifact_2.value = 42 strategy = conditional_strategy.ConditionalStrategy([ text_format.Parse(_TEST_PREDICATE_1, placeholder_pb2.PlaceholderExpression()), text_format.Parse(_TEST_PREDICATE_2, placeholder_pb2.PlaceholderExpression()) ]) input_dict = { 'channel_1_key': [artifact_1], 'channel_2_key': [artifact_2] } with self.assertRaises(exceptions.SkipSignal): strategy.resolve_artifacts(self._store, input_dict)
def testProtoFutureValueOperator(self): test_pb_filepath = os.path.join( os.path.dirname(__file__), 'testdata', 'proto_placeholder_future_value_operator.pbtxt') with open(test_pb_filepath) as text_pb_file: expected_pb = text_format.ParseLines( text_pb_file, placeholder_pb2.PlaceholderExpression()) output_channel = Channel(type=standard_artifacts.Integer) placeholder = output_channel.future()[0].value placeholder._key = '_component.num' self.assertProtoEquals(placeholder.encode(), expected_pb)
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[Type[types.ComponentSpec]] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec # Unused by IndexOperator result = placeholder_pb2.PlaceholderExpression() result.operator.index_op.expression.CopyFrom(sub_expression_pb) result.operator.index_op.index = self._index return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[Type['types.ComponentSpec']] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec # Unused by ArtifactUriOperator result = placeholder_pb2.PlaceholderExpression() result.operator.artifact_uri_op.expression.CopyFrom(sub_expression_pb) if self._split: result.operator.artifact_uri_op.split = self._split return result
def encode( self, sub_expression_pb: placeholder_pb2.PlaceholderExpression, component_spec: Optional[Type['types.ComponentSpec']] = None ) -> placeholder_pb2.PlaceholderExpression: del component_spec result = placeholder_pb2.PlaceholderExpression() result.operator.list_serialization_op.expression.CopyFrom( sub_expression_pb) result.operator.list_serialization_op.serialization_format = self._serialization_format.value return result