def _function_languages_from_arg( arg_proto: v2.program_pb2.Arg) -> Iterator[str]: which = arg_proto.WhichOneof('arg') if which == 'func': if arg_proto.func.type in ['add', 'mul']: yield 'linear' for a in arg_proto.func.args: yield from _function_languages_from_arg(a)
def _arg_from_proto( arg_proto: v2.program_pb2.Arg, *, arg_function_language: str, required_arg_name: Optional[str] = None, ) -> Optional[ARG_LIKE]: """Extracts a python value from an argument value proto. Args: arg_proto: The proto containing a serialized value. arg_function_language: The `arg_function_language` field from `Program.Language`. required_arg_name: If set to `None`, the method will return `None` when given an unset proto value. If set to a string, the method will instead raise an error complaining that the value is missing in that situation. Returns: The deserialized value, or else None if there was no set value and `required_arg_name` was set to `None`. """ supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE.get(arg_function_language) if supported is None: raise ValueError(f'Unrecognized arg_function_language: ' f'{arg_function_language!r}') which = arg_proto.WhichOneof('arg') if which == 'arg_value': arg_value = arg_proto.arg_value which_val = arg_value.WhichOneof('arg_value') if which_val == 'float_value': result = float(arg_value.float_value) if math.ceil(result) == math.floor(result): result = int(result) return result if which_val == 'bool_values': return list(arg_value.bool_values.values) if which_val == 'string_value': return str(arg_value.string_value) raise ValueError(f'Unrecognized value type: {which_val!r}') if which == 'symbol': return sympy.Symbol(arg_proto.symbol) if which == 'func': func = arg_proto.func if func.type not in cast(Set[str], supported): raise ValueError( f'Unrecognized function type {func.type!r} ' f'for arg_function_language={arg_function_language!r}') if func.type == 'add': return sympy.Add(*[ _arg_from_proto(a, arg_function_language=arg_function_language, required_arg_name='An addition argument') for a in func.args ]) if func.type == 'mul': return sympy.Mul(*[ _arg_from_proto(a, arg_function_language=arg_function_language, required_arg_name='A multiplication argument') for a in func.args ]) if required_arg_name is not None: raise ValueError( f'{required_arg_name} is missing or has an unrecognized ' f'argument type (WhichOneof("arg")={which!r}).') return None