コード例 #1
0
ファイル: proto_utils_test.py プロジェクト: jay90099/tfx
 def test_dict_to_proto(self):
     self.assertEqual(
         proto_utils.dict_to_proto(self.test_dict, foo_pb2.TestProto()),
         self.test_proto)
     dict_with_obsolete_field = {'obsolete_field': 2, 'string_value': 'x'}
     self.assertEqual(
         proto_utils.dict_to_proto(dict_with_obsolete_field,
                                   foo_pb2.TestProto()),
         foo_pb2.TestProto(string_value='x'))
コード例 #2
0
ファイル: component_spec.py プロジェクト: Saiprasad16/tfx
 def _type_check_helper(value: Any, declared: Type):  # pylint: disable=g-bare-generic
     """Helper type-checking function."""
     if declared == Any:
         return
     if declared.__class__.__name__ in ('_GenericAlias', 'GenericMeta'):
         # Should be dict or list
         if declared.__origin__ in [Dict, dict]:  # pylint: disable=protected-access
             key_type, val_type = declared.__args__[
                 0], declared.__args__[1]
             if not isinstance(value, dict):
                 raise TypeError(
                     'Expecting a dict for parameter %r, but got %s '
                     'instead' % (arg_name, type(value)))
             for k, v in value.items():
                 if key_type != Any and not isinstance(k, key_type):
                     raise TypeError(
                         'Expecting key type %s for parameter %r, '
                         'but got %s instead.' %
                         (str(key_type), arg_name, type(k)))
                 if val_type != Any and not isinstance(v, val_type):
                     raise TypeError(
                         'Expecting value type %s for parameter %r, '
                         'but got %s instead.' %
                         (str(val_type), arg_name, type(v)))
         elif declared.__origin__ in [List, list]:  # pylint: disable=protected-access
             val_type = declared.__args__[0]
             if not isinstance(value, list):
                 raise TypeError('Expecting a list for parameter %r, '
                                 'but got %s instead.' %
                                 (arg_name, type(value)))
             if val_type == Any:
                 return
             for item in value:
                 if not isinstance(item, val_type):
                     raise TypeError(
                         'Expecting item type %s for parameter %r, '
                         'but got %s instead.' %
                         (str(val_type), arg_name, type(item)))
         else:
             raise TypeError('Unexpected type of parameter: %r' %
                             arg_name)
     elif isinstance(value, dict) and issubclass(
             declared, message.Message):
         # If a dict is passed in and is compared against a pb message,
         # do the type-check by converting it to pb message.
         dict_with_default = _make_default(value)
         proto_utils.dict_to_proto(dict_with_default, declared())
     else:
         if not isinstance(value, declared):
             raise TypeError('Expected type %s for parameter %r '
                             'but got %s instead.' %
                             (str(declared), arg_name, value))
コード例 #3
0
  def _parse_parameters(self, raw_args: Mapping[str, Any]):
    """Parse arguments to ComponentSpec."""
    unparsed_args = set(raw_args.keys())
    inputs = {}
    outputs = {}
    self.exec_properties = {}

    # First, check that the arguments are set.
    for arg_name, arg in itertools.chain(self.PARAMETERS.items(),
                                         self.INPUTS.items(),
                                         self.OUTPUTS.items()):
      if arg_name not in unparsed_args:
        if arg.optional:
          continue
        else:
          raise ValueError('Missing argument %r to %s.' %
                           (arg_name, self.__class__))
      unparsed_args.remove(arg_name)

      # Type check the argument.
      value = raw_args[arg_name]
      if arg.optional and value is None:
        continue
      arg.type_check(arg_name, value)

    # Populate the appropriate dictionary for each parameter type.
    for arg_name, arg in self.PARAMETERS.items():
      if arg.optional and arg_name not in raw_args:
        continue
      value = raw_args[arg_name]

      if (inspect.isclass(arg.type) and
          issubclass(arg.type, message.Message) and value and
          not _is_runtime_param(value)):
        if arg.use_proto:
          if isinstance(value, dict):
            value = proto_utils.dict_to_proto(value, arg.type())
          elif isinstance(value, str):
            value = proto_utils.json_to_proto(value, arg.type())
        else:
          # Create deterministic json string as it will be stored in metadata
          # for cache check.
          if isinstance(value, dict):
            value = json_utils.dumps(value)
          elif not isinstance(value, str):
            value = proto_utils.proto_to_json(value)

      self.exec_properties[arg_name] = value

    for arg_dict, param_dict in ((self.INPUTS, inputs), (self.OUTPUTS,
                                                         outputs)):
      for arg_name, arg in arg_dict.items():
        if arg.optional and not raw_args.get(arg_name):
          continue
        value = raw_args[arg_name]
        param_dict[arg_name] = value

    self.inputs = inputs
    self.outputs = outputs
コード例 #4
0
    def _type_check_helper(value: Any, declared: Type):  # pylint: disable=g-bare-generic
      """Helper type-checking function."""
      if isinstance(value, placeholder.Placeholder):
        if isinstance(value, placeholder.ChannelWrappedPlaceholder):
          return
        placeholders_involved = value.placeholders_involved()
        if (len(placeholders_involved) != 1 or not isinstance(
            placeholders_involved[0], placeholder.RuntimeInfoPlaceholder)):
          placeholders_involved_str = [
              x.__class__.__name__ for x in placeholders_involved
          ]
          raise TypeError(
              'Only simple RuntimeInfoPlaceholders are supported, but while '
              'checking parameter %r, the following placeholders were '
              'involved: %s' % (arg_name, placeholders_involved_str))
        if not issubclass(declared, str):
          raise TypeError(
              'Cannot use Placeholders except for str parameter, but parameter '
              '%r was of type %s' % (arg_name, declared))
        return

      is_runtime_param = _is_runtime_param(value)
      value = _make_default(value)
      if declared == Any:
        return
      if declared.__class__.__name__ in ('_GenericAlias', 'GenericMeta'):
        # Should be dict or list
        if declared.__origin__ in [Dict, dict]:  # pylint: disable=protected-access
          key_type, val_type = declared.__args__[0], declared.__args__[1]
          if not isinstance(value, dict):
            raise TypeError('Expecting a dict for parameter %r, but got %s '
                            'instead' % (arg_name, type(value)))
          for k, v in value.items():
            if key_type != Any and not isinstance(k, key_type):
              raise TypeError('Expecting key type %s for parameter %r, '
                              'but got %s instead.' %
                              (str(key_type), arg_name, type(k)))
            if val_type != Any and not isinstance(v, val_type):
              raise TypeError('Expecting value type %s for parameter %r, '
                              'but got %s instead.' %
                              (str(val_type), arg_name, type(v)))
        elif declared.__origin__ in [List, list]:  # pylint: disable=protected-access
          val_type = declared.__args__[0]
          if not isinstance(value, list):
            raise TypeError('Expecting a list for parameter %r, '
                            'but got %s instead.' % (arg_name, type(value)))
          if val_type == Any:
            return
          for item in value:
            if not isinstance(item, val_type):
              raise TypeError('Expecting item type %s for parameter %r, '
                              'but got %s instead.' %
                              (str(val_type), arg_name, type(item)))
        else:
          raise TypeError('Unexpected type of parameter: %r' % arg_name)
      elif isinstance(value, dict) and issubclass(declared, message.Message):
        # If a dict is passed in and is compared against a pb message,
        # do the type-check by converting it to pb message.
        proto_utils.dict_to_proto(value, declared())
      elif (isinstance(value, str) and not isinstance(declared, tuple) and
            issubclass(declared, message.Message)):
        # Skip check for runtime param string proto.
        if not is_runtime_param:
          # If a text is passed in and is compared against a pb message,
          # do the type-check by converting text (as json) to pb message.
          proto_utils.json_to_proto(value, declared())
      else:
        if not isinstance(value, declared):
          raise TypeError('Expected type %s for parameter %r '
                          'but got %s instead.' %
                          (str(declared), arg_name, value))