Ejemplo n.º 1
0
def test_hmc(model_class, X, y, kernel, likelihood):
    if model_class is SparseGPRegression or model_class is VariationalSparseGP:
        gp = model_class(X, y, kernel, X, likelihood)
    else:
        gp = model_class(X, y, kernel, likelihood)

    kernel.set_prior("variance",
                     dist.Uniform(torch.tensor(0.5), torch.tensor(1.5)))
    kernel.set_prior("lengthscale",
                     dist.Uniform(torch.tensor(1.0), torch.tensor(3.0)))

    hmc_kernel = HMC(gp.model, step_size=1)
    mcmc_run = MCMC(hmc_kernel, num_samples=10)

    post_trace = defaultdict(list)
    for trace, _ in mcmc_run._traces():
        variance_name = param_with_module_name(kernel.name, "variance")
        post_trace["variance"].append(trace.nodes[variance_name]["value"])
        lengthscale_name = param_with_module_name(kernel.name, "lengthscale")
        post_trace["lengthscale"].append(
            trace.nodes[lengthscale_name]["value"])
        if model_class is VariationalGP:
            f_name = param_with_module_name(gp.name, "f")
            post_trace["f"].append(trace.nodes[f_name]["value"])
        if model_class is VariationalSparseGP:
            u_name = param_with_module_name(gp.name, "u")
            post_trace["u"].append(trace.nodes[u_name]["value"])

    for param in post_trace:
        param_mean = torch.mean(torch.stack(post_trace[param]), 0)
        logger.info("Posterior mean - {}".format(param))
        logger.info(param_mean)
Ejemplo n.º 2
0
def test_hmc(model_class, X, y, kernel, likelihood):
    if model_class is SparseGPRegression or model_class is VariationalSparseGP:
        gp = model_class(X, y, kernel, X, likelihood)
    else:
        gp = model_class(X, y, kernel, likelihood)

    kernel.set_prior("variance", dist.Uniform(torch.tensor(0.5), torch.tensor(1.5)))
    kernel.set_prior("lengthscale", dist.Uniform(torch.tensor(1.0), torch.tensor(3.0)))

    hmc_kernel = HMC(gp.model, step_size=1)
    mcmc_run = MCMC(hmc_kernel, num_samples=10)

    post_trace = defaultdict(list)
    for trace, _ in mcmc_run._traces():
        variance_name = param_with_module_name(kernel.name, "variance")
        post_trace["variance"].append(trace.nodes[variance_name]["value"])
        lengthscale_name = param_with_module_name(kernel.name, "lengthscale")
        post_trace["lengthscale"].append(trace.nodes[lengthscale_name]["value"])
        if model_class is VariationalGP:
            f_name = param_with_module_name(gp.name, "f")
            post_trace["f"].append(trace.nodes[f_name]["value"])
        if model_class is VariationalSparseGP:
            u_name = param_with_module_name(gp.name, "u")
            post_trace["u"].append(trace.nodes[u_name]["value"])

    for param in post_trace:
        param_mean = torch.mean(torch.stack(post_trace[param]), 0)
        logger.info("Posterior mean - {}".format(param))
        logger.info(param_mean)
Ejemplo n.º 3
0
    def _register_param(self, param, mode="model"):
        """
        Registers a parameter to Pyro. It can be seen as a wrapper for
        :func:`pyro.param` and :func:`pyro.sample` primitives.

        :param str param: Name of the parameter.
        :param str mode: Either "model" or "guide".
        """
        if param in self._fixed_params:
            self._registered_params[param] = self._fixed_params[param]
            return
        prior = self._priors.get(param)
        if self.name is None:
            param_name = param
        else:
            param_name = param_with_module_name(self.name, param)

        if prior is None:
            constraint = self._constraints.get(param)
            default_value = getattr(self, param)
            if constraint is None:
                p = pyro.param(param_name, default_value)
            else:
                p = pyro.param(param_name,
                               default_value,
                               constraint=constraint)
        elif mode == "model":
            p = pyro.sample(param_name, prior)
        else:  # prior != None and mode = "guide"
            MAP_param_name = param_name + "_MAP"
            # TODO: consider to init parameter from a prior call instead of mean
            MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
            p = pyro.sample(param_name, dist.Delta(MAP_param))

        self._registered_params[param] = p
Ejemplo n.º 4
0
Archivo: vgp.py Proyecto: lewisKit/pyro
    def model(self):
        self.set_mode("model")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        N = self.X.shape[0]
        Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) *
                                     self.jitter)
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(f_loc.shape)
        f_name = param_with_module_name(self.name, "f")

        if self.whiten:
            Id = torch.eye(N, out=self.X.new_empty(N, N))
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(f_scale_tril)
        else:
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Lff)
                            .independent(zero_loc.dim() - 1))

        f_var = f_scale_tril.pow(2).sum(dim=-1)

        if self.whiten:
            f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1)
        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
Ejemplo n.º 5
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        u_loc = self.get_param("u_loc")
        u_scale_tril = self.get_param("u_scale_tril")

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter
        Luu = Kuu.potrf(upper=False)

        zero_loc = Xu.new_zeros(u_loc.shape)
        u_name = param_with_module_name(self.name, "u")
        if self.whiten:
            Id = torch.eye(M, out=Xu.new_empty(M, M))
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
        else:
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .independent(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten,
                                   jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            with poutine.scale(None, self.num_data / self.X.shape[0]):
                return self.likelihood(f_loc, f_var, self.y)
Ejemplo n.º 6
0
    def _register_param(self, param, mode="model"):
        """
        Registers a parameter to Pyro. It can be seen as a wrapper for
        :func:`pyro.param` and :func:`pyro.sample` primitives.

        :param str param: Name of the parameter.
        :param str mode: Either "model" or "guide".
        """
        if param in self._fixed_params:
            self._registered_params[param] = self._fixed_params[param]
            return
        prior = self._priors.get(param)
        if self.name is None:
            param_name = param
        else:
            param_name = param_with_module_name(self.name, param)

        if prior is None:
            constraint = self._constraints.get(param)
            default_value = getattr(self, param)
            if constraint is None:
                p = pyro.param(param_name, default_value)
            else:
                p = pyro.param(param_name, default_value, constraint=constraint)
        elif mode == "model":
            p = pyro.sample(param_name, prior)
        else:  # prior != None and mode = "guide"
            MAP_param_name = param_name + "_MAP"
            # TODO: consider to init parameter from a prior call instead of mean
            MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
            p = pyro.sample(param_name, dist.Delta(MAP_param))

        self._registered_params[param] = p
Ejemplo n.º 7
0
def module(name, nn_module, update_module_params=False):
    """
    Takes a torch.nn.Module and registers its parameters with the ParamStore.
    In conjunction with the ParamStore save() and load() functionality, this
    allows the user to save and load modules.

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\
        _MODULE_NAMESPACE_DIVIDER

    if isclass(nn_module):
        raise NotImplementedError(
            "pyro.module does not support class constructors for " +
            "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        if param_value.requires_grad:
            # register the parameter in the module with pyro
            # this only does something substantive if the parameter hasn't been seen before
            full_param_name = param_with_module_name(name, param_name)
            returned_param = param(full_param_name, param_value)

            if param_value._cdata != returned_param._cdata:
                target_state_dict[param_name] = returned_param
        else:
            warnings.warn("{} was not registered in the param store because".
                          format(param_name) + " requires_grad=False")

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit('.', 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(
                        nn_module, mod_name
                    )._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
Ejemplo n.º 8
0
def module(name, nn_module, tags="default", update_module_params=False):
    """
    Takes a torch.nn.Module and registers its parameters with the ParamStore.
    In conjunction with the ParamStore save() and load() functionality, this
    allows the user to save and load modules.

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param tags: optional; tags to associate with any parameters inside the module
    :type tags: string or iterable of strings
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\
        _MODULE_NAMESPACE_DIVIDER

    if isclass(nn_module):
        raise NotImplementedError("pyro.module does not support class constructors for " +
                                  "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        # register the parameter in the module with pyro
        # this only does something substantive if the parameter hasn't been seen before
        full_param_name = param_with_module_name(name, param_name)
        returned_param = param(full_param_name, param_value, tags=tags)

        if get_tensor_data(param_value)._cdata != get_tensor_data(returned_param)._cdata:
            target_state_dict[param_name] = returned_param

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit('.', 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(nn_module, mod_name)._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
Ejemplo n.º 9
0
Archivo: vgp.py Proyecto: lewisKit/pyro
    def guide(self):
        self.set_mode("guide")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        if self._sample_latent:
            f_name = param_with_module_name(self.name, "f")
            pyro.sample(f_name,
                        dist.MultivariateNormal(f_loc, scale_tril=f_scale_tril)
                            .independent(f_loc.dim()-1))
        return f_loc, f_scale_tril
Ejemplo n.º 10
0
    def guide(self):
        self.set_mode("guide")

        Xu = self.get_param("Xu")
        u_loc = self.get_param("u_loc")
        u_scale_tril = self.get_param("u_scale_tril")

        if self._sample_latent:
            u_name = param_with_module_name(self.name, "u")
            pyro.sample(u_name,
                        dist.MultivariateNormal(u_loc, scale_tril=u_scale_tril)
                            .independent(u_loc.dim()-1))
        return Xu, u_loc, u_scale_tril
Ejemplo n.º 11
0
    def model(self):
        self.set_mode("model", recursive=False)

        # sample X from unit multivariate normal distribution
        zero_loc = self.X_loc.new_zeros(self.X_loc.shape)
        C = self.X_loc.shape[1]
        Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
        X_name = param_with_module_name(self.name, "X")
        X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id)
                                    .independent(zero_loc.dim()-1))

        self.base_model.set_data(X, self.y)
        self.base_model.model()
Ejemplo n.º 12
0
    def guide(self):
        self.set_mode("guide")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        if self._sample_latent:
            f_name = param_with_module_name(self.name, "f")
            pyro.sample(
                f_name,
                dist.MultivariateNormal(
                    f_loc,
                    scale_tril=f_scale_tril).independent(f_loc.dim() - 1))
        return f_loc, f_scale_tril
Ejemplo n.º 13
0
    def guide(self):
        self.set_mode("guide", recursive=False)

        # sample X from variational multivariate normal distribution
        X_loc = self.get_param("X_loc")
        X_scale_tril = self.get_param("X_scale_tril")
        X_name = param_with_module_name(self.name, "X")
        X = pyro.sample(X_name,
                        dist.MultivariateNormal(X_loc, scale_tril=X_scale_tril)
                            .independent(X_loc.dim()-1))

        self.base_model.set_data(X, self.y)
        if self._call_base_model_guide:
            self.base_model.guide()
Ejemplo n.º 14
0
    def guide(self):
        self.set_mode("guide", recursive=False)

        # sample X from variational multivariate normal distribution
        X_loc = self.get_param("X_loc")
        X_scale_tril = self.get_param("X_scale_tril")
        X_name = param_with_module_name(self.name, "X")
        X = pyro.sample(
            X_name,
            dist.MultivariateNormal(
                X_loc, scale_tril=X_scale_tril).independent(X_loc.dim() - 1))

        self.base_model.set_data(X, self.y)
        if self._call_base_model_guide:
            self.base_model.guide()
Ejemplo n.º 15
0
    def model(self):
        self.set_mode("model", recursive=False)

        # sample X from unit multivariate normal distribution
        zero_loc = self.X_loc.new_zeros(self.X_loc.shape)
        C = self.X_loc.shape[1]
        Id = torch.eye(C, out=self.X_loc.new_empty(C, C))
        X_name = param_with_module_name(self.name, "X")
        X = pyro.sample(
            X_name,
            dist.MultivariateNormal(
                zero_loc, scale_tril=Id).independent(zero_loc.dim() - 1))

        self.base_model.set_data(X, self.y)
        self.base_model.model()
Ejemplo n.º 16
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        noise = self.get_param("noise")

        # W = inv(Luu) @ Kuf
        # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W
        # Fomulas for each approximation method are
        # DTC:  y_cov = Qff + noise,                   trace_term = 0
        # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0
        # VFE:  y_cov = Qff + noise,                   trace_term = tr(Kff-Qff) / noise
        # y_cov = W.T @ W + D
        # trace_term is added into log_prob

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M,
                                                              M)) * self.jitter
        Luu = Kuu.potrf(upper=False)
        Kuf = self.kernel(Xu, self.X)
        W = matrix_triangular_solve_compat(Kuf, Luu, upper=False)

        D = noise.expand(W.shape[1])
        trace_term = 0
        if self.approx == "FITC" or self.approx == "VFE":
            Kffdiag = self.kernel(self.X, diag=True)
            Qffdiag = W.pow(2).sum(dim=0)
            if self.approx == "FITC":
                D = D + Kffdiag - Qffdiag
            else:  # approx = "VFE"
                trace_term += (Kffdiag - Qffdiag).sum() / noise

        zero_loc = self.X.new_zeros(self.X.shape[0])
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = D + W.pow(2).sum(dim=0)
            return f_loc, f_var
        else:
            y_name = param_with_module_name(self.name, "y")
            return pyro.sample(
                y_name,
                dist.LowRankMultivariateNormal(
                    f_loc, W, D, trace_term).expand_by(
                        self.y.shape[:-1]).independent(self.y.dim() - 1),
                obs=self.y)
Ejemplo n.º 17
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        noise = self.get_param("noise")

        # W = inv(Luu) @ Kuf
        # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W
        # Fomulas for each approximation method are
        # DTC:  y_cov = Qff + noise,                   trace_term = 0
        # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0
        # VFE:  y_cov = Qff + noise,                   trace_term = tr(Kff-Qff) / noise
        # y_cov = W.T @ W + D
        # trace_term is added into log_prob

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter
        Luu = Kuu.potrf(upper=False)
        Kuf = self.kernel(Xu, self.X)
        W = matrix_triangular_solve_compat(Kuf, Luu, upper=False)

        D = noise.expand(W.shape[1])
        trace_term = 0
        if self.approx == "FITC" or self.approx == "VFE":
            Kffdiag = self.kernel(self.X, diag=True)
            Qffdiag = W.pow(2).sum(dim=0)
            if self.approx == "FITC":
                D = D + Kffdiag - Qffdiag
            else:  # approx = "VFE"
                trace_term += (Kffdiag - Qffdiag).sum() / noise

        zero_loc = self.X.new_zeros(self.X.shape[0])
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = D + W.pow(2).sum(dim=0)
            return f_loc, f_var
        else:
            y_name = param_with_module_name(self.name, "y")
            return pyro.sample(y_name,
                               dist.LowRankMultivariateNormal(f_loc, W, D, trace_term)
                                   .expand_by(self.y.shape[:-1])
                                   .independent(self.y.dim() - 1),
                               obs=self.y)
Ejemplo n.º 18
0
Archivo: gpr.py Proyecto: lewisKit/pyro
    def model(self):
        self.set_mode("model")

        noise = self.get_param("noise")

        Kff = self.kernel(self.X) + noise.expand(self.X.shape[0]).diag()
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(self.X.shape[0])
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = Lff.pow(2).sum(dim=-1)
            return f_loc, f_var
        else:
            y_name = param_with_module_name(self.name, "y")
            return pyro.sample(y_name,
                               dist.MultivariateNormal(f_loc, scale_tril=Lff)
                                   .expand_by(self.y.shape[:-1])
                                   .independent(self.y.dim() - 1),
                               obs=self.y)
Ejemplo n.º 19
0
    def model(self):
        self.set_mode("model")

        noise = self.get_param("noise")

        Kff = self.kernel(self.X) + noise.expand(self.X.shape[0]).diag()
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(self.X.shape[0])
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = Lff.pow(2).sum(dim=-1)
            return f_loc, f_var
        else:
            y_name = param_with_module_name(self.name, "y")
            return pyro.sample(
                y_name,
                dist.MultivariateNormal(f_loc, scale_tril=Lff).expand_by(
                    self.y.shape[:-1]).independent(self.y.dim() - 1),
                obs=self.y)
Ejemplo n.º 20
0
    def model(self):
        self.set_mode("model")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        N = self.X.shape[0]
        Kff = self.kernel(
            self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) * self.jitter)
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(f_loc.shape)
        f_name = param_with_module_name(self.name, "f")

        if self.whiten:
            Id = torch.eye(N, out=self.X.new_empty(N, N))
            pyro.sample(
                f_name,
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Id).independent(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(f_scale_tril)
        else:
            pyro.sample(
                f_name,
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Lff).independent(zero_loc.dim() - 1))

        f_var = f_scale_tril.pow(2).sum(dim=-1)

        if self.whiten:
            f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1)
        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
Ejemplo n.º 21
0
def module(name, nn_module, update_module_params=False):
    """
    Registers all parameters of a :class:`torch.nn.Module` with Pyro's
    :mod:`~pyro.params.param_store`.  In conjunction with the
    :class:`~pyro.params.param_store.ParamStoreDict`
    :meth:`~pyro.params.param_store.ParamStoreDict.save` and
    :meth:`~pyro.params.param_store.ParamStoreDict.load` functionality, this
    allows the user to save and load modules.

    .. note:: Consider instead using :class:`~pyro.nn.module.PyroModule`, a
        newer alternative to ``pyro.module()`` that has better support for:
        jitting, serving in C++, and converting parameters to random variables.
        For details see the `Modules Tutorial
        <https://pyro.ai/examples/modules.html>`_ .

    :param name: name of module
    :type name: str
    :param nn_module: the module to be registered with Pyro
    :type nn_module: torch.nn.Module
    :param update_module_params: determines whether Parameters
                                 in the PyTorch module get overridden with the values found in the
                                 ParamStore (if any). Defaults to `False`
    :type load_from_param_store: bool
    :returns: torch.nn.Module
    """
    assert hasattr(nn_module, "parameters"), "module has no parameters"
    assert _MODULE_NAMESPACE_DIVIDER not in name, (
        "improper module name, since contains %s" % _MODULE_NAMESPACE_DIVIDER)

    if isclass(nn_module):
        raise NotImplementedError(
            "pyro.module does not support class constructors for " +
            "the argument nn_module")

    target_state_dict = OrderedDict()

    for param_name, param_value in nn_module.named_parameters():
        if param_value.requires_grad:
            # register the parameter in the module with pyro
            # this only does something substantive if the parameter hasn't been seen before
            full_param_name = param_with_module_name(name, param_name)
            returned_param = param(full_param_name, param_value)

            if param_value._cdata != returned_param._cdata:
                target_state_dict[param_name] = returned_param
        elif nn_module.training:
            warnings.warn(
                f"{param_name} was not registered in the param store "
                "because requires_grad=False. You can silence this "
                "warning by calling my_module.train(False)")

    if target_state_dict and update_module_params:
        # WARNING: this is very dangerous. better method?
        for _name, _param in nn_module.named_parameters():
            is_param = False
            name_arr = _name.rsplit(".", 1)
            if len(name_arr) > 1:
                mod_name, param_name = name_arr[0], name_arr[1]
            else:
                is_param = True
                mod_name = _name
            if _name in target_state_dict.keys():
                if not is_param:
                    deep_getattr(
                        nn_module, mod_name
                    )._parameters[param_name] = target_state_dict[_name]
                else:
                    nn_module._parameters[mod_name] = target_state_dict[_name]

    return nn_module
Ejemplo n.º 22
0
 def __init__(self, name=None):
     super(Likelihood, self).__init__(name)
     self.y_name = (param_with_module_name(name, "y") if name is not None
                    else "y")
Ejemplo n.º 23
0
 def __init__(self, name=None):
     super(Likelihood, self).__init__(name)
     self.y_name = (param_with_module_name(name, "y")
                    if name is not None else "y")