예제 #1
0
파일: compiler_utils.py 프로젝트: kp425/tfx
def set_runtime_parameter_pb(
    pb: pipeline_pb2.RuntimeParameter,
    name: Text,
    ptype: Type[types.Property],
    default_value: Optional[types.Property] = None
) -> pipeline_pb2.RuntimeParameter:
    """Helper function to fill a RuntimeParameter proto.

  Args:
    pb: A RuntimeParameter proto to be filled in.
    name: Name to be set at pb.name.
    ptype: The Python type to be set at pb.type.
    default_value: Optional. If provided, it will be pb.default_value.

  Returns:
    A RuntimeParameter proto filled with provided values.
  """
    pb.name = name
    if ptype == int:
        pb.type = pipeline_pb2.RuntimeParameter.Type.INT
        if default_value:
            pb.default_value.int_value = default_value
    elif ptype == float:
        pb.type = pipeline_pb2.RuntimeParameter.Type.DOUBLE
        if default_value:
            pb.default_value.double_value = default_value
    elif ptype == str:
        pb.type = pipeline_pb2.RuntimeParameter.Type.STRING
        if default_value:
            pb.default_value.string_value = default_value
    else:
        raise ValueError(
            "Got unsupported runtime parameter type: {}".format(ptype))
    return pb
예제 #2
0
def _get_runtime_parameter_value(
    runtime_parameter: pipeline_pb2.RuntimeParameter,
    parameter_bindings: Mapping[str,
                                types.Property]) -> Optional[types.Property]:
    """Populates the value for a RuntimeParameter when possible.

  If external parameter bindings not found, try to use the default value.

  Args:
    runtime_parameter: RuntimeParameter as the template.
    parameter_bindings: Parameter bindings to substitute runtime parameter
      placeholders in the RuntimeParameter.

  Returns:
    Resolved value for the RuntimeParameter if available. Returns None if the
    RuntimeParameter cannot be resolved.

  Raises:
    RuntimeError: When the provided binding value type does not match the
      RuntimeParameter type requirement.
  """
    # If no external parameter bindings for this runtime parameter, try to use its
    # default value.
    if runtime_parameter.name not in parameter_bindings:
        if runtime_parameter.HasField('default_value'):
            default_value = getattr(
                runtime_parameter.default_value,
                runtime_parameter.default_value.WhichOneof('value'))
            if _is_type_match(runtime_parameter.type, default_value):
                return default_value
            else:
                raise RuntimeError(
                    'Runtime parameter type %s does not match with %s.' %
                    (type(default_value), runtime_parameter))
        else:
            return None

    # External parameter binding is found, try to use it.
    binding_value = parameter_bindings[runtime_parameter.name]
    if _is_type_match(runtime_parameter.type, binding_value):
        return binding_value
    else:
        raise RuntimeError(
            'Runtime parameter type %s does not match with %s.' %
            (type(binding_value), runtime_parameter))