def merge(preferred_value: Any, fallback_value: Any) -> Any: if isinstance(preferred_value, dict) and isinstance( fallback_value, dict): return with_fallback(preferred_value, fallback_value) elif isinstance(preferred_value, dict) and isinstance( fallback_value, list): # treat preferred_value as a sparse list, where each key is an index to be overridden merged_list = fallback_value for elem_key, preferred_element in preferred_value.items(): try: index = int(elem_key) merged_list[index] = merge(preferred_element, fallback_value[index]) except ValueError: raise ConfigurationError( "could not merge dicts - the preferred dict contains " f"invalid keys (key {elem_key} is not a valid list index)" ) except IndexError: raise ConfigurationError( "could not merge dicts - the preferred dict contains " f"invalid keys (key {index} is out of bounds)") return merged_list else: return copy.deepcopy(preferred_value)
def unflatten(flat_dict: Dict[str, Any]) -> Dict[str, Any]: """ Given a "flattened" dict with compound keys, e.g. {"a.b": 0} unflatten it: {"a": {"b": 0}} """ unflat: Dict[str, Any] = {} for compound_key, value in flat_dict.items(): curr_dict = unflat parts = compound_key.split(".") for key in parts[:-1]: curr_value = curr_dict.get(key) if key not in curr_dict: curr_dict[key] = {} curr_dict = curr_dict[key] elif isinstance(curr_value, dict): curr_dict = curr_value else: raise ConfigurationError("flattened dictionary is invalid") if not isinstance(curr_dict, dict) or parts[-1] in curr_dict: raise ConfigurationError("flattened dictionary is invalid") curr_dict[parts[-1]] = value return unflat
def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any: """ Performs the functionality associated with dict.pop(key), along with checking for returned dictionaries, replacing them with Param objects with an updated history (unless keep_as_dict is True, in which case we leave them as dictionaries). If `key` is not present in the dictionary, and no default was specified, we raise a `ConfigurationError`, instead of the typical `KeyError`. """ if default is self.DEFAULT: try: value = self.params.pop(key) except KeyError: msg = f'key "{key}" is required' if self.history: msg += f' at location "{self.history}"' raise ConfigurationError(msg) else: value = self.params.pop(key, default) if keep_as_dict or _is_dict_free(value): logger.info(f"{self.history}{key} = {value}") return value else: return self._check_is_dict(key, value)
def resolve_class_name(cls: Type[T], name: str) -> Tuple[Type[T], Optional[str]]: """ Returns the subclass that corresponds to the given `name`, along with the name of the method that was registered as a constructor for that `name`, if any. This method also allows `name` to be a fully-specified module name, instead of a name that was already added to the `Registry`. In that case, you cannot use a separate function as a constructor (as you need to call `cls.register()` in order to tell us what separate function to use). """ if name in Registrable._registry[cls]: subclass, constructor = Registrable._registry[cls].get(name) return subclass, constructor elif "." in name: # This might be a fully qualified class name, so we'll try importing its "module" # and finding it there. parts = name.split(".") submodule = ".".join(parts[:-1]) class_name = parts[-1] try: module = importlib.import_module(submodule) except ModuleNotFoundError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to import module {submodule}" ) try: subclass = getattr(module, class_name) constructor = None return subclass, constructor except AttributeError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to find class {class_name} in {submodule}" ) else: # is not a qualified class name raise ConfigurationError( f"{name} is not a registered name for {cls.__name__}. " "You probably need to use the --include-package flag " "to load your custom code. Alternatively, you can specify your choices " """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ "in which case they will be automatically imported correctly." )
def pop_and_construct_arg(class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, **extras) -> Any: """ Does the work of actually constructing an individual argument for [`create_kwargs`](./#create_kwargs). Here we're in the inner loop of iterating over the parameters to a particular constructor, trying to construct just one of them. The information we get for that parameter is its name, its type annotation, and its default value; we also get the full set of `Params` for constructing the object (which we may mutate), and any `extras` that the constructor might need. We take the type annotation and default value here separately, instead of using an `inspect.Parameter` object directly, so that we can handle `Union` types using recursion on this method, trying the different annotation types in the union in turn. """ #from allenai_common.models.archival import load_archive # import here to avoid circular imports # We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in # `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back # to using `name`. name = argument_name # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: if name not in params: return extras[name] else: logger.warning( f"Parameter {name} for class {class_name} was found in both " "**extras and in params. Using the specification found in params, " "but you probably put a key in a config file that you didn't need, " "and if it is different from what we get from **extras, you might " "get unexpected behavior.") # Next case is when argument should be loaded from pretrained archive. elif (name in params and isinstance(params.get(name), Params) and "_pretrained" in params.get(name)): load_module_params = params.pop(name).pop("_pretrained") archive_file = load_module_params.pop("archive_file") module_path = load_module_params.pop("module_path") freeze = load_module_params.pop("freeze", True) archive = load_archive(archive_file) result = archive.extract_module(module_path, freeze) if not isinstance(result, annotation): raise ConfigurationError( f"The module from model at {archive_file} at path {module_path} " f"was expected of type {annotation} but is of type {type(result)}" ) return result popped_params = params.pop( name, default) if default != _NO_DEFAULT else params.pop(name) if popped_params is None: return None return construct_arg(class_name, name, popped_params, annotation, default, **extras)
def pop_choice( self, key: str, choices: List[Any], default_to_first_choice: bool = False, allow_class_names: bool = True, ) -> Any: """ Gets the value of `key` in the `params` dictionary, ensuring that the value is one of the given choices. Note that this `pops` the key from params, modifying the dictionary, consistent with how parameters are processed in this codebase. # Parameters key: `str` Key to get the value from in the param dictionary choices: `List[Any]` A list of valid options for values corresponding to `key`. For example, if you're specifying the type of encoder to use for some part of your model, the choices might be the list of encoder classes we know about and can instantiate. If the value we find in the param dictionary is not in `choices`, we raise a `ConfigurationError`, because the user specified an invalid value in their parameter file. default_to_first_choice: `bool`, optional (default = `False`) If this is `True`, we allow the `key` to not be present in the parameter dictionary. If the key is not present, we will use the return as the value the first choice in the `choices` list. If this is `False`, we raise a `ConfigurationError`, because specifying the `key` is required (e.g., you `have` to specify your model class when running an experiment, but you can feel free to use default settings for encoders if you want). allow_class_names: `bool`, optional (default = `True`) If this is `True`, then we allow unknown choices that look like fully-qualified class names. This is to allow e.g. specifying a model type as my_library.my_model.MyModel and importing it on the fly. Our check for "looks like" is extremely lenient and consists of checking that the value contains a '.'. """ default = choices[0] if default_to_first_choice else self.DEFAULT value = self.pop(key, default) ok_because_class_name = allow_class_names and "." in value if value not in choices and not ok_because_class_name: key_str = self.history + key message = ( f"{value} not in acceptable choices for {key_str}: {choices}. " "You should either use the --include-package flag to make sure the correct module " "is loaded, or use a fully qualified class name in your config file like " """{"model": "my_module.models.MyModel"} to have it imported automatically.""" ) raise ConfigurationError(message) return value
def assert_empty(self, class_name: str): """ Raises a `ConfigurationError` if `self.params` is not empty. We take `class_name` as an argument so that the error message gives some idea of where an error happened, if there was one. `class_name` should be the name of the `calling` class, the one that got extra parameters (if there are any). """ if self.params: raise ConfigurationError( "Extra parameters passed to {}: {}".format( class_name, self.params))
def list_available(cls) -> List[str]: """List default first if it exists""" keys = list(Registrable._registry[cls].keys()) default = cls.default_implementation if default is None: return keys elif default not in keys: raise ConfigurationError(f"Default implementation {default} is not registered") else: return [default] + [k for k in keys if k != default]
def takes_arg(obj, arg: str) -> bool: """ Checks whether the provided obj takes a certain arg. If it's a class, we're really checking whether its constructor does. If it's a function or method, we're checking the object itself. Otherwise, we raise an error. """ if inspect.isclass(obj): signature = inspect.signature(obj.__init__) elif inspect.ismethod(obj) or inspect.isfunction(obj): signature = inspect.signature(obj) else: raise ConfigurationError(f"object {obj} is not callable") return arg in signature.parameters
def takes_kwargs(obj) -> bool: """ Checks whether a provided object takes in any positional arguments. Similar to takes_arg, we do this for both the __init__ function of the class or a function / method Otherwise, we raise an error """ if inspect.isclass(obj): signature = inspect.signature(obj.__init__) elif inspect.ismethod(obj) or inspect.isfunction(obj): signature = inspect.signature(obj) else: raise ConfigurationError(f"object {obj} is not callable") return any(p.kind == inspect.Parameter.VAR_KEYWORD # type: ignore for p in signature.parameters.values())
def add_subclass_to_registry(subclass: Type[T]): # Add to registry, raise an error if key has already been used. if name in registry: if exist_ok: message = ( f"{name} has already been registered as {registry[name][0].__name__}, but " f"exist_ok=True, so overwriting with {cls.__name__}" ) logger.info(message) else: message = ( f"Cannot register {name} as {cls.__name__}; " f"name already in use for {registry[name][0].__name__}" ) raise ConfigurationError(message) registry[name] = (subclass, constructor) return subclass
def from_params( cls: Type[T], params: Params, constructor_to_call: Callable[..., T] = None, constructor_to_inspect: Union[Callable[..., T], Callable[[T], None]] = None, **extras, ) -> T: """ This is the automatic implementation of `from_params`. Any class that subclasses `FromParams` (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free. If you want your class to be instantiated from params in the "obvious" way -- pop off parameters and hand them to your constructor with the same names -- this provides that functionality. If you need more complex logic in your from `from_params` method, you'll have to implement your own method that overrides this one. The `constructor_to_call` and `constructor_to_inspect` arguments deal with a bit of redirection that we do. We allow you to register particular `@classmethods` on a class as the constructor to use for a registered name. This lets you, e.g., have a single `Vocabulary` class that can be constructed in two different ways, with different names registered to each constructor. In order to handle this, we need to know not just the class we're trying to construct (`cls`), but also what method we should inspect to find its arguments (`constructor_to_inspect`), and what method to call when we're done constructing arguments (`constructor_to_call`). These two methods are the same when you've used a `@classmethod` as your constructor, but they are `different` when you use the default constructor (because you inspect `__init__`, but call `cls()`). """ # import here to avoid circular imports from allenai_common.registrable import Registrable logger.debug( f"instantiating class {cls} from params {getattr(params, 'params', params)} " f"and extras {set(extras.keys())}") if params is None: return None if isinstance(params, str): params = Params({"type": params}) if not isinstance(params, Params): raise ConfigurationError( "from_params was passed a `params` object that was not a `Params`. This probably " "indicates malformed parameters in a configuration file, where something that " "should have been a dictionary was actually a list, or something else. " f"This happened when constructing an object of type {cls}.") registered_subclasses = Registrable._registry.get(cls) if is_base_registrable(cls) and registered_subclasses is None: # NOTE(mattg): There are some potential corner cases in this logic if you have nested # Registrable types. We don't currently have any of those, but if we ever get them, # adding some logic to check `constructor_to_call` should solve the issue. Not # bothering to add that unnecessary complexity for now. raise ConfigurationError( "Tried to construct an abstract Registrable base class that has no registered " "concrete types. This might mean that you need to use --include-package to get " "your concrete classes actually registered.") if registered_subclasses is not None and not constructor_to_call: # We know `cls` inherits from Registrable, so we'll use a cast to make mypy happy. as_registrable = cast(Type[Registrable], cls) default_to_first_choice = as_registrable.default_implementation is not None choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=default_to_first_choice, ) subclass, constructor_name = as_registrable.resolve_class_name( choice) # See the docstring for an explanation of what's going on here. if not constructor_name: constructor_to_inspect = subclass.__init__ constructor_to_call = subclass # type: ignore else: constructor_to_inspect = cast( Callable[..., T], getattr(subclass, constructor_name)) constructor_to_call = constructor_to_inspect if hasattr(subclass, "from_params"): # We want to call subclass.from_params. extras = create_extras(subclass, extras) # mypy can't follow the typing redirection that we do, so we explicitly cast here. retyped_subclass = cast(Type[T], subclass) return retyped_subclass.from_params( params=params, constructor_to_call=constructor_to_call, constructor_to_inspect=constructor_to_inspect, **extras, ) else: # In some rare cases, we get a registered subclass that does _not_ have a # from_params method (this happens with Activations, for instance, where we # register pytorch modules directly). This is a bit of a hack to make those work, # instead of adding a `from_params` method for them somehow. We just trust that # you've done the right thing in passing your parameters, and nothing else needs to # be recursively constructed. return subclass(**params) # type: ignore else: # This is not a base class, so convert our params and extras into a dict of kwargs. # See the docstring for an explanation of what's going on here. if not constructor_to_inspect: constructor_to_inspect = cls.__init__ if not constructor_to_call: constructor_to_call = cls if constructor_to_inspect == object.__init__: # This class does not have an explicit constructor, so don't give it any kwargs. # Without this logic, create_kwargs will look at object.__init__ and see that # it takes *args and **kwargs and look for those. kwargs: Dict[str, Any] = {} params.assert_empty(cls.__name__) else: # This class has a constructor, so create kwargs for it. constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect) kwargs = create_kwargs(constructor_to_inspect, cls, params, **extras) return constructor_to_call(**kwargs) # type: ignore
def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, **extras, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if hasattr(annotation, "from_params"): if popped_params is default: return default elif popped_params is not None: # Our params have an entry for this, so we use that. subextras = create_extras(annotation, extras) # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): popped_params = Params({"type": popped_params}) elif isinstance(popped_params, dict): popped_params = Params(popped_params) return annotation.from_params(params=popped_params, **subextras) elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError( f"expected key {argument_name} for {class_name}") else: return default # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation in {int, bool}: if type(popped_params) in {int, bool}: return annotation(popped_params) else: raise TypeError( f"Expected {argument_name} to be a {annotation.__name__}.") elif annotation == str: # Strings are special because we allow casting from Path to str. if type(popped_params) == str or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError(f"Expected {argument_name} to be a string.") elif annotation == float: # Floats are special because in Python, you can put an int wherever you can put a float. # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html if type(popped_params) in {int, float}: return popped_params else: raise TypeError(f"Expected {argument_name} to be numeric.") # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif (origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2 and can_construct_from_params(args[-1])): value_cls = annotation.__args__[-1] value_dict = {} if not isinstance(popped_params, Mapping): raise TypeError( f"Expected {argument_name} to be a Mapping (probably a dict or a Params object)." ) for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, **extras, ) return value_dict elif origin in (Tuple, tuple) and all( can_construct_from_params(arg) for arg in args): value_list = [] for i, (value_cls, value_params) in enumerate( zip(annotation.__args__, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1 and can_construct_from_params( args[0]): value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_set.add(value) return value_set elif origin == Union: # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, **extras, ) except (ValueError, TypeError, ConfigurationError, AttributeError): # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) # If none of them succeeded, we crash. raise ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}" ) elif origin == Lazy: if popped_params is default: return default value_cls = args[0] subextras = create_extras(value_cls, extras) # type: ignore return Lazy(value_cls, params=deepcopy(popped_params), contructor_extras=subextras) # For any other kind of iterable, we will just assume that a list is good enough, and treat # it the same as List. This condition needs to be at the end, so we don't catch other kinds # of Iterables with this branch. elif (origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1 and can_construct_from_params(args[0])): value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, **extras, ) value_list.append(value) return value_list else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict() return popped_params