Esempio n. 1
0
    def id(self) -> str:
        """Node id, unique across all TFX nodes in a pipeline.

    If `id` is set by the user, return it directly.
    Otherwise, return <node_class_name>.

    Returns:
      node id.
    """
        if self._id:
            return self._id
        node_class = deprecation_utils.get_first_nondeprecated_class(
            self.__class__)
        return node_class.__name__
Esempio n. 2
0
def _convert_to_resolver_steps(resolver_node: base_node.BaseNode):
    """Converts Resolver node to a corresponding ResolverSteps."""
    assert compiler_utils.is_resolver(resolver_node)
    resolver_node = cast(resolver.Resolver, resolver_node)
    result = []
    for strategy_cls, config in resolver_node.strategy_class_and_configs:
        strategy_cls = deprecation_utils.get_first_nondeprecated_class(
            strategy_cls)
        step = pipeline_pb2.ResolverConfig.ResolverStep()
        step.class_path = (
            f"{strategy_cls.__module__}.{strategy_cls.__name__}")
        step.config_json = json_utils.dumps(config)
        step.input_keys.extend(resolver_node.inputs.keys())
        result.append(step)
    return result
Esempio n. 3
0
    def default(self, obj: Any) -> Any:
        # If obj is a str-typed RuntimeParameter, serialize it in place.
        if obj.__class__.__name__ == 'RuntimeParameter' and obj.ptype == Text:
            dict_data = {
                _TFX_OBJECT_TYPE_KEY: _ObjectType.JSONABLE,
                _MODULE_KEY: obj.__class__.__module__,
                _CLASS_KEY: obj.__class__.__name__,
            }
            dict_data.update(obj.to_json_dict())
            return dumps(dict_data)

        if isinstance(obj, Jsonable):
            dict_data = {
                _TFX_OBJECT_TYPE_KEY: _ObjectType.JSONABLE,
                _MODULE_KEY: obj.__class__.__module__,
                _CLASS_KEY: obj.__class__.__name__,
            }
            # Need to first check the existence of str-typed runtime parameter.
            data_patch = obj.to_json_dict()
            for k, v in data_patch.items():
                if v.__class__.__name__ == 'RuntimeParameter' and v.ptype == Text:
                    data_patch[k] = dumps(v)
            dict_data.update(data_patch)
            return dict_data

        if inspect.isclass(obj):
            # When serializing, skip over deprecated class aliases in the class
            # hierarchy.
            obj = deprecation_utils.get_first_nondeprecated_class(obj)
            return {
                _TFX_OBJECT_TYPE_KEY: _ObjectType.CLASS,
                _MODULE_KEY: obj.__module__,
                _CLASS_KEY: obj.__name__,
            }

        if isinstance(obj, message.Message):
            return {
                _TFX_OBJECT_TYPE_KEY: _ObjectType.PROTO,
                _MODULE_KEY: obj.__class__.__module__,
                _CLASS_KEY: obj.__class__.__name__,
                _PROTO_VALUE_KEY: proto_utils.proto_to_json(obj)
            }

        return super(_DefaultEncoder, self).default(obj)
Esempio n. 4
0
    def id(self) -> Text:
        """Node id, unique across all TFX nodes in a pipeline.

    If `id` is set by the user, return it directly.
    otherwise, if instance name (deprecated) is available, node id will be:
      <node_class_name>.<instance_name>
    otherwise, node id will be:
      <node_class_name>

    Returns:
      node id.
    """
        if self._id:
            return self._id
        node_class = deprecation_utils.get_first_nondeprecated_class(
            self.__class__)
        node_class_name = node_class.__name__
        if self._instance_name:
            return '{}.{}'.format(node_class_name, self._instance_name)
        else:
            return node_class_name
Esempio n. 5
0
    def get_id(cls, instance_name: Optional[Text] = None):
        """Gets the id of a node.

    This can be used during pipeline authoring time. For example:
    from tfx.components import Trainer

    resolver = ResolverNode(..., model=Channel(
        type=Model, producer_component_id=Trainer.get_id('my_trainer')))

    Args:
      instance_name: (Optional) instance name of a node. If given, the instance
        name will be taken into consideration when generating the id.

    Returns:
      an id for the node.
    """
        node_class = deprecation_utils.get_first_nondeprecated_class(cls)
        node_class_name = node_class.__name__
        if instance_name:
            return '{}.{}'.format(node_class_name, instance_name)
        else:
            return node_class_name
Esempio n. 6
0
  def _build_resolver_spec(self) -> Dict[str, pipeline_pb2.PipelineTaskSpec]:
    """Validates and builds ResolverSpec for this node.

    Returns:
      A list of PipelineTaskSpec represents the (potentially multiple) resolver
      task(s).
    Raises:
      TypeError: When get unsupported resolver policy. Currently only support
        LatestBlessedModelStrategy and LatestArtifactsStrategy.
    """
    assert isinstance(self._node, resolver.Resolver)

    strategy_cls = self._exec_properties[resolver.RESOLVER_STRATEGY_CLASS]
    strategy_cls = deprecation_utils.get_first_nondeprecated_class(strategy_cls)
    if strategy_cls == latest_blessed_model_strategy.LatestBlessedModelStrategy:
      return self._build_latest_blessed_model_resolver()
    elif strategy_cls == latest_artifact_strategy.LatestArtifactStrategy:
      return self._build_latest_artifact_resolver()
    else:
      raise TypeError(
          'Unexpected resolver policy encountered. Currently '
          'only support LatestArtifactStrategy and LatestBlessedModelStrategy '
          f'but got: {strategy_cls}.')
Esempio n. 7
0
 def get_class_type(cls) -> str:
     nondeprecated_class = deprecation_utils.get_first_nondeprecated_class(
         cls)
     return '.'.join(
         [nondeprecated_class.__module__, nondeprecated_class.__name__])
Esempio n. 8
0
 def type(self) -> Text:
     node_class = deprecation_utils.get_first_nondeprecated_class(
         self.__class__)
     return '.'.join([node_class.__module__, node_class.__name__])
Esempio n. 9
0
def _fully_qualified_name(cls: Type[Any]):
  cls = deprecation_utils.get_first_nondeprecated_class(cls)
  return f"{cls.__module__}.{cls.__qualname__}"