def setup( cls: Type[Dataclass], arguments: Optional[str] = "", dest: Optional[str] = None, default: Optional[Dataclass] = None, conflict_resolution_mode: ConflictResolution = ConflictResolution.AUTO, add_option_string_dash_variants: bool = False, parse_known_args: bool = False, attempt_to_reorder: bool = False, ) -> Dataclass: """Basic setup for a test. Keyword Arguments: arguments {Optional[str]} -- The arguments to pass to the parser (default: {""}) dest {Optional[str]} -- the attribute where the argument should be stored. (default: {None}) Returns: {cls}} -- the class's type. """ parser = simple_parsing.ArgumentParser( conflict_resolution=conflict_resolution_mode, add_option_string_dash_variants=add_option_string_dash_variants, ) if dest is None: dest = camel_case(cls.__name__) parser.add_arguments(cls, dest=dest, default=default) if arguments is None: if parse_known_args: args = parser.parse_known_args( attempt_to_reorder=attempt_to_reorder) else: args = parser.parse_args() else: splits = shlex.split(arguments) if parse_known_args: args, unknown_args = parser.parse_known_args( splits, attempt_to_reorder=attempt_to_reorder) else: args = parser.parse_args(splits) assert hasattr(args, dest), f"attribute '{dest}' not found in args {args}" instance: Dataclass = getattr(args, dest) # type: ignore delattr(args, dest) assert args == argparse.Namespace( ), f"Namespace has leftover garbage values: {args}" instance = cast(Dataclass, instance) return instance
def setup_multiple(cls: Type[Dataclass], num_to_parse: int, arguments: Optional[str] = "") -> Tuple[Dataclass, ...]: conflict_resolution_mode: ConflictResolution = ConflictResolution.ALWAYS_MERGE parser = simple_parsing.ArgumentParser( conflict_resolution=conflict_resolution_mode) class_name = camel_case(cls.__name__) for i in range(num_to_parse): parser.add_arguments(cls, f"{class_name}_{i}") if arguments is None: args = parser.parse_args() else: splits = shlex.split(arguments) args = parser.parse_args(splits) return tuple( getattr(args, f"{class_name}_{i}") for i in range(num_to_parse))