예제 #1
0
def module(name, nn_module, update_module_params=False):
    """
    Takes a torch.nn.Module and registers its parameters with the ParamStore.
    In conjunction with the ParamStore save() and load() functionality, this
    allows the user to save and load modules.

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\
        _MODULE_NAMESPACE_DIVIDER

    if isclass(nn_module):
        raise NotImplementedError(
            "pyro.module does not support class constructors for " +
            "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        if param_value.requires_grad:
            # register the parameter in the module with pyro
            # this only does something substantive if the parameter hasn't been seen before
            full_param_name = param_with_module_name(name, param_name)
            returned_param = param(full_param_name, param_value)

            if param_value._cdata != returned_param._cdata:
                target_state_dict[param_name] = returned_param
        else:
            warnings.warn("{} was not registered in the param store because".
                          format(param_name) + " requires_grad=False")

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit('.', 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(
                        nn_module, mod_name
                    )._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
예제 #2
0
def module(name, nn_module, tags="default", update_module_params=False):
    """
    Takes a torch.nn.Module and registers its parameters with the ParamStore.
    In conjunction with the ParamStore save() and load() functionality, this
    allows the user to save and load modules.

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param tags: optional; tags to associate with any parameters inside the module
    :type tags: string or iterable of strings
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\
        _MODULE_NAMESPACE_DIVIDER

    if isclass(nn_module):
        raise NotImplementedError("pyro.module does not support class constructors for " +
                                  "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        # register the parameter in the module with pyro
        # this only does something substantive if the parameter hasn't been seen before
        full_param_name = param_with_module_name(name, param_name)
        returned_param = param(full_param_name, param_value, tags=tags)

        if get_tensor_data(param_value)._cdata != get_tensor_data(returned_param)._cdata:
            target_state_dict[param_name] = returned_param

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit('.', 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(nn_module, mod_name)._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
예제 #3
0
파일: primitives.py 프로젝트: pyro-ppl/pyro
def module(name, nn_module, update_module_params=False):
    """
    Registers all parameters of a :class:`torch.nn.Module` with Pyro's
    :mod:`~pyro.params.param_store`.  In conjunction with the
    :class:`~pyro.params.param_store.ParamStoreDict`
    :meth:`~pyro.params.param_store.ParamStoreDict.save` and
    :meth:`~pyro.params.param_store.ParamStoreDict.load` functionality, this
    allows the user to save and load modules.

    .. note:: Consider instead using :class:`~pyro.nn.module.PyroModule`, a
        newer alternative to ``pyro.module()`` that has better support for:
        jitting, serving in C++, and converting parameters to random variables.
        For details see the `Modules Tutorial
        <https://pyro.ai/examples/modules.html>`_ .

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, (
        "improper module name, since contains %s" % _MODULE_NAMESPACE_DIVIDER)

    if isclass(nn_module):
        raise NotImplementedError(
            "pyro.module does not support class constructors for " +
            "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        if param_value.requires_grad:
            # register the parameter in the module with pyro
            # this only does something substantive if the parameter hasn't been seen before
            full_param_name = param_with_module_name(name, param_name)
            returned_param = param(full_param_name, param_value)

            if param_value._cdata != returned_param._cdata:
                target_state_dict[param_name] = returned_param
        elif nn_module.training:
            warnings.warn(
                f"{param_name} was not registered in the param store "
                "because requires_grad=False. You can silence this "
                "warning by calling my_module.train(False)")

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit(".", 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(
                        nn_module, mod_name
                    )._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
예제 #4
0
 def get_scale(self, site_name):
     return pyutil.deep_getattr(self, site_name + ".scale")
예제 #5
0
 def get_loc(self, site_name):
     return pyutil.deep_getattr(self, site_name + ".loc")