Example #1
0
 def merge(preferred_value: Any, fallback_value: Any) -> Any:
     if isinstance(preferred_value, dict) and isinstance(
             fallback_value, dict):
         return with_fallback(preferred_value, fallback_value)
     elif isinstance(preferred_value, dict) and isinstance(
             fallback_value, list):
         # treat preferred_value as a sparse list, where each key is an index to be overridden
         merged_list = fallback_value
         for elem_key, preferred_element in preferred_value.items():
             try:
                 index = int(elem_key)
                 merged_list[index] = merge(preferred_element,
                                            fallback_value[index])
             except ValueError:
                 raise ConfigurationError(
                     "could not merge dicts - the preferred dict contains "
                     f"invalid keys (key {elem_key} is not a valid list index)"
                 )
             except IndexError:
                 raise ConfigurationError(
                     "could not merge dicts - the preferred dict contains "
                     f"invalid keys (key {index} is out of bounds)")
         return merged_list
     else:
         return copy.deepcopy(preferred_value)
Example #2
0
def unflatten(flat_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Given a "flattened" dict with compound keys, e.g.
        {"a.b": 0}
    unflatten it:
        {"a": {"b": 0}}
    """
    unflat: Dict[str, Any] = {}

    for compound_key, value in flat_dict.items():
        curr_dict = unflat
        parts = compound_key.split(".")
        for key in parts[:-1]:
            curr_value = curr_dict.get(key)
            if key not in curr_dict:
                curr_dict[key] = {}
                curr_dict = curr_dict[key]
            elif isinstance(curr_value, dict):
                curr_dict = curr_value
            else:
                raise ConfigurationError("flattened dictionary is invalid")
        if not isinstance(curr_dict, dict) or parts[-1] in curr_dict:
            raise ConfigurationError("flattened dictionary is invalid")
        curr_dict[parts[-1]] = value

    return unflat
Example #3
0
    def pop(self,
            key: str,
            default: Any = DEFAULT,
            keep_as_dict: bool = False) -> Any:
        """
        Performs the functionality associated with dict.pop(key), along with checking for
        returned dictionaries, replacing them with Param objects with an updated history
        (unless keep_as_dict is True, in which case we leave them as dictionaries).
        If ``key`` is not present in the dictionary, and no default was specified, we raise a
        ``ConfigurationError``, instead of the typical ``KeyError``.
        """
        if default is self.DEFAULT:
            try:
                value = self.params.pop(key)
            except KeyError:
                raise ConfigurationError(
                    'key "{}" is required at location "{}"'.format(
                        key, self.history))
        else:
            value = self.params.pop(key, default)

        if keep_as_dict or _is_dict_free(value):
            logger.info(f"{self.history}{key} = {value}")
            return value
        else:
            return self._check_is_dict(key, value)
 def step(self, metric: float = None, epoch: int = None) -> None:
     if metric is None:
         raise ConfigurationError(
             "This learning rate scheduler requires "
             "a validation metric to compute the schedule and therefore "
             "must be used with a validation dataset."
         )
     self.lr_scheduler.step(metric, epoch)
Example #5
0
 def assert_empty(self, class_name: str):
     """
     Raises a ``ConfigurationError`` if ``self.params`` is not empty.  We take ``class_name`` as
     an argument so that the error message gives some idea of where an error happened, if there
     was one.  ``class_name`` should be the name of the `calling` class, the one that got extra
     parameters (if there are any).
     """
     if self.params:
         raise ConfigurationError(
             "Extra parameters passed to {}: {}".format(
                 class_name, self.params))
Example #6
0
    def list_available(cls) -> List[str]:
        """List default first if it exists"""
        keys = list(Registrable._registry[cls].keys())
        default = cls.default_implementation

        if default is None:
            return keys
        elif default not in keys:
            message = "Default implementation %s is not registered" % default
            raise ConfigurationError(message)
        else:
            return [default] + [k for k in keys if k != default]
Example #7
0
def takes_arg(obj, arg: str) -> bool:
    """
    Checks whether the provided obj takes a certain arg.
    If it's a class, we're really checking whether its constructor does.
    If it's a function or method, we're checking the object itself.
    Otherwise, we raise an error.
    """
    if inspect.isclass(obj):
        signature = inspect.signature(obj.__init__)
    elif inspect.ismethod(obj) or inspect.isfunction(obj):
        signature = inspect.signature(obj)
    else:
        raise ConfigurationError(f"object {obj} is not callable")
    return arg in signature.parameters
Example #8
0
 def get(self, key: str, default: Any = DEFAULT):
     """
     Performs the functionality associated with dict.get(key) but also checks for returned
     dicts and returns a Params object in their place with an updated history.
     """
     if default is self.DEFAULT:
         try:
             value = self.params.get(key)
         except KeyError:
             raise ConfigurationError(
                 'key "{}" is required at location "{}"'.format(
                     key, self.history))
     else:
         value = self.params.get(key, default)
     return self._check_is_dict(key, value)
Example #9
0
    def by_name(cls: Type[T], name: str) -> Type[T]:
        logger.debug(f"instantiating registered subclass {name} of {cls}")
        if name in Registrable._registry[cls]:
            return Registrable._registry[cls].get(name)
        elif "." in name:
            # This might be a fully qualified class name, so we'll try importing its "module"
            # and finding it there.
            parts = name.split(".")
            submodule = ".".join(parts[:-1])
            class_name = parts[-1]

            try:
                module = importlib.import_module(submodule)
            except ModuleNotFoundError:
                raise ConfigurationError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to import module {submodule}"
                )

            try:
                return getattr(module, class_name)
            except AttributeError:
                raise ConfigurationError(
                    f"tried to interpret {name} as a path to a class "
                    f"but unable to find class {class_name} in {submodule}"
                )

        else:
            # is not a qualified class name
            raise ConfigurationError(
                f"{name} is not a registered name for {cls.__name__}. "
                "You probably need to use the --include-package flag "
                "to load your custom code. Alternatively, you can specify your choices "
                """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
                "in which case they will be automatically imported correctly."
            )
Example #10
0
 def pop_choice(
     self,
     key: str,
     choices: List[Any],
     default_to_first_choice: bool = False,
     allow_class_names: bool = True,
 ) -> Any:
     """
     Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
     the given choices. Note that this `pops` the key from params, modifying the dictionary,
     consistent with how parameters are processed in this codebase.
     Parameters
     ----------
     key: str
         Key to get the value from in the param dictionary
     choices: List[Any]
         A list of valid options for values corresponding to ``key``.  For example, if you're
         specifying the type of encoder to use for some part of your model, the choices might be
         the list of encoder classes we know about and can instantiate.  If the value we find in
         the param dictionary is not in ``choices``, we raise a ``ConfigurationError``, because
         the user specified an invalid value in their parameter file.
     default_to_first_choice: bool, optional (default=False)
         If this is ``True``, we allow the ``key`` to not be present in the parameter
         dictionary.  If the key is not present, we will use the return as the value the first
         choice in the ``choices`` list.  If this is ``False``, we raise a
         ``ConfigurationError``, because specifying the ``key`` is required (e.g., you `have` to
         specify your model class when running an experiment, but you can feel free to use
         default settings for encoders if you want).
     allow_class_names : bool, optional (default = True)
         If this is `True`, then we allow unknown choices that look like fully-qualified class names.
         This is to allow e.g. specifying a model type as my_library.my_model.MyModel
         and importing it on the fly. Our check for "looks like" is extremely lenient
         and consists of checking that the value contains a '.'.
     """
     default = choices[0] if default_to_first_choice else self.DEFAULT
     value = self.pop(key, default)
     ok_because_class_name = allow_class_names and "." in value
     if value not in choices and not ok_because_class_name:
         key_str = self.history + key
         message = (
             f"{value} not in acceptable choices for {key_str}: {choices}. "
             "You should either use the --include-package flag to make sure the correct module "
             "is loaded, or use a fully qualified class name in your config file like "
             """{"model": "my_module.models.MyModel"} to have it imported automatically."""
         )
         raise ConfigurationError(message)
     return value
Example #11
0
 def add_subclass_to_registry(subclass: Type[T]):
     # Add to registry, raise an error if key has already been used.
     if name in registry:
         if exist_ok:
             message = (
                 f"{name} has already been registered as {registry[name].__name__}, but "
                 f"exist_ok=True, so overwriting with {cls.__name__}"
             )
             logger.info(message)
         else:
             message = (
                 f"Cannot register {name} as {cls.__name__}; "
                 f"name already in use for {registry[name].__name__}"
             )
             raise ConfigurationError(message)
     registry[name] = subclass
     return subclass
Example #12
0
def takes_kwargs(obj) -> bool:
    """
    Checks whether a provided object takes in any positional arguments.
    Similar to takes_arg, we do this for both the __init__ function of
    the class or a function / method
    Otherwise, we raise an error
    """
    if inspect.isclass(obj):
        signature = inspect.signature(obj.__init__)
    elif inspect.ismethod(obj) or inspect.isfunction(obj):
        signature = inspect.signature(obj)
    else:
        raise ConfigurationError(f"object {obj} is not callable")
    return bool(
        any([
            p.kind == inspect.Parameter.VAR_KEYWORD  # type: ignore
            for p in signature.parameters.values()
        ]))
Example #13
0
def construct_arg(cls: Type[T], param_name: str, annotation: Type,
                  default: Any, params: Params, **extras) -> Any:
    """
    Does the work of actually constructing an individual argument for :func:`create_kwargs`.
    Here we're in the inner loop of iterating over the parameters to a particular constructor,
    trying to construct just one of them.  The information we get for that parameter is its name,
    its type annotation, and its default value; we also get the full set of ``Params`` for
    constructing the object (which we may mutate), and any ``extras`` that the constructor might
    need.
    We take the type annotation and default value here separately, instead of using an
    ``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on
    this method, trying the different annotation types in the union in turn.
    """
    from allennlp.models.archival import load_archive  # import here to avoid circular imports

    # We used `param_name` as the method argument to avoid conflicts with 'name' being a key in
    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back
    # to using `name`.
    name = param_name
    origin = getattr(annotation, "__origin__", None)
    args = getattr(annotation, "__args__", [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
    # We check the provided `extras` for these and just use them if they exist.
    if name in extras:
        return extras[name]
    # Next case is when argument should be loaded from pretrained archive.
    elif (name in params and isinstance(params.get(name), Params)
          and "_pretrained" in params.get(name)):
        load_module_params = params.pop(name).pop("_pretrained")
        archive_file = load_module_params.pop("archive_file")
        module_path = load_module_params.pop("module_path")
        freeze = load_module_params.pop("freeze", True)
        archive = load_archive(archive_file)
        result = archive.extract_module(module_path, freeze)
        if not isinstance(result, annotation):
            raise ConfigurationError(
                f"The module from model at {archive_file} at path {module_path} "
                f"was expected of type {annotation} but is of type {type(result)}"
            )
        return result
    # The next case is when the parameter type is itself constructible from_params.
    elif hasattr(annotation, "from_params"):
        if name in params:
            # Our params have an entry for this, so we use that.
            subparams = params.pop(name)

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(subparams, str):
                return annotation.by_name(subparams)()
            else:
                return annotation.from_params(params=subparams, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(f"expected key {name} for {cls.__name__}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation == str:
        return params.pop(name, default) if optional else params.pop(name)
    elif annotation == int:
        return params.pop_int(name,
                              default) if optional else params.pop_int(name)
    elif annotation == bool:
        return params.pop_bool(name,
                               default) if optional else params.pop_bool(name)
    elif annotation == float:
        return params.pop_float(
            name, default) if optional else params.pop_float(name)

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict, dict) and len(args) == 2 and hasattr(
            args[-1], "from_params"):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in params.pop(name, Params({})).items():
            subextras = create_extras(value_cls, extras)
            value_dict[key] = value_cls.from_params(params=value_params,
                                                    **subextras)

        return value_dict

    elif origin in (List, list) and len(args) == 1 and hasattr(
            args[0], "from_params"):
        value_cls = annotation.__args__[0]

        value_list = []

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return value_list

    elif origin in (Tuple, tuple) and all(
            hasattr(arg, "from_params") for arg in args):
        value_list = []

        for value_cls, value_params in zip(annotation.__args__,
                                           params.pop(name, Params({}))):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and hasattr(
            args[0], "from_params"):
        value_cls = annotation.__args__[0]

        value_set = set()

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_set.add(
                value_cls.from_params(params=value_params, **subextras))

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        param_value = params.get(name, Params({}))
        if isinstance(param_value, Params):
            param_value = param_value.duplicate()

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg in args:
            try:
                return construct_arg(cls, name, arg, default, params, **extras)
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have popped `params[name]`, so we
                # restore it here.
                params[name] = param_value
                if isinstance(param_value, Params):
                    param_value = param_value.duplicate()
                continue

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {name} with type {annotation}")
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if optional:
            return params.pop(name, default, keep_as_dict=True)
        else:
            return params.pop(name, keep_as_dict=True)