Exemplo n.º 1
0
    def from_params(cls: Type[T], params: Params, **extras) -> T:
        """
        This is the automatic implementation of `from_params`. Any class that subclasses `FromParams`
        (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free.
        If you want your class to be instantiated from params in the "obvious" way -- pop off parameters
        and hand them to your constructor with the same names -- this provides that functionality.

        If you need more complex logic in your from `from_params` method, you'll have to implement
        your own method that overrides this one.
        """
        # pylint: disable=protected-access
        from srl_model.common.registrable import Registrable  # import here to avoid circular imports

        logger.info(
            f"instantiating class {cls} from params {getattr(params, 'params', params)} "
            f"and extras {extras}")

        if params is None:
            return None

        registered_subclasses = Registrable._registry.get(cls)
        if registered_subclasses is not None:
            # We know ``cls`` inherits from Registrable, so we'll use a cast to make mypy happy.
            # We have to use a disable to make pylint happy.
            # pylint: disable=no-member
            as_registrable = cast(Type[Registrable], cls)
            default_to_first_choice = as_registrable.default_implementation is not None
            choice = params.pop_choice(
                "type",
                choices=as_registrable.list_available(),
                default_to_first_choice=default_to_first_choice)
            subclass = registered_subclasses[choice]
            # We want to call subclass.from_params. It's possible that it's just the "free"
            # implementation here, in which case it accepts `**extras` and we are not able
            # to make any assumptions about what extra parameters it needs.
            #
            # It's also possible that it has a custom `from_params` method. In that case it
            # won't accept any **extra parameters and we'll need to filter them out.
            if not takes_arg(subclass.from_params, 'extras'):
                # Necessarily subclass.from_params is a custom implementation, so we need to
                # pass it only the args it's expecting.
                extras = {
                    k: v
                    for k, v in extras.items()
                    if takes_arg(subclass.from_params, k)
                }
            return subclass.from_params(params=params, **extras)
        else:
            # This is not a base class, so convert our params and extras into a dict of kwargs.
            if cls.__init__ == object.__init__:
                # This class does not have an explicit constructor, so don't give it any kwargs.
                # Without this logic, create_kwargs will look at object.__init__ and see that
                # it takes *args and **kwargs and look for those.
                kwargs: Dict[str, Any] = {}
            else:
                # This class has a constructor, so create kwargs for it.
                kwargs = create_kwargs(cls, params, **extras)

            return cls(**kwargs)  # type: ignore
Exemplo n.º 2
0
    def from_params(cls, optimizer: torch.optim.Optimizer, params: Params):  # type: ignore
        # pylint: disable=arguments-differ
        scheduler = params.pop_choice("type", LearningRateScheduler.list_available())

        schedulers = LearningRateScheduler.by_name(scheduler)(optimizer, **params.as_dict())  # type: ignore
        if isinstance(schedulers, torch.optim.lr_scheduler.ReduceLROnPlateau):
            return LearningRateWithMetricsWrapper(schedulers)
        else:
            return LearningRateWithoutMetricsWrapper(schedulers)