Ejemplo n.º 1
0
    def testDebugPlaceholder(self):
        pb = text_format.Parse(_CONCAT_SPLIT_URI_EXPRESSION,
                               placeholder_pb2.PlaceholderExpression())
        self.assertEqual(
            placeholder_utils.debug_str(pb),
            "(input(\"examples\")[0].split_uri(\"train\") + \"/\" + \"1\")")

        another_pb_str = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "serving_spec"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          serialization_format: TEXT_FORMAT
        }
      }
    """
        another_pb = text_format.Parse(another_pb_str,
                                       placeholder_pb2.PlaceholderExpression())
        self.assertEqual(
            placeholder_utils.debug_str(another_pb),
            "exec_property(\"serving_spec\").tensorflow_serving.serialize(TEXT_FORMAT)"
        )
Ejemplo n.º 2
0
  def encode(
      self,
      component_spec: Optional[types.ComponentSpec] = None) -> message.Message:
    """Encodes ExecutorSpec into an IR proto for compiling.

    This method will be used by DSL compiler to generate the corresponding IR.

    Args:
      component_spec: Optional. The ComponentSpec to help with the encoding.

    Returns:
      An executor spec proto.
    """
    result = executable_spec_pb2.ContainerExecutableSpec()
    result.image = self.image
    for command in self.command:
      cmd = result.commands.add()
      str_or_placeholder = self._recursively_encode(command)
      if isinstance(str_or_placeholder, str):
        expression = placeholder_pb2.PlaceholderExpression()
        expression.value.string_value = str_or_placeholder
        cmd.CopyFrom(expression)
      else:
        cmd.CopyFrom(self._recursively_encode(command).encode())

    for arg in self.args:
      cmd = result.args.add()
      str_or_placeholder = self._recursively_encode(arg)
      if isinstance(str_or_placeholder, str):
        expression = placeholder_pb2.PlaceholderExpression()
        expression.value.string_value = str_or_placeholder
        cmd.CopyFrom(expression)
      else:
        cmd.CopyFrom(self._recursively_encode(arg).encode())
    return result
Ejemplo n.º 3
0
    def encode(
        self,
        sub_expression_pb: placeholder_pb2.PlaceholderExpression,
        component_spec: Optional[Type['types.ComponentSpec']] = None
    ) -> placeholder_pb2.PlaceholderExpression:
        del component_spec  # Unused by ConcatOperator

        # ConcatOperator's proto version contains multiple placeholder expressions
        # as operands. For convenience, the Python version is implemented taking
        # only two operands.
        if self._right:
            # Resolve other expression
            if isinstance(self._right, Placeholder):
                other_expression = cast(Placeholder, self._right)
                other_expression_pb = other_expression.encode()
            else:
                other_expression_pb = placeholder_pb2.PlaceholderExpression()
                other_expression_pb.value.string_value = self._right

            # Try combining with existing concat operator
            if sub_expression_pb.HasField(
                    'operator') and sub_expression_pb.operator.HasField(
                        'concat_op'):
                sub_expression_pb.operator.concat_op.expressions.append(
                    other_expression_pb)
                return sub_expression_pb
            else:
                result = placeholder_pb2.PlaceholderExpression()
                result.operator.concat_op.expressions.extend(
                    [sub_expression_pb, other_expression_pb])
                return result

        if self._left:
            # Resolve other expression: left operand must be str
            other_expression_pb = placeholder_pb2.PlaceholderExpression()
            other_expression_pb.value.string_value = self._left

            # Try combining with existing concat operator
            if sub_expression_pb.HasField(
                    'operator') and sub_expression_pb.operator.HasField(
                        'concat_op'):
                sub_expression_pb.operator.concat_op.expressions.insert(
                    0, other_expression_pb)
                return sub_expression_pb
            else:
                result = placeholder_pb2.PlaceholderExpression()
                result.operator.concat_op.expressions.extend(
                    [other_expression_pb, sub_expression_pb])
                return result

        raise RuntimeError(
            'ConcatOperator does not have the other expression to concat.')
Ejemplo n.º 4
0
 def testBase64EncodeOperator(self):
     placeholder_expression = """
   operator {
     base64_encode_op {
       expression {
         operator {
           proto_op {
             expression {
               placeholder {
                 type: EXEC_PROPERTY
                 key: "proto_property"
               }
             }
             proto_schema {
               message_type: "tfx.components.infra_validator.ServingSpec"
             }
             proto_field_path: ".tensorflow_serving"
             proto_field_path: ".tags"
             proto_field_path: "[0]"
           }
         }
       }
     }
   }
 """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     self.assertEqual(
         placeholder_utils.resolve_placeholder_expression(
             pb, self._resolution_context),
         base64.urlsafe_b64encode(b"latest").decode("ASCII"))
Ejemplo n.º 5
0
    def testProtoRuntimeInfoNoneAccess(self):
        # Access a missing platform config.
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: RUNTIME_INFO
              key: "platform_config"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          proto_field_path: ".tags"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        self.assertIsNone(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._none_resolution_context))
Ejemplo n.º 6
0
    def testProtoExecPropertyMessageFieldTextFormat(self):
        # Access a message type proto field
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          serialization_format: TEXT_FORMAT
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        # If proto_field_path points to a message type field, the message will
        # be rendered using text_format.
        self.assertEqual(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context),
            "tags: \"latest\"\ntags: \"1.15.0-gpu\"\n")
Ejemplo n.º 7
0
    def testArtifactUriNoneAccess(self):
        # Access a missing optional channel.
        placeholder_expression = """
      operator {
        artifact_uri_op {
          expression {
            operator {
              index_op{
                expression {
                  placeholder {
                    type: INPUT_ARTIFACT
                    key: "examples"
                  }
                }
                index: 0
              }
            }
          }
          split: "train"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        self.assertIsNone(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._none_resolution_context))
Ejemplo n.º 8
0
    def testProtoExecPropertyInvalidField(self):
        # Access a repeated field.
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".some_invalid_field"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        with self.assertRaises(AttributeError):
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context)
Ejemplo n.º 9
0
  def encode(
      self,
      sub_expression_pb: placeholder_pb2.PlaceholderExpression,
      component_spec: Optional[types.ComponentSpec] = None
  ) -> placeholder_pb2.PlaceholderExpression:
    result = placeholder_pb2.PlaceholderExpression()
    result.operator.proto_op.expression.CopyFrom(sub_expression_pb)
    result.operator.proto_op.proto_field_path.extend(self._proto_field_path)

    # Attach proto descriptor if available through component spec.
    if (component_spec and sub_expression_pb.placeholder.type ==
        placeholder_pb2.Placeholder.EXEC_PROPERTY):
      exec_property_name = sub_expression_pb.placeholder.key
      if exec_property_name not in component_spec.PARAMETERS:
        raise ValueError(
            f"Can't find provided placeholder key {exec_property_name} in "
            "component spec's exec properties. "
            f"Available exec property keys: {component_spec.PARAMETERS.keys()}."
        )
      execution_param = component_spec.PARAMETERS[exec_property_name]
      if not issubclass(execution_param.type, message.Message):
        raise ValueError(
            "Can't apply placehodler proto operator on non-proto type "
            f"exec property. Got {execution_param.type}.")
      fd_set = result.operator.proto_op.proto_schema.file_descriptors
      for fd in proto_utils.gather_file_descriptors(
          execution_param.type.DESCRIPTOR):
        fd.CopyToProto(fd_set.file.add())

    return result
Ejemplo n.º 10
0
    def testProtoWithoutSerializationFormat(self):
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        with self.assertRaises(ValueError):
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context)
Ejemplo n.º 11
0
 def encode(
     self, sub_expression_pb: placeholder_pb2.PlaceholderExpression
 ) -> placeholder_pb2.PlaceholderExpression:
     result = placeholder_pb2.PlaceholderExpression()
     result.operator.artifact_value_op.expression.CopyFrom(
         sub_expression_pb)
     return result
Ejemplo n.º 12
0
 def _assert_serialized_proto_b64encode_eq(self, serialize_format,
                                           expected):
     placeholder_expression = """
     operator {
       base64_encode_op {
         expression {
           operator {
             proto_op {
               expression {
                 placeholder {
                   type: EXEC_PROPERTY
                   key: "proto_property"
                 }
               }
               proto_schema {
                 message_type: "tfx.components.infra_validator.ServingSpec"
               }
               serialization_format: """ + serialize_format + """
             }
           }
         }
       }
     }
   """
     pb = text_format.Parse(placeholder_expression,
                            placeholder_pb2.PlaceholderExpression())
     resolved_base64_str = placeholder_utils.resolve_placeholder_expression(
         pb, self._resolution_context)
     decoded = base64.urlsafe_b64decode(resolved_base64_str).decode()
     self.assertEqual(decoded, expected)
Ejemplo n.º 13
0
 def testEncodeWithKeys(self):
     channel = Channel(type=_MyType)
     channel_future = channel.future()[0].value
     actual_pb = channel_future.encode_with_keys(
         lambda channel: channel.type_name)
     expected_pb = text_format.Parse(
         """
   operator {
     artifact_value_op {
       expression {
         operator {
           index_op {
             expression {
               placeholder {
                 key: "MyTypeName"
               }
             }
           }
         }
       }
     }
   }
 """, placeholder_pb2.PlaceholderExpression())
     self.assertProtoEquals(actual_pb, expected_pb)
     self.assertIsNone(channel_future._key)
Ejemplo n.º 14
0
 def encode(self) -> placeholder_pb2.PlaceholderExpression:
     result = placeholder_pb2.PlaceholderExpression()
     result.placeholder.type = self._type
     result.placeholder.key = self._key
     for op in self._operators:
         result = op.encode(result)
     return result
Ejemplo n.º 15
0
 def encode(
     self, sub_expression_pb: placeholder_pb2.PlaceholderExpression
 ) -> placeholder_pb2.PlaceholderExpression:
     result = placeholder_pb2.PlaceholderExpression()
     result.operator.index_op.expression.CopyFrom(sub_expression_pb)
     result.operator.index_op.index = self._index
     return result
Ejemplo n.º 16
0
    def testProtoExecPropertyPrimitiveField(self):
        # Access a non-message type proto field
        placeholder_expression = """
      operator {
        proto_op {
          expression {
            placeholder {
              type: EXEC_PROPERTY
              key: "proto_property"
            }
          }
          proto_schema {
            message_type: "tfx.components.infra_validator.ServingSpec"
          }
          proto_field_path: ".tensorflow_serving"
          proto_field_path: ".tags"
          proto_field_path: "[1]"
        }
      }
    """
        pb = text_format.Parse(placeholder_expression,
                               placeholder_pb2.PlaceholderExpression())

        # Prepare FileDescriptorSet
        fd = descriptor_pb2.FileDescriptorProto()
        infra_validator_pb2.ServingSpec().DESCRIPTOR.file.CopyToProto(fd)
        pb.operator.proto_op.proto_schema.file_descriptors.file.append(fd)

        self.assertEqual(
            placeholder_utils.resolve_placeholder_expression(
                pb, self._resolution_context), "1.15.0-gpu")
Ejemplo n.º 17
0
def _encode_value_like(
    x: _ValueLikeType,
    channel_to_key_fn: Optional[Callable[['types.Channel'], str]] = None
) -> placeholder_pb2.PlaceholderExpression:
    """Encodes x to a placeholder expression proto."""

    if isinstance(x, ChannelWrappedPlaceholder):
        if channel_to_key_fn:
            # pylint: disable=protected-access
            old_key = x._key
            x._key = channel_to_key_fn(x.channel)
            result = x.encode()
            x._key = old_key
            # pylint: enable=protected-access
        else:
            result = x.encode()
        return result
    result = placeholder_pb2.PlaceholderExpression()
    if isinstance(x, int):
        result.value.int_value = x
    elif isinstance(x, float):
        result.value.double_value = x
    elif isinstance(x, str):
        result.value.string_value = x
    else:
        raise ValueError(
            f'x must be an int, float, str, or ChannelWrappedPlaceholder. x: {x}'
        )
    return result
Ejemplo n.º 18
0
 def encode(
     self, sub_expression_pb: placeholder_pb2.PlaceholderExpression
 ) -> placeholder_pb2.PlaceholderExpression:
     result = placeholder_pb2.PlaceholderExpression()
     result.operator.proto_op.expression.CopyFrom(sub_expression_pb)
     result.operator.proto_op.proto_field_path.extend(
         self._proto_field_path)
     return result
Ejemplo n.º 19
0
 def encode(
     self, sub_expression_pb: placeholder_pb2.PlaceholderExpression
 ) -> placeholder_pb2.PlaceholderExpression:
     result = placeholder_pb2.PlaceholderExpression()
     result.operator.artifact_uri_op.expression.CopyFrom(sub_expression_pb)
     if self._split:
         result.operator.artifact_uri_op.split = self._split
     return result
Ejemplo n.º 20
0
 def testProtoOperatorDescriptor(self):
     test_pb_filepath = os.path.join(os.path.dirname(__file__), 'testdata',
                                     'proto_placeholder_operator.pbtxt')
     with open(test_pb_filepath) as text_pb_file:
         expected_pb = text_format.ParseLines(
             text_pb_file, placeholder_pb2.PlaceholderExpression())
     placeholder = ph.exec_property('splits_config').analyze[0]
     component_spec = standard_component_specs.TransformSpec
     self.assertProtoEquals(placeholder.encode(component_spec), expected_pb)
Ejemplo n.º 21
0
    def encode(
        self,
        sub_expression_pb: placeholder_pb2.PlaceholderExpression,
        component_spec: Optional[Type['types.ComponentSpec']] = None
    ) -> placeholder_pb2.PlaceholderExpression:
        del component_spec  # Unused by B64EncodeOperator

        result = placeholder_pb2.PlaceholderExpression()
        result.operator.base64_encode_op.expression.CopyFrom(sub_expression_pb)
        return result
Ejemplo n.º 22
0
  def encode(
      self,
      sub_expression_pb: placeholder_pb2.PlaceholderExpression,
      component_spec: Optional[types.ComponentSpec] = None
  ) -> placeholder_pb2.PlaceholderExpression:
    del component_spec  # Unused by ArtifactValueOperator

    result = placeholder_pb2.PlaceholderExpression()
    result.operator.artifact_value_op.expression.CopyFrom(sub_expression_pb)
    return result
Ejemplo n.º 23
0
 def encode(
     self,
     component_spec: Optional[types.ComponentSpec] = None
 ) -> placeholder_pb2.PlaceholderExpression:
   result = placeholder_pb2.PlaceholderExpression()
   result.placeholder.type = self._type
   result.placeholder.key = self._key
   for op in self._operators:
     result = op.encode(result, component_spec)
   return result
Ejemplo n.º 24
0
 def testDoubleNegation(self):
     """Treat `not(not(a))` as `a`."""
     channel_1 = Channel(type=_MyType)
     channel_2 = Channel(type=_MyType)
     pred = channel_1.future().value < channel_2.future().value
     not_not_pred = ph.logical_not(ph.logical_not(pred))
     channel_to_key_map = {
         channel_1: 'channel_1_key',
         channel_2: 'channel_2_key',
     }
     actual_pb = not_not_pred.encode_with_keys(
         lambda channel: channel_to_key_map[channel])
     expected_pb = text_format.Parse(
         """
   operator {
     compare_op {
       lhs {
         operator {
           artifact_value_op {
             expression {
               operator {
                 index_op {
                   expression {
                     placeholder {
                       key: "channel_1_key"
                     }
                   }
                 }
               }
             }
           }
         }
       }
       rhs {
         operator {
           artifact_value_op {
             expression {
               operator {
                 index_op {
                   expression {
                     placeholder {
                       key: "channel_2_key"
                     }
                   }
                 }
               }
             }
           }
         }
       }
       op: LESS_THAN
     }
   }
 """, placeholder_pb2.PlaceholderExpression())
     self.assertProtoEquals(actual_pb, expected_pb)
Ejemplo n.º 25
0
 def _assert_placeholder_pb_equal_and_deepcopyable(self, placeholder,
                                                   expected_pb_str):
     """This function will delete the original copy of placeholder."""
     placeholder_copy = copy.deepcopy(placeholder)
     expected_pb = text_format.Parse(
         expected_pb_str, placeholder_pb2.PlaceholderExpression())
     # The original placeholder is deleted to verify deepcopy works. If caller
     # needs to use an instance of placeholder after calling to this function,
     # we can consider returning placeholder_copy.
     del placeholder
     self.assertProtoEquals(placeholder_copy.encode(), expected_pb)
Ejemplo n.º 26
0
    def testStrategy_IrMode_PredicateFalse(self):
        artifact_1 = standard_artifacts.Integer()
        artifact_1.uri = self.create_tempfile().full_path
        artifact_1.value = 0
        artifact_2 = standard_artifacts.Integer()
        artifact_2.uri = self.create_tempfile().full_path
        artifact_2.value = 42

        strategy = conditional_strategy.ConditionalStrategy([
            text_format.Parse(_TEST_PREDICATE_1,
                              placeholder_pb2.PlaceholderExpression()),
            text_format.Parse(_TEST_PREDICATE_2,
                              placeholder_pb2.PlaceholderExpression())
        ])
        input_dict = {
            'channel_1_key': [artifact_1],
            'channel_2_key': [artifact_2]
        }
        with self.assertRaises(exceptions.SkipSignal):
            strategy.resolve_artifacts(self._store, input_dict)
Ejemplo n.º 27
0
 def testProtoFutureValueOperator(self):
     test_pb_filepath = os.path.join(
         os.path.dirname(__file__), 'testdata',
         'proto_placeholder_future_value_operator.pbtxt')
     with open(test_pb_filepath) as text_pb_file:
         expected_pb = text_format.ParseLines(
             text_pb_file, placeholder_pb2.PlaceholderExpression())
     output_channel = Channel(type=standard_artifacts.Integer)
     placeholder = output_channel.future()[0].value
     placeholder._key = '_component.num'
     self.assertProtoEquals(placeholder.encode(), expected_pb)
Ejemplo n.º 28
0
    def encode(
        self,
        sub_expression_pb: placeholder_pb2.PlaceholderExpression,
        component_spec: Optional[Type[types.ComponentSpec]] = None
    ) -> placeholder_pb2.PlaceholderExpression:
        del component_spec  # Unused by IndexOperator

        result = placeholder_pb2.PlaceholderExpression()
        result.operator.index_op.expression.CopyFrom(sub_expression_pb)
        result.operator.index_op.index = self._index
        return result
Ejemplo n.º 29
0
    def encode(
        self,
        sub_expression_pb: placeholder_pb2.PlaceholderExpression,
        component_spec: Optional[Type['types.ComponentSpec']] = None
    ) -> placeholder_pb2.PlaceholderExpression:
        del component_spec  # Unused by ArtifactUriOperator

        result = placeholder_pb2.PlaceholderExpression()
        result.operator.artifact_uri_op.expression.CopyFrom(sub_expression_pb)
        if self._split:
            result.operator.artifact_uri_op.split = self._split
        return result
Ejemplo n.º 30
0
    def encode(
        self,
        sub_expression_pb: placeholder_pb2.PlaceholderExpression,
        component_spec: Optional[Type['types.ComponentSpec']] = None
    ) -> placeholder_pb2.PlaceholderExpression:
        del component_spec

        result = placeholder_pb2.PlaceholderExpression()
        result.operator.list_serialization_op.expression.CopyFrom(
            sub_expression_pb)
        result.operator.list_serialization_op.serialization_format = self._serialization_format.value
        return result