class BaseExecutorOperator(abc.ABC): """The base class of all executor operators.""" SUPPORTED_EXECUTOR_SPEC_TYPE = abc_utils.abstract_property() SUPPORTED_PLATFORM_CONFIG_TYPE = abc_utils.abstract_property() def __init__(self, executor_spec: message.Message, platform_config: Optional[message.Message] = None): """Constructor. Args: executor_spec: The specification of how to initialize the executor. platform_config: The specification of how to allocate resource for the executor. Raises: RuntimeError: if the executor_spec or platform_config is not supported. """ if not isinstance(executor_spec, tuple(t for t in self.SUPPORTED_EXECUTOR_SPEC_TYPE)): raise RuntimeError('Executor spec not supported: %s' % executor_spec) if platform_config and not isinstance( platform_config, tuple(t for t in self.SUPPORTED_PLATFORM_CONFIG_TYPE)): raise RuntimeError('Platform spec not supported: %s' % platform_config) self._executor_spec = executor_spec self._platform_config = platform_config self._execution_watcher_address = None @abc.abstractmethod def run_executor( self, execution_info: data_types.ExecutionInfo, ) -> execution_result_pb2.ExecutorOutput: """Invokes the executor with inputs provided by the Launcher. Args: execution_info: A wrapper of the info needed by this execution. Returns: The output from executor. """ pass def with_execution_watcher( self, execution_watcher_address: str) -> 'BaseExecutorOperator': """Attatch an execution watcher to the executor operator. Args: execution_watcher_address: The address to an executor watcher gRPC service which can be used to update execution properties. Returns: The executor operator itself. """ self._execution_watcher_address = execution_watcher_address return self
class BaseExecutorOperator(six.with_metaclass(abc.ABCMeta, object)): """The base class of all executor operators.""" SUPPORTED_EXECUTOR_SPEC_TYPE = abc_utils.abstract_property() SUPPORTED_PLATFORM_SPEC_TYPE = abc_utils.abstract_property() def __init__(self, executor_spec: message.Message, platform_spec: Optional[message.Message] = None): """Constructor. Args: executor_spec: The specification of how to initialize the executor. platform_spec: The specification of how to allocate resource for the executor. Raises: RuntimeError: if the executor_spec or platform_spec is not supported. """ if not isinstance(executor_spec, tuple(t for t in self.SUPPORTED_EXECUTOR_SPEC_TYPE)): raise RuntimeError('Executor spec not supported: %s' % executor_spec) if platform_spec and not isinstance( platform_spec, tuple(t for t in self.SUPPORTED_PLATFORM_SPEC_TYPE)): raise RuntimeError('Platform spec not supported: %s' % platform_spec) self._executor_spec = executor_spec self._platform_spec = platform_spec @abc.abstractmethod def run_executor( self, execution_info: ExecutionInfo, ) -> execution_result_pb2.ExecutorOutput: """Invokes the executor with inputs provided by the Launcher. Args: execution_info: A wrapper of the info needed by this execution. Returns: The output from executor. """ pass
class ArtifactVisualization(with_metaclass(abc.ABCMeta)): """Visualization for a certain type of Artifact.""" # Artifact type (of type `Type[types.Artifact]`) to which the visualization # applies. ARTIFACT_TYPE = abc_utils.abstract_property() @abc.abstractmethod def display(self, artifact: types.Artifact) -> Text: """Returns HTML string rendering artifact, in a notebook environment.""" raise NotImplementedError()
class BaseDriverOperator(six.with_metaclass(abc.ABCMeta, object)): """The base class of all executor operators.""" SUPPORTED_EXECUTABLE_SPEC_TYPE = abc_utils.abstract_property() def __init__(self, driver_spec: message.Message, mlmd_connection: metadata.Metadata, pipeline_info: pipeline_pb2.PipelineInfo, pipeline_node: pipeline_pb2.PipelineNode): """Constructor. Args: driver_spec: The specification of how to initialize the driver. mlmd_connection: ML metadata connection. pipeline_info: The information of the pipeline that this driver is in. pipeline_node: The specification of the node that this driver is in. Raises: RuntimeError: if the driver_spec is not supported. """ if not isinstance( driver_spec, tuple(t for t in self.SUPPORTED_EXECUTABLE_SPEC_TYPE)): raise RuntimeError('Driver spec not supported: %s' % driver_spec) self._driver_spec = driver_spec self._mlmd_connection = mlmd_connection self._pipeline_info = pipeline_info self._pipeline_node = pipeline_node @abc.abstractmethod def run_driver( self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> driver_output_pb2.DriverOutput: """Invokes the driver with inputs provided by the Launcher. Args: input_dict: The defult input_dict resolved by the launcher. output_dict: The default output_dict resolved by the launcher. exec_properties: The default exec_properties resolved by the launcher. Returns: An DriverOutput instance. """ pass
class BaseDriverOperator(six.with_metaclass(abc.ABCMeta, object)): """The base class of all executor operators.""" SUPPORTED_EXECUTABLE_SPEC_TYPE = abc_utils.abstract_property() def __init__(self, driver_spec: message.Message, mlmd_connection: metadata.Metadata): """Constructor. Args: driver_spec: The specification of how to initialize the driver. mlmd_connection: ML metadata connection. Raises: RuntimeError: if the driver_spec is not supported. """ if not isinstance( driver_spec, tuple(t for t in self.SUPPORTED_EXECUTABLE_SPEC_TYPE)): raise RuntimeError('Driver spec not supported: %s' % driver_spec) self._driver_spec = driver_spec self._mlmd_connection = mlmd_connection @abc.abstractmethod def run_driver( self, execution_info: data_types.ExecutionInfo ) -> driver_output_pb2.DriverOutput: """Invokes the driver with inputs provided by the Launcher. Args: execution_info: data_types.ExecutionInfo containing information needed for driver execution. Returns: An DriverOutput instance. """ pass
class BaseComponent(with_metaclass(abc.ABCMeta, json_utils.Jsonable)): """Base class for a TFX pipeline component. An instance of a subclass of BaseComponent represents the parameters for a single execution of that TFX pipeline component. All subclasses of BaseComponent must override the SPEC_CLASS field with the ComponentSpec subclass that defines the interface of this component. Attributes: SPEC_CLASS: a subclass of types.ComponentSpec used by this component (required). EXECUTOR_SPEC: an instance of executor_spec.ExecutorSpec which describes how to execute this component (required). DRIVER_CLASS: a subclass of base_driver.BaseDriver as a custom driver for this component (optional, defaults to base_driver.BaseDriver). """ # Subclasses must override this property (by specifying a types.ComponentSpec # class, e.g. "SPEC_CLASS = MyComponentSpec"). SPEC_CLASS = abc_utils.abstract_property() # Subclasses must also override the executor spec. # # Note: EXECUTOR_CLASS has been replaced with EXECUTOR_SPEC. A custom # component's existing executor class definition "EXECUTOR_CLASS = MyExecutor" # should be replaced with "EXECUTOR_SPEC = ExecutorClassSpec(MyExecutor). EXECUTOR_SPEC = abc_utils.abstract_property() # Subclasses will usually use the default driver class, but may override this # property as well. DRIVER_CLASS = base_driver.BaseDriver def __init__( self, spec: types.ComponentSpec, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, instance_name: Optional[Text] = None): """Initialize a component. Args: spec: types.ComponentSpec object for this component instance. custom_executor_spec: Optional custom executor spec overriding the default executor specified in the component attribute. instance_name: Optional unique identifying name for this instance of the component in the pipeline. Required if two instances of the same component is used in the pipeline. """ self.spec = spec if custom_executor_spec: if not isinstance(custom_executor_spec, executor_spec.ExecutorSpec): raise TypeError( ('Custom executor spec override %s for %s should be an instance of ' 'ExecutorSpec') % (custom_executor_spec, self.__class__)) self.executor_spec = (custom_executor_spec or self.__class__.EXECUTOR_SPEC) self.driver_class = self.__class__.DRIVER_CLASS # TODO(b/139540680): consider making instance_name private. self.instance_name = instance_name self._upstream_nodes = set() self._downstream_nodes = set() self._validate_component_class() self._validate_spec(spec) @classmethod def _validate_component_class(cls): """Validate that the SPEC_CLASSES property of this class is set properly.""" if not (inspect.isclass(cls.SPEC_CLASS) and issubclass(cls.SPEC_CLASS, types.ComponentSpec)): raise TypeError( ('Component class %s expects SPEC_CLASS property to be a subclass ' 'of types.ComponentSpec; got %s instead.') % (cls, cls.SPEC_CLASS)) if not isinstance(cls.EXECUTOR_SPEC, executor_spec.ExecutorSpec): raise TypeError(( 'Component class %s expects EXECUTOR_SPEC property to be an instance ' 'of ExecutorSpec; got %s instead.') % (cls, type(cls.EXECUTOR_SPEC))) if not (inspect.isclass(cls.DRIVER_CLASS) and issubclass(cls.DRIVER_CLASS, base_driver.BaseDriver)): raise TypeError( ('Component class %s expects DRIVER_CLASS property to be a subclass ' 'of base_driver.BaseDriver; got %s instead.') % (cls, cls.DRIVER_CLASS)) def _validate_spec(self, spec): """Verify given spec is valid given the component's SPEC_CLASS.""" if not isinstance(spec, types.ComponentSpec): raise ValueError(( 'BaseComponent (parent class of %s) expects "spec" argument to be an ' 'instance of types.ComponentSpec, got %s instead.') % (self.__class__, spec)) if not isinstance(spec, self.__class__.SPEC_CLASS): raise ValueError( ('%s expects the "spec" argument to be an instance of %s; ' 'got %s instead.') % (self.__class__, self.__class__.SPEC_CLASS, spec)) def __repr__(self): return ('%s(spec: %s, executor_spec: %s, driver_class: %s, ' 'component_id: %s, inputs: %s, outputs: %s)') % ( self.__class__.__name__, self.spec, self.executor_spec, self.driver_class, self.component_id, self.inputs, self.outputs) def to_json_dict(self) -> Dict[Text, Any]: return { # _DRIVER_CLASS_PATH_KEY: driver_class_path, _DRIVER_CLASS_PATH_KEY: self.driver_class, _EXECUTOR_SPEC_KEY: self.executor_spec, _INSTANCE_NAME_KEY: self.instance_name, _SPEC_KEY: self.spec } @property def component_type(self) -> Text: return '.'.join([self.__class__.__module__, self.__class__.__name__]) @property def inputs(self) -> component_spec._PropertyDictWrapper: return self.spec.inputs @property def outputs(self) -> component_spec._PropertyDictWrapper: return self.spec.outputs @property def exec_properties(self) -> Dict[Text, Any]: return self.spec.exec_properties # TODO(ruoyu): Consolidate the usage of component identifier. Moving forward, # we will have two component level keys: # - component_type: the path of the python executor or the image uri of the # executor. # - component_id: <component_class_name>.<unique_name> @property def component_id(self): """Component id, unique across all component instances in a pipeline. If unique name is available, component_id will be: <component_class_name>.<instance_name> otherwise, component_id will be: <component_class_name> Returns: component id. """ component_class_name = self.__class__.__name__ if self.instance_name: return '{}.{}'.format(component_class_name, self.instance_name) else: return component_class_name @property def upstream_nodes(self): return self._upstream_nodes def add_upstream_node(self, upstream_node): self._upstream_nodes.add(upstream_node) @property def downstream_nodes(self): return self._downstream_nodes def add_downstream_node(self, downstream_node): self._downstream_nodes.add(downstream_node)
class ComponentSpec(with_metaclass(abc.ABCMeta, json_utils.Jsonable)): """A specification of the inputs, outputs and parameters for a component. Components should have a corresponding ComponentSpec inheriting from this class and must override: - PARAMETERS (as a dict of string keys and ExecutionParameter values), - INPUTS (as a dict of string keys and ChannelParameter values) and - OUTPUTS (also a dict of string keys and ChannelParameter values). Here is an example of how a ComponentSpec may be defined: class MyCustomComponentSpec(ComponentSpec): PARAMETERS = { 'internal_option': ExecutionParameter(type=str), } INPUTS = { 'input_examples': ChannelParameter(type=standard_artifacts.Examples), } OUTPUTS = { 'output_examples': ChannelParameter(type=standard_artifacts.Examples), } To create an instance of a subclass, call it directly with any execution parameters / inputs / outputs as kwargs. For example: spec = MyCustomComponentSpec( internal_option='abc', input_examples=input_examples_channel, output_examples=output_examples_channel) Attributes: PARAMETERS: a dict of string keys and ExecutionParameter values. INPUTS: a dict of string keys and ChannelParameter values. OUTPUTS: a dict of string keys and ChannelParameter values. """ PARAMETERS = abc_utils.abstract_property() INPUTS = abc_utils.abstract_property() OUTPUTS = abc_utils.abstract_property() def __init__(self, **kwargs): """Initialize a ComponentSpec. Args: **kwargs: Any inputs, outputs and execution parameters for this instance of the component spec. """ self._raw_args = kwargs self._validate_spec() self._verify_parameter_types() self._parse_parameters() def __eq__(self, other): return (isinstance(other.__class__, self.__class__) and self.to_json_dict() == other.to_json_dict()) def _validate_spec(self): """Check the parameters and types passed to this ComponentSpec.""" for param_name, param in [('PARAMETERS', self.PARAMETERS), ('INPUTS', self.INPUTS), ('OUTPUTS', self.OUTPUTS)]: if not isinstance(param, dict): raise TypeError( ('Subclass %s of ComponentSpec must override %s with a ' 'dict; got %s instead.') % (self.__class__, param_name, param)) # Validate that the ComponentSpec class is well-formed. seen_arg_names = set() for arg_name, arg in itertools.chain(self.PARAMETERS.items(), self.INPUTS.items(), self.OUTPUTS.items()): if not isinstance(arg, _ComponentParameter): raise ValueError(( 'The ComponentSpec subclass %s expects that the values of its ' 'PARAMETERS, INPUTS, and OUTPUTS dicts are _ComponentParameter ' 'objects (i.e. ChannelParameter or ExecutionParameter objects); ' 'got %s (for argument %s) instead.') % (self.__class__, arg, arg_name)) if arg_name in seen_arg_names: raise ValueError(( 'The ComponentSpec subclass %s has a duplicate argument with ' 'name %s. Argument names should be unique across the PARAMETERS, ' 'INPUTS and OUTPUTS dicts.') % (self.__class__, arg_name)) seen_arg_names.add(arg_name) def _verify_parameter_types(self): """Verify spec parameter types.""" for arg in self.PARAMETERS.values(): if not isinstance(arg, ExecutionParameter): raise TypeError(( 'PARAMETERS dict expects values of type ExecutionParameter, ' 'got {}.').format(arg)) for arg in itertools.chain(self.INPUTS.values(), self.OUTPUTS.values()): if not isinstance(arg, ChannelParameter): raise TypeError(( 'INPUTS and OUTPUTS dicts expect values of type ChannelParameter, ' ' got {}.').format(arg)) def _parse_parameters(self): """Parse arguments to ComponentSpec.""" unparsed_args = set(self._raw_args.keys()) inputs = {} outputs = {} self.exec_properties = {} # First, check that the arguments are set. for arg_name, arg in itertools.chain(self.PARAMETERS.items(), self.INPUTS.items(), self.OUTPUTS.items()): if arg_name not in unparsed_args: if arg.optional: continue else: raise ValueError('Missing argument %r to %s.' % (arg_name, self.__class__)) unparsed_args.remove(arg_name) # Type check the argument. value = self._raw_args[arg_name] if arg.optional and value is None: continue arg.type_check(arg_name, value) # Populate the appropriate dictionary for each parameter type. for arg_name, arg in self.PARAMETERS.items(): if arg.optional and arg_name not in self._raw_args: continue value = self._raw_args[arg_name] if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) and value): # Create deterministic json string as it will be stored in metadata for # cache check. if isinstance(value, dict): value = json_utils.dumps(value) else: value = json_format.MessageToJson( message=value, sort_keys=True, preserving_proto_field_name=True) self.exec_properties[arg_name] = value for arg_dict, param_dict in ((self.INPUTS, inputs), (self.OUTPUTS, outputs)): for arg_name, arg in arg_dict.items(): if arg.optional and not self._raw_args.get(arg_name): continue value = self._raw_args[arg_name] param_dict[arg_name] = value # Note: for forwards compatibility, ComponentSpec objects may provide an # attribute mapping virtual keys to physical keys in the outputs dictionary, # and when the value for a virtual key is accessed, the value for the # physical key will be returned instead. This is intended to provide # forwards compatibility. This feature will be removed once attribute # renaming is completed and *should not* be used by ComponentSpec authors # outside the TFX package. # # TODO(b/139281215): remove this functionality. self.inputs = _PropertyDictWrapper( inputs, compat_aliases=getattr(self, '_INPUT_COMPATIBILITY_ALIASES', None)) self.outputs = _PropertyDictWrapper( outputs, compat_aliases=getattr(self, '_OUTPUT_COMPATIBILITY_ALIASES', None)) def to_json_dict(self) -> Dict[Text, Any]: """Convert from an object to a JSON serializable dictionary.""" return { 'inputs': self.inputs, 'outputs': self.outputs, 'exec_properties': self.exec_properties, }
class BaseComponent(base_node.BaseNode, abc.ABC): """Base class for a TFX pipeline component. An instance of a subclass of BaseComponent represents the parameters for a single execution of that TFX pipeline component. All subclasses of BaseComponent must override the SPEC_CLASS field with the ComponentSpec subclass that defines the interface of this component. Attributes: SPEC_CLASS: a subclass of types.ComponentSpec used by this component (required). This is a class level value. EXECUTOR_SPEC: an instance of executor_spec.ExecutorSpec which describes how to execute this component (required). This is a class level value. DRIVER_CLASS: a subclass of base_driver.BaseDriver as a custom driver for this component (optional, defaults to base_driver.BaseDriver). This is a class level value. spec: an instance of `SPEC_CLASS`. See types.ComponentSpec for more details. platform_config: a protobuf message representing platform config for a component instance. """ # Subclasses must override this property (by specifying a types.ComponentSpec # class, e.g. "SPEC_CLASS = MyComponentSpec"). SPEC_CLASS = abc_utils.abstract_property() doc_controls.do_not_doc_in_subclasses(SPEC_CLASS) # Subclasses must also override the executor spec. # # Note: EXECUTOR_CLASS has been replaced with EXECUTOR_SPEC. A custom # component's existing executor class definition "EXECUTOR_CLASS = MyExecutor" # should be replaced with "EXECUTOR_SPEC = ExecutorClassSpec(MyExecutor). EXECUTOR_SPEC = abc_utils.abstract_property() doc_controls.do_not_doc_in_subclasses(EXECUTOR_SPEC) # Subclasses will usually use the default driver class, but may override this # property as well. DRIVER_CLASS = base_driver.BaseDriver doc_controls.do_not_doc_in_subclasses(DRIVER_CLASS) def __init__( self, spec: types.ComponentSpec, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None): """Initialize a component. Args: spec: types.ComponentSpec object for this component instance. custom_executor_spec: Optional custom executor spec overriding the default executor specified in the component attribute. """ if custom_executor_spec: if not isinstance(custom_executor_spec, executor_spec.ExecutorSpec): raise TypeError( ('Custom executor spec override %s for %s should be an instance of ' 'ExecutorSpec') % (custom_executor_spec, self.__class__)) executor_spec_obj = custom_executor_spec or self.__class__.EXECUTOR_SPEC # TODO(b/171742415): Remove this try-catch block once we migrate Beam # DAG runner to IR-based stack. The deep copy will only fail for function # based components due to pickle workaround we created in ExecutorClassSpec. try: executor_spec_obj = executor_spec_obj.copy() except Exception as e: # pylint:disable = bare-except # This will only happen for function based components, which is fine. raise ValueError(f'The executor spec of {self.__class__} class is ' f'not copyable.') from e driver_class = self.__class__.DRIVER_CLASS super(BaseComponent, self).__init__( executor_spec=executor_spec_obj, driver_class=driver_class, ) self.spec = spec self._validate_component_class() self._validate_spec(spec) self.platform_config = None @classmethod def _validate_component_class(cls): """Validate that the SPEC_CLASSES property of this class is set properly.""" if not (inspect.isclass(cls.SPEC_CLASS) and issubclass(cls.SPEC_CLASS, types.ComponentSpec)): raise TypeError( ('Component class %s expects SPEC_CLASS property to be a subclass ' 'of types.ComponentSpec; got %s instead.') % (cls, cls.SPEC_CLASS)) if not isinstance(cls.EXECUTOR_SPEC, executor_spec.ExecutorSpec): raise TypeError(( 'Component class %s expects EXECUTOR_SPEC property to be an instance ' 'of ExecutorSpec; got %s instead.') % (cls, type(cls.EXECUTOR_SPEC))) if not (inspect.isclass(cls.DRIVER_CLASS) and issubclass(cls.DRIVER_CLASS, base_driver.BaseDriver)): raise TypeError( ('Component class %s expects DRIVER_CLASS property to be a subclass ' 'of base_driver.BaseDriver; got %s instead.') % (cls, cls.DRIVER_CLASS)) def _validate_spec(self, spec): """Verify given spec is valid given the component's SPEC_CLASS.""" if not isinstance(spec, types.ComponentSpec): raise ValueError(( 'BaseComponent (parent class of %s) expects "spec" argument to be an ' 'instance of types.ComponentSpec, got %s instead.') % (self.__class__, spec)) if not isinstance(spec, self.__class__.SPEC_CLASS): raise ValueError( ('%s expects the "spec" argument to be an instance of %s; ' 'got %s instead.') % (self.__class__, self.__class__.SPEC_CLASS, spec)) # TODO(b/170682320): This function is not widely available until we migrate # the entire stack to IR-based. @doc_controls.do_not_doc_in_subclasses def with_platform_config(self, config: message.Message) -> 'BaseComponent': """Attaches a proto-form platform config to a component. The config will be a per-node platform-specific config. Args: config: platform config to attach to the component. Returns: the same component itself. """ self.platform_config = config return self def __repr__(self): return ('%s(spec: %s, executor_spec: %s, driver_class: %s, ' 'component_id: %s, inputs: %s, outputs: %s)') % ( self.__class__.__name__, self.spec, self.executor_spec, self.driver_class, self.id, self.inputs, self.outputs) @property @doc_controls.do_not_doc_in_subclasses def inputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access, g-missing-from-attributes return self.spec.inputs @property def outputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access """Component's output channel dict.""" return self.spec.outputs @property @doc_controls.do_not_doc_in_subclasses def exec_properties(self) -> Dict[Text, Any]: # pylint: disable=g-missing-from-attributes return self.spec.exec_properties
class BaseComponent(with_metaclass(abc.ABCMeta, base_node.BaseNode)): """Base class for a TFX pipeline component. An instance of a subclass of BaseComponent represents the parameters for a single execution of that TFX pipeline component. All subclasses of BaseComponent must override the SPEC_CLASS field with the ComponentSpec subclass that defines the interface of this component. Attributes: SPEC_CLASS: a subclass of types.ComponentSpec used by this component (required). EXECUTOR_SPEC: an instance of executor_spec.ExecutorSpec which describes how to execute this component (required). DRIVER_CLASS: a subclass of base_driver.BaseDriver as a custom driver for this component (optional, defaults to base_driver.BaseDriver). """ # Subclasses must override this property (by specifying a types.ComponentSpec # class, e.g. "SPEC_CLASS = MyComponentSpec"). SPEC_CLASS = abc_utils.abstract_property() # Subclasses must also override the executor spec. # # Note: EXECUTOR_CLASS has been replaced with EXECUTOR_SPEC. A custom # component's existing executor class definition "EXECUTOR_CLASS = MyExecutor" # should be replaced with "EXECUTOR_SPEC = ExecutorClassSpec(MyExecutor). EXECUTOR_SPEC = abc_utils.abstract_property() # Subclasses will usually use the default driver class, but may override this # property as well. DRIVER_CLASS = base_driver.BaseDriver def __init__( self, spec: types.ComponentSpec, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None, instance_name: Optional[Text] = None): """Initialize a component. Args: spec: types.ComponentSpec object for this component instance. custom_executor_spec: Optional custom executor spec overriding the default executor specified in the component attribute. instance_name: Optional unique identifying name for this instance of the component in the pipeline. Required if two instances of the same component is used in the pipeline. """ executor_spec_obj = (custom_executor_spec or self.__class__.EXECUTOR_SPEC) driver_class = self.__class__.DRIVER_CLASS super(BaseComponent, self).__init__( instance_name=instance_name, executor_spec=executor_spec_obj, driver_class=driver_class, ) self.spec = spec if custom_executor_spec: if not isinstance(custom_executor_spec, executor_spec.ExecutorSpec): raise TypeError(( 'Custom executor spec override %s for %s should be an instance of ' 'ExecutorSpec') % (custom_executor_spec, self.__class__)) self._validate_component_class() self._validate_spec(spec) @classmethod def _validate_component_class(cls): """Validate that the SPEC_CLASSES property of this class is set properly.""" if not (inspect.isclass(cls.SPEC_CLASS) and issubclass(cls.SPEC_CLASS, types.ComponentSpec)): raise TypeError(( 'Component class %s expects SPEC_CLASS property to be a subclass ' 'of types.ComponentSpec; got %s instead.') % (cls, cls.SPEC_CLASS)) if not isinstance(cls.EXECUTOR_SPEC, executor_spec.ExecutorSpec): raise TypeError(( 'Component class %s expects EXECUTOR_SPEC property to be an instance ' 'of ExecutorSpec; got %s instead.') % (cls, type(cls.EXECUTOR_SPEC))) if not (inspect.isclass(cls.DRIVER_CLASS) and issubclass(cls.DRIVER_CLASS, base_driver.BaseDriver)): raise TypeError(( 'Component class %s expects DRIVER_CLASS property to be a subclass ' 'of base_driver.BaseDriver; got %s instead.') % (cls, cls.DRIVER_CLASS)) def _validate_spec(self, spec): """Verify given spec is valid given the component's SPEC_CLASS.""" if not isinstance(spec, types.ComponentSpec): raise ValueError(( 'BaseComponent (parent class of %s) expects "spec" argument to be an ' 'instance of types.ComponentSpec, got %s instead.') % (self.__class__, spec)) if not isinstance(spec, self.__class__.SPEC_CLASS): raise ValueError( ('%s expects the "spec" argument to be an instance of %s; ' 'got %s instead.') % (self.__class__, self.__class__.SPEC_CLASS, spec)) def __repr__(self): return ('%s(spec: %s, executor_spec: %s, driver_class: %s, ' 'component_id: %s, inputs: %s, outputs: %s)') % ( self.__class__.__name__, self.spec, self.executor_spec, self.driver_class, self.id, self.inputs, self.outputs) @property def inputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access return self.spec.inputs @property def outputs(self) -> node_common._PropertyDictWrapper: # pylint: disable=protected-access return self.spec.outputs @property def exec_properties(self) -> Dict[Text, Any]: return self.spec.exec_properties
class BaseComponent(base_node.BaseNode, abc.ABC): """Base class for a TFX pipeline component. An instance of a subclass of BaseComponent represents the parameters for a single execution of that TFX pipeline component. All subclasses of BaseComponent must override the SPEC_CLASS field with the ComponentSpec subclass that defines the interface of this component. Attributes: SPEC_CLASS: a subclass of types.ComponentSpec used by this component (required). This is a class level value. EXECUTOR_SPEC: an instance of executor_spec.ExecutorSpec which describes how to execute this component (required). This is a class level value. DRIVER_CLASS: a subclass of base_driver.BaseDriver as a custom driver for this component (optional, defaults to base_driver.BaseDriver). This is a class level value. spec: an instance of `SPEC_CLASS`. See types.ComponentSpec for more details. platform_config: a protobuf message representing platform config for a component instance. """ # Subclasses must override this property (by specifying a types.ComponentSpec # class, e.g. "SPEC_CLASS = MyComponentSpec"). SPEC_CLASS = abc_utils.abstract_property() doc_controls.do_not_doc_in_subclasses(SPEC_CLASS) # Subclasses must also override the executor spec. # # Note: EXECUTOR_CLASS has been replaced with EXECUTOR_SPEC. A custom # component's existing executor class definition "EXECUTOR_CLASS = MyExecutor" # should be replaced with "EXECUTOR_SPEC = ExecutorClassSpec(MyExecutor). EXECUTOR_SPEC = abc_utils.abstract_property() doc_controls.do_not_doc_in_subclasses(EXECUTOR_SPEC) # Subclasses will usually use the default driver class, but may override this # property as well. DRIVER_CLASS = base_driver.BaseDriver doc_controls.do_not_doc_in_subclasses(DRIVER_CLASS) def __init__( self, spec: types.ComponentSpec, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None): """Initialize a component. Args: spec: types.ComponentSpec object for this component instance. custom_executor_spec: Optional custom executor spec overriding the default executor specified in the component attribute. """ if custom_executor_spec: if not isinstance(custom_executor_spec, executor_spec.ExecutorSpec): raise TypeError(( 'Custom executor spec override %s for %s should be an instance of ' 'ExecutorSpec') % (custom_executor_spec, self.__class__)) executor_spec_obj = custom_executor_spec or self.__class__.EXECUTOR_SPEC # TODO(b/171742415): Remove this try-catch block once we migrate Beam # DAG runner to IR-based stack. The deep copy will only fail for function # based components due to pickle workaround we created in ExecutorClassSpec. try: executor_spec_obj = executor_spec_obj.copy() except Exception as e: # pylint:disable = bare-except # This will only happen for function based components, which is fine. raise ValueError(f'The executor spec of {self.__class__} class is ' f'not copyable.') from e driver_class = self.__class__.DRIVER_CLASS # Set self.spec before super.__init__() where node registration happens. # This enable node input checking on node context registration. self.spec = spec super().__init__( executor_spec=executor_spec_obj, driver_class=driver_class, ) self._validate_component_class() self._validate_spec(spec) self.platform_config = None self._pip_dependencies = [] @classmethod def _validate_component_class(cls): """Validate that the SPEC_CLASSES property of this class is set properly.""" if not (inspect.isclass(cls.SPEC_CLASS) and issubclass(cls.SPEC_CLASS, types.ComponentSpec)): raise TypeError(( 'Component class %s expects SPEC_CLASS property to be a subclass ' 'of types.ComponentSpec; got %s instead.') % (cls, cls.SPEC_CLASS)) if not isinstance(cls.EXECUTOR_SPEC, executor_spec.ExecutorSpec): raise TypeError(( 'Component class %s expects EXECUTOR_SPEC property to be an instance ' 'of ExecutorSpec; got %s instead.') % (cls, type(cls.EXECUTOR_SPEC))) if not (inspect.isclass(cls.DRIVER_CLASS) and issubclass(cls.DRIVER_CLASS, base_driver.BaseDriver)): raise TypeError(( 'Component class %s expects DRIVER_CLASS property to be a subclass ' 'of base_driver.BaseDriver; got %s instead.') % (cls, cls.DRIVER_CLASS)) def _validate_spec(self, spec): """Verify given spec is valid given the component's SPEC_CLASS.""" if not isinstance(spec, types.ComponentSpec): raise ValueError(( 'BaseComponent (parent class of %s) expects "spec" argument to be an ' 'instance of types.ComponentSpec, got %s instead.') % (self.__class__, spec)) if not isinstance(spec, self.__class__.SPEC_CLASS): raise ValueError( ('%s expects the "spec" argument to be an instance of %s; ' 'got %s instead.') % (self.__class__, self.__class__.SPEC_CLASS, spec)) # TODO(b/170682320): This function is not widely available until we migrate # the entire stack to IR-based. @doc_controls.do_not_doc_in_subclasses def with_platform_config(self, config: message.Message) -> 'BaseComponent': """Attaches a proto-form platform config to a component. The config will be a per-node platform-specific config. Args: config: platform config to attach to the component. Returns: the same component itself. """ self.platform_config = config return self def __repr__(self): return ('%s(spec: %s, executor_spec: %s, driver_class: %s, ' 'component_id: %s, inputs: %s, outputs: %s)') % ( self.__class__.__name__, self.spec, self.executor_spec, self.driver_class, self.id, self.inputs, self.outputs) @property @doc_controls.do_not_doc_in_subclasses def inputs(self) -> Dict[str, Any]: return self.spec.inputs @property def outputs(self) -> Dict[str, Any]: """Component's output channel dict.""" return self.spec.outputs @property @doc_controls.do_not_doc_in_subclasses def type_annotation(self) -> Optional[Type[SystemExecution]]: result = self.__class__.SPEC_CLASS.TYPE_ANNOTATION if result and not issubclass(result, SystemExecution): raise TypeError( 'TYPE_ANNOTATION %s is not a subclass of SystemExecution.' % result) return result @property @doc_controls.do_not_doc_in_subclasses def exec_properties(self) -> Dict[str, Any]: # pylint: disable=g-missing-from-attributes return self.spec.exec_properties def _add_pip_dependency( self, dependency: Union[str, '_PipDependencyFuture']) -> None: """Internal use only: add pip dependency to current component.""" # TODO(b/187122662): Provide separate Python component hierarchy and remove # logic from this class. self._pip_dependencies.append(dependency) def _resolve_pip_dependencies(self, pipeline_root: str) -> None: """Experimental: resolve pip dependencies into specifiers.""" if not hasattr(self, '_pip_dependencies'): return new_pip_dependencies = [] for dependency in self._pip_dependencies: if isinstance(dependency, str): new_pip_dependencies.append(dependency) elif isinstance(dependency, _PipDependencyFuture): resolved_dependency = dependency.resolve(pipeline_root) if resolved_dependency: new_pip_dependencies.append(resolved_dependency) else: raise ValueError('Invalid pip dependency object: %s.' % dependency) self._pip_dependencies = new_pip_dependencies
class TfxRunner(with_metaclass(abc.ABCMeta, object)): """Base runner class for TFX. This is the base class for every TFX runner. """ # A list of component launcher classes that are supported by the current # runner. List sequence determines the order in which launchers are chosen # for each component being run. # Subclasses must override this property by specifying a list of supported # launcher classes, e.g. # `SUPPORTED_LAUNCHER_CLASSES = [InProcessComponentLauncher]`. SUPPORTED_LAUNCHER_CLASSES = abc_utils.abstract_property() def __init__(self): """Initializes a TfxRunner instance. """ self._supported_launcher_classes = self.__class__.SUPPORTED_LAUNCHER_CLASSES self._validate_supported_launcher_classes() def _validate_supported_launcher_classes(self): if not self._supported_launcher_classes: raise ValueError( 'component_launcher_classes must not be None or empty.') if any([ not issubclass(cls, base_component_launcher.BaseComponentLauncher) for cls in self._supported_launcher_classes ]): raise TypeError( 'Each item in supported_launcher_classes must be type of ' 'base_component_launcher.BaseComponentLauncher.') def find_component_launcher_class( self, component: base_component.BaseComponent ) -> Type[base_component_launcher.BaseComponentLauncher]: """Find a launcher in the runner which can launch the component. The default lookup logic goes through the self._supported_launcher_classes in order and returns the first one which can launch the executor_spec of the component. Subclass may customize the logic by overriding the method. Args: component: the component to launch. Returns: The found component launcher for the component. Raises: RuntimeError: if no supported launcher is found. """ for component_launcher_class in self._supported_launcher_classes: if component_launcher_class.can_launch(component.executor_spec): return component_launcher_class raise RuntimeError('No launcher can launch component "%s".' % component.component_id) @abc.abstractmethod def run(self, pipeline) -> Optional[Any]: """Runs logical TFX pipeline on specific platform. Args: pipeline: logical TFX pipeline definition. Returns: Platform-specific object. """ pass
class MaterializedArtifact(abc.ABC): """TFX artifact used for artifact analysis and visualization.""" def __init__(self, artifact: types.Artifact): self._artifact = artifact def __str__(self): return 'Type: %s, URI: %s' % (self.type_name, self.uri) def __repr__(self): return f'<{self.__str__()}>' # Artifact type (of type `Type[types.Artifact]`). ARTIFACT_TYPE = abc_utils.abstract_property() @property def uri(self) -> str: """Artifact URI.""" return self._artifact.uri @property def id(self) -> int: """Artifact id.""" return self._artifact.id @property def name(self) -> str: """Artifact name.""" return self._artifact.mlmd_artifact.name or self._artifact.name @property def type_name(self) -> str: """Artifact type name.""" return self._artifact.type_name @property def producer_component(self) -> str: """The producer component of this artifact.""" return self._artifact.producer_component @property def properties(self) -> Dict[str, str]: """Returns dictionary of custom and default properties of the artifact.""" properties = {} for key, value in self._artifact.mlmd_artifact.properties.items(): properties[key] = value.string_value for key, value in self._artifact.mlmd_artifact.custom_properties.items( ): properties[key] = value.string_value return properties def _validate_payload(self): """Raises error if the artifact uri is not readable. Raises: IOError: Error raised if no conclusive determination could be made of files state (either because the path definitely does not exist or because some error occurred trying to query the file's state). """ if not gfile.exists(self.uri): raise IOError(f'Artifact URI {self.uri} not readable.') @abc.abstractmethod def show(self) -> None: """Displays respective visualization for artifact type.""" raise NotImplementedError()