Пример #1
0
    def from_params(cls, model: BaseFairseqModel, task_list: List[Task],
                    serialization_dir: str, params: Params) -> 'Trainer':
        """
    """
        cuda_device = params.pop_int("cuda_device", -1)
        grad_clipping = params.pop_float("grad_clipping", 0.1)
        grad_norm = params.pop_float("grad_norm", 5.0)
        min_lr = params.pop_float("min_lr", 1e-7)
        num_epochs = params.pop_int("num_epochs", 100)
        patience = params.pop_int("patience", 5)

        optimizer_params = params.pop("optimizer", None)
        parameters_to_train = [(n, p) for n, p in model.named_parameters()
                               if p.requires_grad]
        optimizer_ = Optimizer.from_params(
            model_parameters=parameters_to_train, params=optimizer_params)

        return cls(model=model,
                   task_list=task_list,
                   serialization_dir=serialization_dir,
                   cuda_device=cuda_device,
                   grad_clipping=grad_clipping,
                   grad_norm=grad_norm,
                   min_lr=min_lr,
                   num_epochs=num_epochs,
                   patience=patience,
                   optimizer=optimizer_)
Пример #2
0
def construct_arg(
        cls: Type[T],  # pylint: disable=inconsistent-return-statements,too-many-return-statements
        param_name: str,
        annotation: Type,
        default: Any,
        params: Params,
        **extras) -> Any:
    """
    Does the work of actually constructing an individual argument for :func:`create_kwargs`.

    Here we're in the inner loop of iterating over the parameters to a particular constructor,
    trying to construct just one of them.  The information we get for that parameter is its name,
    its type annotation, and its default value; we also get the full set of ``Params`` for
    constructing the object (which we may mutate), and any ``extras`` that the constructor might
    need.

    We take the type annotation and default value here separately, instead of using an
    ``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on
    this method, trying the different annotation types in the union in turn.
    """
    from allennlp.models.archival import load_archive  # import here to avoid circular imports

    # We used `param_name` as the method argument to avoid conflicts with 'name' being a key in
    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back
    # to using `name`.
    name = param_name
    origin = getattr(annotation, '__origin__', None)
    args = getattr(annotation, '__args__', [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
    # We check the provided `extras` for these and just use them if they exist.
    if name in extras:
        return extras[name]
    # Next case is when argument should be loaded from pretrained archive.
    elif name in params and isinstance(
            params.get(name), Params) and "_pretrained" in params.get(name):
        load_module_params = params.pop(name).pop("_pretrained")
        archive_file = load_module_params.pop("archive_file")
        module_path = load_module_params.pop("module_path")
        freeze = load_module_params.pop("freeze", True)
        archive = load_archive(archive_file)
        result = archive.extract_module(module_path, freeze)  # pylint: disable=no-member
        if not isinstance(result, annotation):
            raise ConfigurationError(
                f"The module from model at {archive_file} at path {module_path} "
                f"was expected of type {annotation} but is of type {type(result)}"
            )
        return result
    # The next case is when the parameter type is itself constructible from_params.
    elif hasattr(annotation, 'from_params'):
        if name in params:
            # Our params have an entry for this, so we use that.
            subparams = params.pop(name)

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(subparams, str):
                return annotation.by_name(subparams)()
            else:
                return annotation.from_params(params=subparams, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(f"expected key {name} for {cls.__name__}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation == str:
        return params.pop(name, default) if optional else params.pop(name)
    elif annotation == int:
        return params.pop_int(name,
                              default) if optional else params.pop_int(name)
    elif annotation == bool:
        return params.pop_bool(name,
                               default) if optional else params.pop_bool(name)
    elif annotation == float:
        return params.pop_float(
            name, default) if optional else params.pop_float(name)

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict, dict) and len(args) == 2 and hasattr(
            args[-1], 'from_params'):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in params.pop(name, Params({})).items():
            subextras = create_extras(value_cls, extras)
            value_dict[key] = value_cls.from_params(params=value_params,
                                                    **subextras)

        return value_dict

    elif origin in (List, list) and len(args) == 1 and hasattr(
            args[0], 'from_params'):
        value_cls = annotation.__args__[0]

        value_list = []

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return value_list

    elif origin in (Tuple, tuple) and all(
            hasattr(arg, 'from_params') for arg in args):
        value_list = []

        for value_cls, value_params in zip(annotation.__args__,
                                           params.pop(name, Params({}))):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and hasattr(
            args[0], 'from_params'):
        value_cls = annotation.__args__[0]

        value_set = set()

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_set.add(
                value_cls.from_params(params=value_params, **subextras))

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        param_value = params.get(name, Params({}))
        if isinstance(param_value, Params):
            param_value = param_value.duplicate()

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg in args:
            try:
                return construct_arg(cls, name, arg, default, params, **extras)
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have popped `params[name]`, so we
                # restore it here.
                params[name] = param_value
                if isinstance(param_value, Params):
                    param_value = param_value.duplicate()
                continue

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {name} with type {annotation}")
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if optional:
            return params.pop(name, default)
        else:
            return params.pop(name)