def _parse_parameters(self): """Parse arguments to ComponentSpec.""" unparsed_args = set(self._raw_args.keys()) inputs = {} outputs = {} self.exec_properties = {} # First, check that the arguments are set. for arg_name, arg in itertools.chain(self.PARAMETERS.items(), self.INPUTS.items(), self.OUTPUTS.items()): if arg_name not in unparsed_args: if arg.optional: continue else: raise ValueError('Missing argument %r to %s.' % (arg_name, self.__class__)) unparsed_args.remove(arg_name) # Type check the argument. value = self._raw_args[arg_name] if arg.optional and value is None: continue arg.type_check(arg_name, value) # Populate the appropriate dictionary for each parameter type. for arg_name, arg in self.PARAMETERS.items(): if arg.optional and arg_name not in self._raw_args: continue value = self._raw_args[arg_name] if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) and value): # Create deterministic json string as it will be stored in metadata for # cache check. value = json_format.MessageToJson(value, sort_keys=True) self.exec_properties[arg_name] = value for arg_name, arg in self.INPUTS.items(): if arg.optional and not self._raw_args.get(arg_name): continue value = self._raw_args[arg_name] inputs[arg_name] = value for arg_name, arg in self.OUTPUTS.items(): value = self._raw_args[arg_name] outputs[arg_name] = value # Note: for forwards compatibility, ComponentSpec objects may provide an # attribute mapping virtual keys to physical keys in the outputs dictionary, # and when the value for a virtual key is accessed, the value for the # physical key will be returned instead. This is intended to provide # forwards compatibility. This feature will be removed once attribute # renaming is completed and *should not* be used by ComponentSpec authors # outside the TFX package. # # TODO(b/139281215): remove this functionality. self.inputs = _PropertyDictWrapper( inputs, compat_aliases=getattr(self, '_INPUT_COMPATIBILITY_ALIASES', None)) self.outputs = _PropertyDictWrapper( outputs, compat_aliases=getattr(self, '_OUTPUT_COMPATIBILITY_ALIASES', None))
def outputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access return node_common._PropertyDictWrapper(self._output_dict) # pylint: disable=protected-access
def inputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access return node_common._PropertyDictWrapper({}) # pylint: disable=protected-access
def outputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access """Output Channel dict that contains imported artifacts.""" return node_common._PropertyDictWrapper(self._output_dict) # pylint: disable=protected-access
def _parse_parameters(self): """Parse arguments to ComponentSpec.""" unparsed_args = set(self._raw_args.keys()) inputs = {} outputs = {} self.exec_properties = {} # Following three helper functions replace RuntimeParameters with its # default values, so that later on we can leverage json_format library to do # type check. def _make_default_dict(dict_data: Dict[Text, Any]) -> Dict[Text, Any]: """Generates a dict with parameters replaced by default values.""" copy_dict = copy.deepcopy(dict_data) _put_default_dict(copy_dict) return copy_dict def _put_default_dict(dict_data: Dict[Text, Any]) -> None: """Helper function to replace RuntimeParameter with its default value.""" for k, v in dict_data.items(): if isinstance(v, dict): _put_default_dict(v) elif isinstance(v, list): _put_default_list(v) elif v.__class__.__name__ == 'RuntimeParameter': # Currently supporting int, float, bool, Text ptype = v.ptype dict_data[k] = ptype.__new__(ptype) def _put_default_list(list_data: List[Any]) -> None: """Helper function to replace RuntimeParameter with its default value.""" for index, item in enumerate(list_data): if isinstance(item, dict): _put_default_dict(item) elif isinstance(item, list): _put_default_list(item) elif item.__class__.__name__ == 'RuntimeParameter': # Currently supporting int, float, bool, Text ptype = item.ptype list_data[index] = ptype.__new__(ptype) # First, check that the arguments are set. for arg_name, arg in itertools.chain(self.PARAMETERS.items(), self.INPUTS.items(), self.OUTPUTS.items()): if arg_name not in unparsed_args: if arg.optional: continue else: raise ValueError('Missing argument %r to %s.' % (arg_name, self.__class__)) unparsed_args.remove(arg_name) # Type check the argument. value = self._raw_args[arg_name] if arg.optional and value is None: continue arg.type_check(arg_name, value) # Populate the appropriate dictionary for each parameter type. for arg_name, arg in self.PARAMETERS.items(): if arg.optional and arg_name not in self._raw_args: continue value = self._raw_args[arg_name] if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) and value): # Create deterministic json string as it will be stored in metadata for # cache check. if isinstance(value, dict): # If a dict is passed in, it might contains RuntimeParameter. # Given the argument type is specified as pb, do the type-check by # converting it to pb message. dict_with_default = _make_default_dict(value) json_format.ParseDict(dict_with_default, arg.type()) value = json_utils.dumps(value) else: value = json_format.MessageToJson( message=value, sort_keys=True, preserving_proto_field_name=True) self.exec_properties[arg_name] = value for arg_name, arg in self.INPUTS.items(): if arg.optional and not self._raw_args.get(arg_name): continue value = self._raw_args[arg_name] inputs[arg_name] = value for arg_name, arg in self.OUTPUTS.items(): value = self._raw_args[arg_name] outputs[arg_name] = value # Note: for forwards compatibility, ComponentSpec objects may provide an # attribute mapping virtual keys to physical keys in the outputs dictionary, # and when the value for a virtual key is accessed, the value for the # physical key will be returned instead. This is intended to provide # forwards compatibility. This feature will be removed once attribute # renaming is completed and *should not* be used by ComponentSpec authors # outside the TFX package. # # TODO(b/139281215): remove this functionality. self.inputs = _PropertyDictWrapper( inputs, compat_aliases=getattr(self, '_INPUT_COMPATIBILITY_ALIASES', None)) self.outputs = _PropertyDictWrapper( outputs, compat_aliases=getattr(self, '_OUTPUT_COMPATIBILITY_ALIASES', None))