Ejemplo n.º 1
0
  def _parse_parameter_from_component(
      self, component: base_component.BaseComponent) -> None:
    """Extract embedded RuntimeParameter placeholders from a component.

    Extract embedded RuntimeParameter placeholders from a component, then append
    the corresponding dsl.PipelineParam to KubeflowDagRunner.

    Args:
      component: a TFX component.
    """

    serialized_component = json_utils.dumps(component)
    placeholders = re.findall(data_types.RUNTIME_PARAMETER_PATTERN,
                              serialized_component)
    for placeholder in placeholders:
      placeholder = placeholder.replace('\\', '')  # Clean escapes.
      placeholder = utils.fix_brackets(placeholder)  # Fix brackets if needed.
      parameter = json_utils.loads(placeholder)
      # Escape pipeline root because it will be added later.
      if parameter.name == tfx_pipeline.ROOT_PARAMETER.name:
        continue
      if parameter.name not in self._deduped_parameter_names:
        self._deduped_parameter_names.add(parameter.name)
        # TODO(b/178436919): Create a test to cover default value rendering
        # and move the external code reference over there.
        # The default needs to be serialized then passed to dsl.PipelineParam.
        # See
        # https://github.com/kubeflow/pipelines/blob/f65391309650fdc967586529e79af178241b4c2c/sdk/python/kfp/dsl/_pipeline_param.py#L154
        dsl_parameter = dsl.PipelineParam(
            name=parameter.name, value=str(parameter.default))
        self._params.append(dsl_parameter)
Ejemplo n.º 2
0
    def _parse_parameter_from_component(
            self, component: base_component.BaseComponent) -> None:
        """Extract embedded RuntimeParameter placeholders from a component.

    Extract embedded RuntimeParameter placeholders from a component, then append
    the corresponding dsl.PipelineParam to KubeflowDagRunner.

    Args:
      component: a TFX component.
    """

        serialized_component = json_utils.dumps(component)
        placeholders = re.findall(data_types.RUNTIME_PARAMETER_PATTERN,
                                  serialized_component)
        for placeholder in placeholders:
            placeholder = placeholder.replace('\\', '')  # Clean escapes.
            placeholder = utils.fix_brackets(
                placeholder)  # Fix brackets if needed.
            parameter = json_utils.loads(placeholder)
            # Escape pipeline root because it will be added later.
            if parameter.name == tfx_pipeline.ROOT_PARAMETER.name:
                continue
            if parameter.name not in self._deduped_parameter_names:
                self._deduped_parameter_names.add(parameter.name)
                dsl_parameter = dsl.PipelineParam(name=parameter.name,
                                                  value=parameter.default)
                self._params.append(dsl_parameter)