def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, ) -> None: """convert a dataclass instance to tailing parser arguments""" def argparse_name(name: str): if name == "data": # normally data is positional args return name if name == "_name": # private member, skip return None return "--" + name.replace("_", "-") def get_kwargs_from_dc(dataclass_instance: FairseqDataclass, k: str) -> Dict[str, Any]: """k: dataclass attributes""" kwargs = {} field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] else: field_choices = None field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: if field_default is MISSING: kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices if (isinstance(inter_type, type) and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, float) elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: raise NotImplementedError("parsing of type " + str(inter_type) + " is not implemented") if field_default is not MISSING: kwargs["default"] = (",".join(map(str, field_default)) if field_default is not None else None) elif (isinstance(inter_type, type) and issubclass( inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): kwargs["default"] = field_default.value else: kwargs["default"] = field_default elif inter_type is bool: kwargs["action"] = ("store_false" if field_default is True else "store_true") kwargs["default"] = field_default else: kwargs["type"] = inter_type if field_default is not MISSING: kwargs["default"] = field_default kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" return kwargs for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) field_type = dataclass_instance._get_type(k) if field_name is None: continue elif inspect.isclass(field_type) and issubclass( field_type, FairseqDataclass): gen_parser_from_dataclass(parser, field_type(), delete_default) continue kwargs = get_kwargs_from_dc(dataclass_instance, k) field_args = [field_name] alias = dataclass_instance._get_argparse_alias(k) if alias is not None: field_args.append(alias) if "default" in kwargs: if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): if kwargs["help"] is None: # this is a field with a name that will be added elsewhere continue else: del kwargs["default"] if delete_default: del kwargs["default"] try: parser.add_argument(*field_args, **kwargs) except ArgumentError: pass
def gen_parser_from_dataclass( parser: ArgumentParser, dataclass_instance: FairseqDataclass, delete_default: bool = False, with_prefix: Optional[str] = None, ) -> None: """ convert a dataclass instance to tailing parser arguments. If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are building a flat namespace from a structured dataclass (see transformer_config.py for example). """ def argparse_name(name: str): if name == "data" and (with_prefix is None or with_prefix == ''): # normally data is positional args, so we don't add the -- nor the prefix return name if name == "_name": # private member, skip return None full_name = "--" + name.replace("_", "-") if with_prefix is not None and with_prefix != '': # if a prefix is specified, construct the prefixed arg name full_name = with_prefix + "-" + full_name[ 2:] # strip -- when composing return full_name def get_kwargs_from_dc(dataclass_instance: FairseqDataclass, k: str) -> Dict[str, Any]: """k: dataclass attributes""" kwargs = {} field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] else: field_choices = None field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: if field_default is MISSING: kwargs["required"] = True if field_choices is not None: kwargs["choices"] = field_choices if (isinstance(inter_type, type) and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): if "int" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, int) elif "float" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, float) elif "str" in str(inter_type): kwargs["type"] = lambda x: eval_str_list(x, str) else: raise NotImplementedError("parsing of type " + str(inter_type) + " is not implemented") if field_default is not MISSING: kwargs["default"] = (",".join(map(str, field_default)) if field_default is not None else None) elif (isinstance(inter_type, type) and issubclass( inter_type, Enum)) or "Enum" in str(inter_type): kwargs["type"] = str if field_default is not MISSING: if isinstance(field_default, Enum): kwargs["default"] = field_default.value else: kwargs["default"] = field_default elif inter_type is bool: kwargs["action"] = ("store_false" if field_default is True else "store_true") kwargs["default"] = field_default else: kwargs["type"] = inter_type if field_default is not MISSING: kwargs["default"] = field_default # build the help with the hierarchical prefix if with_prefix is not None and with_prefix != '' and field_help is not None: field_help = with_prefix[2:] + ': ' + field_help kwargs["help"] = field_help if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" return kwargs for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) field_type = dataclass_instance._get_type(k) if field_name is None: continue elif inspect.isclass(field_type) and issubclass( field_type, FairseqDataclass): # for fields that are of type FairseqDataclass, we can recursively # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) prefix = None if with_prefix is not None: # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace # but we prefix them with the name of the current field. prefix = field_name gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) continue kwargs = get_kwargs_from_dc(dataclass_instance, k) field_args = [field_name] alias = dataclass_instance._get_argparse_alias(k) if alias is not None: field_args.append(alias) if "default" in kwargs: if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): if kwargs["help"] is None: # this is a field with a name that will be added elsewhere continue else: del kwargs["default"] if delete_default and "default" in kwargs: del kwargs["default"] try: parser.add_argument(*field_args, **kwargs) except ArgumentError: pass