コード例 #1
0
def _function_languages_from_arg(
        arg_proto: v2.program_pb2.Arg) -> Iterator[str]:

    which = arg_proto.WhichOneof('arg')
    if which == 'func':
        if arg_proto.func.type in ['add', 'mul']:
            yield 'linear'
            for a in arg_proto.func.args:
                yield from _function_languages_from_arg(a)
コード例 #2
0
def _arg_from_proto(
    arg_proto: v2.program_pb2.Arg,
    *,
    arg_function_language: str,
    required_arg_name: Optional[str] = None,
) -> Optional[ARG_LIKE]:
    """Extracts a python value from an argument value proto.

    Args:
        arg_proto: The proto containing a serialized value.
        arg_function_language: The `arg_function_language` field from
            `Program.Language`.
        required_arg_name: If set to `None`, the method will return `None` when
            given an unset proto value. If set to a string, the method will
            instead raise an error complaining that the value is missing in that
            situation.

    Returns:
        The deserialized value, or else None if there was no set value and
        `required_arg_name` was set to `None`.
    """
    supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE.get(arg_function_language)
    if supported is None:
        raise ValueError(f'Unrecognized arg_function_language: '
                         f'{arg_function_language!r}')

    which = arg_proto.WhichOneof('arg')
    if which == 'arg_value':
        arg_value = arg_proto.arg_value
        which_val = arg_value.WhichOneof('arg_value')
        if which_val == 'float_value':
            result = float(arg_value.float_value)
            if math.ceil(result) == math.floor(result):
                result = int(result)
            return result
        if which_val == 'bool_values':
            return list(arg_value.bool_values.values)
        if which_val == 'string_value':
            return str(arg_value.string_value)
        raise ValueError(f'Unrecognized value type: {which_val!r}')

    if which == 'symbol':
        return sympy.Symbol(arg_proto.symbol)

    if which == 'func':
        func = arg_proto.func

        if func.type not in cast(Set[str], supported):
            raise ValueError(
                f'Unrecognized function type {func.type!r} '
                f'for arg_function_language={arg_function_language!r}')

        if func.type == 'add':
            return sympy.Add(*[
                _arg_from_proto(a,
                                arg_function_language=arg_function_language,
                                required_arg_name='An addition argument')
                for a in func.args
            ])

        if func.type == 'mul':
            return sympy.Mul(*[
                _arg_from_proto(a,
                                arg_function_language=arg_function_language,
                                required_arg_name='A multiplication argument')
                for a in func.args
            ])

    if required_arg_name is not None:
        raise ValueError(
            f'{required_arg_name} is missing or has an unrecognized '
            f'argument type (WhichOneof("arg")={which!r}).')

    return None