Esempio n. 1
0
    def process_noise_cov(self, dt=0.):
        '''
        Compute and return cached process noise covariance (Q).

        :param dt: time interval to integrate over.
        :return: Read-only covariance (Q) of the native state `x` resulting from
            stochastic integration (for use with EKF). (Note that this Q, modulo
            numerical error, has rank `dimension/2`. So, it is only positive
            semi-definite.)
        '''
        if dt not in self._Q_cache:
            with torch.no_grad():
                d = self._dimension
                dt2 = dt * dt
                dt3 = dt2 * dt
                dt4 = dt2 * dt2
                Q = torch.zeros(d,
                                d,
                                dtype=self.sa2.dtype,
                                device=self.sa2.device)
                Q[:d // 2, :d // 2] = 0.25 * dt4 * eye_like(self.sa2, d // 2)
                Q[:d // 2, d // 2:] = 0.5 * dt3 * eye_like(self.sa2, d // 2)
                Q[d // 2:, :d // 2] = 0.5 * dt3 * eye_like(self.sa2, d // 2)
                Q[d // 2:, d // 2:] = dt2 * eye_like(self.sa2, d // 2)
            Q = Q * self.sa2
            self._Q_cache[dt] = Q

        return self._Q_cache[dt]
Esempio n. 2
0
File: vsgp.py Progetto: zyxue/pyro
    def __init__(self,
                 X,
                 y,
                 kernel,
                 Xu,
                 likelihood,
                 mean_function=None,
                 latent_shape=None,
                 num_data=None,
                 whiten=False,
                 jitter=1e-6):
        super(VariationalSparseGP, self).__init__(X, y, kernel, mean_function,
                                                  jitter)

        self.likelihood = likelihood
        self.Xu = Parameter(Xu)

        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size(
            [])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        M = self.Xu.size(0)
        u_loc = self.Xu.new_zeros(self.latent_shape + (M, ))
        self.u_loc = Parameter(u_loc)

        identity = eye_like(self.Xu, M)
        u_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.u_scale_tril = Parameter(u_scale_tril)
        self.set_constraint("u_scale_tril", constraints.lower_cholesky)

        self.num_data = num_data if num_data is not None else self.X.size(0)
        self.whiten = whiten
        self._sample_latent = True
Esempio n. 3
0
 def _initialize_model_properties(self):
     if self.max_plate_nesting is None:
         self._guess_max_plate_nesting()
     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
     # No-op if model does not have any discrete latents.
     self.model = poutine.enum(config_enumerate(self.model),
                               first_available_dim=-1 -
                               self.max_plate_nesting)
     if self._automatic_transform_enabled:
         self.transforms = {}
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for name, node in trace.iter_stochastic_nodes():
         if isinstance(node["fn"], _Subsample):
             continue
         if node["fn"].has_enumerate_support:
             self._has_enumerable_sites = True
             continue
         site_value = node["value"]
         if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
             self.transforms[name] = biject_to(node["fn"].support).inv
             site_value = self.transforms[name](node["value"])
         self._r_shapes[name] = site_value.shape
         self._r_numels[name] = site_value.numel()
     self._trace_prob_evaluator = TraceEinsumEvaluator(
         trace, self._has_enumerable_sites, self.max_plate_nesting)
     mass_matrix_size = sum(self._r_numels.values())
     if self.full_mass:
         initial_mass_matrix = eye_like(site_value, mass_matrix_size)
     else:
         initial_mass_matrix = site_value.new_ones(mass_matrix_size)
     self._adapter.inverse_mass_matrix = initial_mass_matrix
Esempio n. 4
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale_tril = PyroParam(
         eye_like(self.loc, self.latent_dim) * self._init_scale,
         constraints.lower_cholesky)
Esempio n. 5
0
    def backward(ctx, grad_output):
        jitter = 1.0e-8  # do i really need this?
        z, epsilon, L = ctx.saved_tensors

        dim = L.shape[0]
        g = grad_output
        loc_grad = sum_leftmost(grad_output, -1)

        identity = eye_like(g, dim)
        R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0]

        z_ja = z.unsqueeze(-1)
        g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2)
        epsilon_jb = epsilon.unsqueeze(-2)
        g_ja = g.unsqueeze(-1)
        diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2)

        Sigma_inv = torch.mm(R_inv, R_inv.t())
        V, D, _ = torch.svd(Sigma_inv + jitter)
        D_outer = D.unsqueeze(-1) + D.unsqueeze(0)

        expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim])
        z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple)
        g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple)

        Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2)
        Y = torch.mm(V, torch.mm(Y, V.t()))
        Y = Y + Y.t()

        Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv))
        diff_L_ab += 0.5 * Tr_xi_Y
        L_grad = torch.tril(diff_L_ab)

        return loc_grad, L_grad, None
Esempio n. 6
0
    def model(self):
        self.set_mode("model")

        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        Kuu.view(-1)[::M + 1] += self.jitter  # add jitter to the diagonal
        Luu = Kuu.cholesky()

        zero_loc = self.Xu.new_zeros(self.u_loc.shape)
        if self.whiten:
            identity = eye_like(self.Xu, M)
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=identity)
                            .to_event(zero_loc.dim() - 1))
        else:
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .to_event(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            # we would like to load likelihood's parameters outside poutine.scale context
            self.likelihood._load_pyro_samples()
            with poutine.scale(scale=self.num_data / self.X.size(0)):
                return self.likelihood(f_loc, f_var, self.y)
Esempio n. 7
0
    def __init__(self,
                 X,
                 y,
                 kernel,
                 likelihood,
                 mean_function=None,
                 latent_shape=None,
                 whiten=False,
                 jitter=1e-6,
                 use_cuda=False):
        super().__init__(X, y, kernel, mean_function, jitter)

        self.likelihood = likelihood
        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size(
            [])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        N = self.X.size(0)
        f_loc = self.X.new_zeros(self.latent_shape + (N, ))
        self.f_loc = Parameter(f_loc)

        identity = eye_like(self.X, N)
        f_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky)

        self.whiten = whiten
        self._sample_latent = True
        if use_cuda:
            self.cuda()
Esempio n. 8
0
    def model(self):
        self.set_mode("model")
        N = self.X.size(0)
        Kff = self.kernel(self.X).contiguous()
        Kff.view(-1)[::N + 1] += self.jitter  # add jitter to the diagonal
        Lff = Kff.cholesky()

        zero_loc = self.X.new_zeros(self.f_loc.shape)
        if self.whiten:
            identity = eye_like(self.X, N)
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc,
                    scale_tril=identity).to_event(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(self.f_scale_tril)
            f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
        else:
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1))
            f_scale_tril = self.f_scale_tril
            f_loc = self.f_loc

        f_loc = f_loc + self.mean_function(self.X)
        f_var = f_scale_tril.pow(2).sum(dim=-1)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
Esempio n. 9
0
    def jacobian(self, dt):
        '''
        Compute and return cached native state transition Jacobian (F) over
        time interval ``dt``.

        :param dt: time interval to integrate over.
        :return: Read-only Jacobian (F) of integration map (f).
        '''
        if dt not in self._F_cache:
            d = self._dimension
            with torch.no_grad():
                F = eye_like(self.sa2, d)
                F[:d // 2, d // 2:] = dt * eye_like(self.sa2, d // 2)
            self._F_cache[dt] = F

        return self._F_cache[dt]
Esempio n. 10
0
    def autoguide(self, name, dist_constructor):
        """
        Sets an autoguide for an existing parameter with name ``name`` (mimic
        the behavior of module :mod:`pyro.infer.autoguide`).

        .. note:: `dist_constructor` should be one of
            :class:`~pyro.distributions.Delta`,
            :class:`~pyro.distributions.Normal`, and
            :class:`~pyro.distributions.MultivariateNormal`. More distribution
            constructor will be supported in the future if needed.

        :param str name: Name of the parameter.
        :param dist_constructor: A
            :class:`~pyro.distributions.distribution.Distribution` constructor.
        """
        if name not in self._priors:
            raise ValueError(
                "There is no prior for parameter: {}".format(name))

        if dist_constructor not in [
                dist.Delta, dist.Normal, dist.MultivariateNormal
        ]:
            raise NotImplementedError(
                "Unsupported distribution type: {}".format(dist_constructor))

        # delete old guide
        if name in self._guides:
            dist_args = self._guides[name][1]
            for arg in dist_args:
                delattr(self, "{}_{}".format(name, arg))

        p = self._priors[name]()  # init_to_sample strategy
        if dist_constructor is dist.Delta:
            support = self._priors[name].support
            if _is_real_support(support):
                p_map = Parameter(p.detach())
            else:
                p_map = PyroParam(p.detach(), support)
            setattr(self, "{}_map".format(name), p_map)
            dist_args = ("map", )
        elif dist_constructor is dist.Normal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            scale = PyroParam(loc.new_ones(loc.shape), constraints.positive)
            setattr(self, "{}_loc".format(name), loc)
            setattr(self, "{}_scale".format(name), scale)
            dist_args = ("loc", "scale")
        elif dist_constructor is dist.MultivariateNormal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            identity = eye_like(loc, loc.size(-1))
            scale_tril = PyroParam(identity.repeat(loc.shape[:-1] + (1, 1)),
                                   constraints.lower_cholesky)
            setattr(self, "{}_loc".format(name), loc)
            setattr(self, "{}_scale_tril".format(name), scale_tril)
            dist_args = ("loc", "scale_tril")
        else:
            raise NotImplementedError

        self._guides[name] = (dist_constructor, dist_args)
Esempio n. 11
0
    def process_noise_cov(self, dt=0.):
        '''
        Compute and return cached process noise covariance (Q).

        :param dt: time interval to integrate over.

        :return: Read-only covariance (Q) of the native state ``x`` resulting from
            stochastic integration (for use with EKF).
        '''
        if dt not in self._Q_cache:

            with torch.no_grad():
                d = self._dimension
                dt2 = dt * dt
                dt3 = dt2 * dt
                Q = torch.zeros(d,
                                d,
                                dtype=self.sa2.dtype,
                                device=self.sa2.device)
                eye = eye_like(self.sa2, d // 2)
                Q[:d // 2, :d // 2] = dt3 * eye / 3.0
                Q[:d // 2, d // 2:] = dt2 * eye / 2.0
                Q[d // 2:, :d // 2] = dt2 * eye / 2.0
                Q[d // 2:, d // 2:] = dt * eye
            # sa2 * dt is an intensity factor that changes in velocity
            # over a sampling period ``dt``, ideally should be ~``sqrt(q*dt)``.
            Q = Q * (self.sa2 * dt)
            self._Q_cache[dt] = Q

        return self._Q_cache[dt]
Esempio n. 12
0
 def __init__(self, dimension, sv2):
     dimension_pv = 2 * dimension
     super().__init__(dimension, dimension_pv, num_process_noise_parameters=1)
     if not isinstance(sv2, torch.Tensor):
         sv2 = torch.tensor(sv2)
     self.sv2 = Parameter(sv2)
     self._F_cache = eye_like(sv2, dimension)  # State transition matrix cache
     self._Q_cache = {}  # Process noise cov cache
Esempio n. 13
0
 def _setup_prototype(self, *args, **kwargs):
     super()._setup_prototype(*args, **kwargs)
     # Initialize guide params
     self.loc = nn.Parameter(self._init_loc())
     self.scale = PyroParam(torch.full_like(self.loc, self._init_scale),
                            self.scale_constraint)
     self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim),
                                 self.scale_tril_constraint)
Esempio n. 14
0
    def init_mvn_guide(self):
        """ Initialize multivariate normal guide
        """
        init_loc = torch.full((self.n_params, ), 0.0)
        init_scale = eye_like(init_loc, self.n_params) * 0.1

        _deep_setattr(self, "mvn.loc", PyroParam(init_loc, constraints.real))
        _deep_setattr(self, "mvn.scale_tril",
                      PyroParam(init_scale, constraints.lower_cholesky))
Esempio n. 15
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a MultivariateNormal posterior distribution.
     """
     loc = pyro.param("{}_loc".format(self.prefix), self._init_loc)
     scale_tril = pyro.param("{}_scale_tril".format(self.prefix),
                             lambda: eye_like(loc, self.latent_dim),
                             constraint=constraints.lower_cholesky)
     return dist.MultivariateNormal(loc, scale_tril=scale_tril)
Esempio n. 16
0
    def guide(self):

        a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0))
        a_scales_tril = pyro.param(
            "a_scales",
            lambda: 0.1 * eye_like(a_locs, self.n_params),
            constraint=constraints.lower_cholesky)

        dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril)
        states = pyro.sample("states", dt, infer={"is_auxiliary": True})

        result = {}

        for i_poiss in torch.arange(self.n_poiss):
            transform = biject_to(self.poiss_priors[i_poiss].support)
            value = transform(states[i_poiss])
            log_density = transform.inv.log_abs_det_jacobian(
                value, states[i_poiss])
            log_density = sum_rightmost(
                log_density,
                log_density.dim() - value.dim() +
                self.poiss_priors[i_poiss].event_dim)

            result[self.labels_poiss[i_poiss]] = pyro.sample(
                self.labels_poiss[i_poiss],
                dist.Delta(value,
                           log_density=log_density,
                           event_dim=self.poiss_priors[i_poiss].event_dim))

        i_param = self.n_poiss

        for i_ps in torch.arange(self.n_ps):
            for i_ps_param in torch.arange(self.n_ps_params):

                transform = biject_to(self.ps_priors[i_ps][i_ps_param].support)

                value = transform(states[i_param])

                log_density = transform.inv.log_abs_det_jacobian(
                    value, states[i_param])
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() +
                    self.ps_priors[i_ps][i_ps_param].event_dim)

                result[self.labels_ps_params[i_ps_param] + "_" +
                       self.labels_ps[i_ps]] = pyro.sample(
                           self.labels_ps_params[i_ps_param] + "_" +
                           self.labels_ps[i_ps],
                           dist.Delta(value,
                                      log_density=log_density,
                                      event_dim=self.ps_priors[i_ps]
                                      [i_ps_param].event_dim))
                i_param += 1

        return result
Esempio n. 17
0
 def __init__(self, mean, cov, time=None, frame_num=None):
     super().__init__(mean, cov, time=time, frame_num=frame_num)
     self._jacobian = torch.cat([
         eye_like(mean, self.dimension),
         torch.zeros(self.dimension,
                     self.dimension,
                     dtype=mean.dtype,
                     device=mean.device)
     ],
                                dim=1)
Esempio n. 18
0
 def guide(self):
     if self.mygroup is None:
         self.mygroup, self.z_init_loc = self._get_group(
             match=self.guide_conf['match'])
     z_loc = pyro.param(self.prefix + "guide_z_loc", self.z_init_loc)
     z_scale_tril = pyro.param(self.prefix + "guide_z_scale_tril",
                               0.01 * eye_like(z_loc, len(z_loc)),
                               constraint=constraints.lower_cholesky)
     # TODO: Flexible initial error
     guide_z, model_zs = self.mygroup.sample(
         self.prefix + 'guide_z',
         dist.MultivariateNormal(z_loc, scale_tril=z_scale_tril))
     return guide_z, model_zs
Esempio n. 19
0
    def process_noise_cov(self, dt=0.):
        '''
        Compute and return cached process noise covariance (Q).

        :param dt: time interval to integrate over.
        :return: Read-only covariance (Q) of the native state `x` resulting from
            stochastic integration (for use with EKF).
        '''
        if dt not in self._Q_cache:
            Q = self.sv2 * dt * dt * eye_like(self.sv2, self._dimension)
            self._Q_cache[dt] = Q

        return self._Q_cache[dt]
Esempio n. 20
0
    def process_noise_cov(self, dt=0.):
        '''
        Compute and return cached process noise covariance (Q).

        :param dt: time interval to integrate over.
        :return: Read-only covariance (Q) of the native state ``x`` resulting from
            stochastic integration (for use with EKF).
        '''
        if dt not in self._Q_cache:
            # q: continuous-time process noise intensity with units
            #   length^2/time (m^2/s). Choose ``q`` so that changes in position,
            #   over a sampling period ``dt``, are roughly ``sqrt(q*dt)``.
            q = self.sv2 * dt
            Q = q * dt * eye_like(self.sv2, self._dimension)
            self._Q_cache[dt] = Q

        return self._Q_cache[dt]
Esempio n. 21
0
    def update(self, measurement):
        """
        Use measurement to update state estimate in-place and return
        innovation. The innovation is useful, e.g., for evaluating filter
        consistency or updating model likelihoods when the ``EKFState`` is part
        of an ``IMMFState``.

        :param: measurement.
        :returns: EKF State, Innovation mean and covariance.
        """
        if self._time is not None:
            assert (self._time == measurement.time
                    ), "State time and measurement time must be aligned!"
        if self._frame_num is not None:
            assert (self._frame_num == measurement.frame_num
                    ), "State time and measurement time must be aligned!"

        x = self._mean
        x_pv = self._dynamic_model.mean2pv(x)
        P = self.cov
        H = measurement.jacobian(x_pv)[:, :self.dimension]
        R = measurement.cov
        z = measurement.mean
        z_predicted = measurement(x_pv)
        dz = measurement.geodesic_difference(z, z_predicted)
        S = H.mm(P).mm(H.transpose(-1, -2)) + R  # innovation cov

        K_prefix = self._cov.mm(H.transpose(-1, -2))
        dx = K_prefix.mm(torch.linalg.solve(S, dz.unsqueeze(1))).squeeze(
            1)  # K*dz
        x = self._dynamic_model.geodesic_difference(x, -dx)

        I = eye_like(x, self._dynamic_model.dimension)  # noqa: E741
        ImKH = I - K_prefix.mm(torch.linalg.solve(S, H))
        # *Joseph form* of covariance update for numerical stability.
        S_inv_R = torch.linalg.solve(S, R)
        P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) + K_prefix.mm(
            torch.linalg.solve(S,
                               K_prefix.mm(S_inv_R).transpose(-1, -2)))

        pred_mean = x
        pred_cov = P
        state = EKFState(self._dynamic_model, pred_mean, pred_cov, self._time,
                         self._frame_num)

        return state, (dz, S)
Esempio n. 22
0
    def _initialize_adapter(self):
        mass_matrix_size = sum(
            [p.numel() for p in self.initial_params.values()])
        site_value = list(self.initial_params.values())[0]
        if self._adapter.is_diag_mass:
            initial_mass_matrix = torch.ones(mass_matrix_size,
                                             dtype=site_value.dtype,
                                             device=site_value.device)
        else:
            initial_mass_matrix = eye_like(site_value, mass_matrix_size)
        self._adapter.configure(
            self._warmup_steps,
            inv_mass_matrix=initial_mass_matrix,
            find_reasonable_step_size_fn=self._find_reasonable_step_size)

        if self._adapter.adapt_step_size:
            self._adapter.reset_step_size_adaptation(self._initial_params)
Esempio n. 23
0
    def __init__(
        self,
        X,
        y,
        kernel,
        Xu,
        likelihood,
        mean_function=None,
        latent_shape=None,
        num_data=None,
        whiten=False,
        jitter=1e-6,
    ):
        assert isinstance(
            X, torch.Tensor
        ), "X needs to be a torch Tensor instead of a {}".format(type(X))
        if y is not None:
            assert isinstance(
                y, torch.Tensor
            ), "y needs to be a torch Tensor instead of a {}".format(type(y))
        assert isinstance(
            Xu, torch.Tensor
        ), "Xu needs to be a torch Tensor instead of a {}".format(type(Xu))

        super().__init__(X, y, kernel, mean_function, jitter)

        self.likelihood = likelihood
        self.Xu = Parameter(Xu)

        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        M = self.Xu.size(0)
        u_loc = self.Xu.new_zeros(self.latent_shape + (M,))
        self.u_loc = Parameter(u_loc)

        identity = eye_like(self.Xu, M)
        u_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.u_scale_tril = PyroParam(u_scale_tril, constraints.lower_cholesky)

        self.num_data = num_data if num_data is not None else self.X.size(0)
        self.whiten = whiten
        self._sample_latent = True
Esempio n. 24
0
def GP_sample(name,
              X,
              f_loc,
              f_scale_tril,
              f_loc_mean,
              Lff=None,
              kernel=None,
              jitter=1e-6,
              whiten=False):

    N = X.size(0)

    if Lff is None:
        Kff = kernel(X).contiguous()
        Kff.view(-1)[::N + 1] += jitter  # add jitter to the diagonal
        Lff = Kff.cholesky()

    zero_loc = X.new_zeros(f_loc.shape)

    if whiten:
        identity = eye_like(X, N)
        gp_sample = pyro.sample(
            name,
            dist.MultivariateNormal(
                zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1))
        gp_sample = Lff.matmul(gp_sample.unsqueeze(-1)).squeeze(-1)
    else:
        gp_sample = pyro.sample(
            name,
            dist.MultivariateNormal(
                zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1))

    # gp_sample += (f_loc + f_loc_mean) # the guide already accounts for f_loc
    gp_sample += f_loc_mean

    return gp_sample
Esempio n. 25
0
    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)

        self.locs = PyroModule()
        self.scales = PyroModule()
        self.scale_trils = PyroModule()
        self.conds = PyroModule()
        self.deps = PyroModule()
        self._batch_shapes = {}
        self._unconstrained_event_shapes = {}
        sample_sites = OrderedDict(
            self.prototype_trace.iter_stochastic_nodes())
        self._auto_config(sample_sites, args, kwargs)

        # Collect unconstrained shapes.
        init_locs = {}
        numel = {}
        for name, site in sample_sites.items():
            with helpful_support_errors(site):
                init_loc = (biject_to(site["fn"].support).inv(
                    site["value"].detach()).detach())
            self._batch_shapes[name] = site["fn"].batch_shape
            self._unconstrained_event_shapes[name] = init_loc.shape[
                len(site["fn"].batch_shape):]
            numel[name] = init_loc.numel()
            init_locs[name] = init_loc.reshape(-1)

        # Initialize guide params.
        children = defaultdict(list)
        num_pending = {}
        for name, site in sample_sites.items():
            # Initialize location parameters.
            init_loc = init_locs[name]
            deep_setattr(self.locs, name, PyroParam(init_loc))

            # Initialize parameters of conditional distributions.
            conditional = self.conditionals[name]
            if callable(conditional):
                deep_setattr(self.conds, name, conditional)
            else:
                if conditional not in ("delta", "normal", "mvn"):
                    raise ValueError(
                        f"Unsupported conditional type: {conditional}")
                if conditional in ("normal", "mvn"):
                    init_scale = torch.full_like(init_loc, self._init_scale)
                    deep_setattr(self.scales, name,
                                 PyroParam(init_scale, self.scale_constraint))
                if conditional == "mvn":
                    init_scale_tril = eye_like(init_loc, init_loc.numel())
                    deep_setattr(
                        self.scale_trils,
                        name,
                        PyroParam(init_scale_tril, self.scale_tril_constraint),
                    )

            # Initialize dependencies on upstream variables.
            num_pending[name] = 0
            deps = PyroModule()
            deep_setattr(self.deps, name, deps)
            for upstream, dep in self.dependencies.get(name, {}).items():
                assert upstream in sample_sites
                children[upstream].append(name)
                num_pending[name] += 1
                if isinstance(dep, str) and dep == "linear":
                    dep = torch.nn.Linear(numel[upstream],
                                          numel[name],
                                          bias=False)
                    dep.weight.data.zero_()
                elif not callable(dep):
                    raise ValueError(
                        f"Expected either the string 'linear' or a callable, but got {dep}"
                    )
                deep_setattr(deps, upstream, dep)

        # Topologically sort sites.
        # TODO should we choose a more optimal structure?
        self._sorted_sites = []
        while num_pending:
            name, count = min(num_pending.items(),
                              key=lambda kv: (kv[1], kv[0]))
            assert count == 0, f"cyclic dependency: {name}"
            del num_pending[name]
            for child in children[name]:
                num_pending[child] -= 1
            site = self._compress_site(sample_sites[name])
            self._sorted_sites.append((name, site))

        # Prune non-essential parts of the trace to save memory.
        for name, site in self.prototype_trace.nodes.items():
            site.clear()
Esempio n. 26
0
    def autoguide(self, name, dist_constructor):
        """
        Sets an autoguide for an existing parameter with name ``name`` (mimic
        the behavior of module :mod:`pyro.infer.autoguide`).

        .. note:: `dist_constructor` should be one of
            :class:`~pyro.distributions.Delta`,
            :class:`~pyro.distributions.Normal`, and
            :class:`~pyro.distributions.MultivariateNormal`. More distribution
            constructor will be supported in the future if needed.

        :param str name: Name of the parameter.
        :param dist_constructor: A
            :class:`~pyro.distributions.distribution.Distribution` constructor.
        """
        if name not in self._priors:
            raise ValueError(
                "There is no prior for parameter: {}".format(name))

        if dist_constructor not in [
                dist.Delta, dist.Normal, dist.MultivariateNormal
        ]:
            raise NotImplementedError(
                "Unsupported distribution type: {}".format(dist_constructor))

        if name in self._guides:
            # delete previous guide's parameters
            dist_args = self._guides[name][1]
            for arg in dist_args:
                arg_name = "{}_{}".format(name, arg)
                if arg_name in self._constraints:
                    # delete its unconstrained parameter
                    self.set_constraint(arg_name, constraints.real)
                delattr(self, arg_name)

        # TODO: create a new argument `autoguide_args` to store other args for other
        # constructors. For example, in LowRankMVN, we need argument `rank`.
        p = self._buffers[name]
        if dist_constructor is dist.Delta:
            p_map = Parameter(p.detach())
            self.register_parameter("{}_map".format(name), p_map)
            self.set_constraint("{}_map".format(name),
                                _get_independent_support(self._priors[name]))
            dist_args = {"map"}
        elif dist_constructor is dist.Normal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            scale = Parameter(loc.new_ones(loc.shape))
            self.register_parameter("{}_loc".format(name), loc)
            self.register_parameter("{}_scale".format(name), scale)
            dist_args = {"loc", "scale"}
        elif dist_constructor is dist.MultivariateNormal:
            loc = Parameter(
                biject_to(self._priors[name].support).inv(p).detach())
            identity = eye_like(loc, loc.size(-1))
            scale_tril = Parameter(identity.repeat(loc.shape[:-1] + (1, 1)))
            self.register_parameter("{}_loc".format(name), loc)
            self.register_parameter("{}_scale_tril".format(name), scale_tril)
            dist_args = {"loc", "scale_tril"}
        else:
            raise NotImplementedError

        if dist_constructor is not dist.Delta:
            # each arg has a constraint, so we set constraints for them
            for arg in dist_args:
                self.set_constraint("{}_{}".format(name, arg),
                                    dist_constructor.arg_constraints[arg])
        self._guides[name] = (dist_constructor, dist_args)
Esempio n. 27
0
 def __init__(self, mean, cov, time=None, frame_num=None):
     super(PositionMeasurement, self).__init__(mean, cov, time=time, frame_num=frame_num)
     self._jacobian = torch.cat([
         eye_like(mean, self.dimension),
         mean.new_zeros((self.dimension, self.dimension))], dim=1)