def process_noise_cov(self, dt=0.): ''' Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state `x` resulting from stochastic integration (for use with EKF). (Note that this Q, modulo numerical error, has rank `dimension/2`. So, it is only positive semi-definite.) ''' if dt not in self._Q_cache: with torch.no_grad(): d = self._dimension dt2 = dt * dt dt3 = dt2 * dt dt4 = dt2 * dt2 Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) Q[:d // 2, :d // 2] = 0.25 * dt4 * eye_like(self.sa2, d // 2) Q[:d // 2, d // 2:] = 0.5 * dt3 * eye_like(self.sa2, d // 2) Q[d // 2:, :d // 2] = 0.5 * dt3 * eye_like(self.sa2, d // 2) Q[d // 2:, d // 2:] = dt2 * eye_like(self.sa2, d // 2) Q = Q * self.sa2 self._Q_cache[dt] = Q return self._Q_cache[dt]
def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None, latent_shape=None, num_data=None, whiten=False, jitter=1e-6): super(VariationalSparseGP, self).__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood self.Xu = Parameter(Xu) y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size( []) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape M = self.Xu.size(0) u_loc = self.Xu.new_zeros(self.latent_shape + (M, )) self.u_loc = Parameter(u_loc) identity = eye_like(self.Xu, M) u_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.u_scale_tril = Parameter(u_scale_tril) self.set_constraint("u_scale_tril", constraints.lower_cholesky) self.num_data = num_data if num_data is not None else self.X.size(0) self.whiten = whiten self._sample_latent = True
def _initialize_model_properties(self): if self.max_plate_nesting is None: self._guess_max_plate_nesting() # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. self.model = poutine.enum(config_enumerate(self.model), first_available_dim=-1 - self.max_plate_nesting) if self._automatic_transform_enabled: self.transforms = {} trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) for name, node in trace.iter_stochastic_nodes(): if isinstance(node["fn"], _Subsample): continue if node["fn"].has_enumerate_support: self._has_enumerable_sites = True continue site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) self._r_shapes[name] = site_value.shape self._r_numels[name] = site_value.numel() self._trace_prob_evaluator = TraceEinsumEvaluator( trace, self._has_enumerable_sites, self.max_plate_nesting) mass_matrix_size = sum(self._r_numels.values()) if self.full_mass: initial_mass_matrix = eye_like(site_value, mass_matrix_size) else: initial_mass_matrix = site_value.new_ones(mass_matrix_size) self._adapter.inverse_mass_matrix = initial_mass_matrix
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale_tril = PyroParam( eye_like(self.loc, self.latent_dim) * self._init_scale, constraints.lower_cholesky)
def backward(ctx, grad_output): jitter = 1.0e-8 # do i really need this? z, epsilon, L = ctx.saved_tensors dim = L.shape[0] g = grad_output loc_grad = sum_leftmost(grad_output, -1) identity = eye_like(g, dim) R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2) Sigma_inv = torch.mm(R_inv, R_inv.t()) V, D, _ = torch.svd(Sigma_inv + jitter) D_outer = D.unsqueeze(-1) + D.unsqueeze(0) expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim]) z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple) g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple) Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2) Y = torch.mm(V, torch.mm(Y, V.t())) Y = Y + Y.t() Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv)) diff_L_ab += 0.5 * Tr_xi_Y L_grad = torch.tril(diff_L_ab) return loc_grad, L_grad, None
def model(self): self.set_mode("model") M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal Luu = Kuu.cholesky() zero_loc = self.Xu.new_zeros(self.u_loc.shape) if self.whiten: identity = eye_like(self.Xu, M) pyro.sample(self._pyro_get_fullname("u"), dist.MultivariateNormal(zero_loc, scale_tril=identity) .to_event(zero_loc.dim() - 1)) else: pyro.sample(self._pyro_get_fullname("u"), dist.MultivariateNormal(zero_loc, scale_tril=Luu) .to_event(zero_loc.dim() - 1)) f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: # we would like to load likelihood's parameters outside poutine.scale context self.likelihood._load_pyro_samples() with poutine.scale(scale=self.num_data / self.X.size(0)): return self.likelihood(f_loc, f_var, self.y)
def __init__(self, X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-6, use_cuda=False): super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size( []) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape N = self.X.size(0) f_loc = self.X.new_zeros(self.latent_shape + (N, )) self.f_loc = Parameter(f_loc) identity = eye_like(self.X, N) f_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky) self.whiten = whiten self._sample_latent = True if use_cuda: self.cuda()
def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1)) f_scale_tril = self.f_scale_tril f_loc = self.f_loc f_loc = f_loc + self.mean_function(self.X) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
def jacobian(self, dt): ''' Compute and return cached native state transition Jacobian (F) over time interval ``dt``. :param dt: time interval to integrate over. :return: Read-only Jacobian (F) of integration map (f). ''' if dt not in self._F_cache: d = self._dimension with torch.no_grad(): F = eye_like(self.sa2, d) F[:d // 2, d // 2:] = dt * eye_like(self.sa2, d // 2) self._F_cache[dt] = F return self._F_cache[dt]
def autoguide(self, name, dist_constructor): """ Sets an autoguide for an existing parameter with name ``name`` (mimic the behavior of module :mod:`pyro.infer.autoguide`). .. note:: `dist_constructor` should be one of :class:`~pyro.distributions.Delta`, :class:`~pyro.distributions.Normal`, and :class:`~pyro.distributions.MultivariateNormal`. More distribution constructor will be supported in the future if needed. :param str name: Name of the parameter. :param dist_constructor: A :class:`~pyro.distributions.distribution.Distribution` constructor. """ if name not in self._priors: raise ValueError( "There is no prior for parameter: {}".format(name)) if dist_constructor not in [ dist.Delta, dist.Normal, dist.MultivariateNormal ]: raise NotImplementedError( "Unsupported distribution type: {}".format(dist_constructor)) # delete old guide if name in self._guides: dist_args = self._guides[name][1] for arg in dist_args: delattr(self, "{}_{}".format(name, arg)) p = self._priors[name]() # init_to_sample strategy if dist_constructor is dist.Delta: support = self._priors[name].support if _is_real_support(support): p_map = Parameter(p.detach()) else: p_map = PyroParam(p.detach(), support) setattr(self, "{}_map".format(name), p_map) dist_args = ("map", ) elif dist_constructor is dist.Normal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) scale = PyroParam(loc.new_ones(loc.shape), constraints.positive) setattr(self, "{}_loc".format(name), loc) setattr(self, "{}_scale".format(name), scale) dist_args = ("loc", "scale") elif dist_constructor is dist.MultivariateNormal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) identity = eye_like(loc, loc.size(-1)) scale_tril = PyroParam(identity.repeat(loc.shape[:-1] + (1, 1)), constraints.lower_cholesky) setattr(self, "{}_loc".format(name), loc) setattr(self, "{}_scale_tril".format(name), scale_tril) dist_args = ("loc", "scale_tril") else: raise NotImplementedError self._guides[name] = (dist_constructor, dist_args)
def process_noise_cov(self, dt=0.): ''' Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). ''' if dt not in self._Q_cache: with torch.no_grad(): d = self._dimension dt2 = dt * dt dt3 = dt2 * dt Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) eye = eye_like(self.sa2, d // 2) Q[:d // 2, :d // 2] = dt3 * eye / 3.0 Q[:d // 2, d // 2:] = dt2 * eye / 2.0 Q[d // 2:, :d // 2] = dt2 * eye / 2.0 Q[d // 2:, d // 2:] = dt * eye # sa2 * dt is an intensity factor that changes in velocity # over a sampling period ``dt``, ideally should be ~``sqrt(q*dt)``. Q = Q * (self.sa2 * dt) self._Q_cache[dt] = Q return self._Q_cache[dt]
def __init__(self, dimension, sv2): dimension_pv = 2 * dimension super().__init__(dimension, dimension_pv, num_process_noise_parameters=1) if not isinstance(sv2, torch.Tensor): sv2 = torch.tensor(sv2) self.sv2 = Parameter(sv2) self._F_cache = eye_like(sv2, dimension) # State transition matrix cache self._Q_cache = {} # Process noise cov cache
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale = PyroParam(torch.full_like(self.loc, self._init_scale), self.scale_constraint) self.scale_tril = PyroParam(eye_like(self.loc, self.latent_dim), self.scale_tril_constraint)
def init_mvn_guide(self): """ Initialize multivariate normal guide """ init_loc = torch.full((self.n_params, ), 0.0) init_scale = eye_like(init_loc, self.n_params) * 0.1 _deep_setattr(self, "mvn.loc", PyroParam(init_loc, constraints.real)) _deep_setattr(self, "mvn.scale_tril", PyroParam(init_scale, constraints.lower_cholesky))
def get_posterior(self, *args, **kwargs): """ Returns a MultivariateNormal posterior distribution. """ loc = pyro.param("{}_loc".format(self.prefix), self._init_loc) scale_tril = pyro.param("{}_scale_tril".format(self.prefix), lambda: eye_like(loc, self.latent_dim), constraint=constraints.lower_cholesky) return dist.MultivariateNormal(loc, scale_tril=scale_tril)
def guide(self): a_locs = pyro.param("a_locs", torch.full((self.n_params, ), 0.0)) a_scales_tril = pyro.param( "a_scales", lambda: 0.1 * eye_like(a_locs, self.n_params), constraint=constraints.lower_cholesky) dt = dist.MultivariateNormal(a_locs, scale_tril=a_scales_tril) states = pyro.sample("states", dt, infer={"is_auxiliary": True}) result = {} for i_poiss in torch.arange(self.n_poiss): transform = biject_to(self.poiss_priors[i_poiss].support) value = transform(states[i_poiss]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_poiss]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.poiss_priors[i_poiss].event_dim) result[self.labels_poiss[i_poiss]] = pyro.sample( self.labels_poiss[i_poiss], dist.Delta(value, log_density=log_density, event_dim=self.poiss_priors[i_poiss].event_dim)) i_param = self.n_poiss for i_ps in torch.arange(self.n_ps): for i_ps_param in torch.arange(self.n_ps_params): transform = biject_to(self.ps_priors[i_ps][i_ps_param].support) value = transform(states[i_param]) log_density = transform.inv.log_abs_det_jacobian( value, states[i_param]) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + self.ps_priors[i_ps][i_ps_param].event_dim) result[self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps]] = pyro.sample( self.labels_ps_params[i_ps_param] + "_" + self.labels_ps[i_ps], dist.Delta(value, log_density=log_density, event_dim=self.ps_priors[i_ps] [i_ps_param].event_dim)) i_param += 1 return result
def __init__(self, mean, cov, time=None, frame_num=None): super().__init__(mean, cov, time=time, frame_num=frame_num) self._jacobian = torch.cat([ eye_like(mean, self.dimension), torch.zeros(self.dimension, self.dimension, dtype=mean.dtype, device=mean.device) ], dim=1)
def guide(self): if self.mygroup is None: self.mygroup, self.z_init_loc = self._get_group( match=self.guide_conf['match']) z_loc = pyro.param(self.prefix + "guide_z_loc", self.z_init_loc) z_scale_tril = pyro.param(self.prefix + "guide_z_scale_tril", 0.01 * eye_like(z_loc, len(z_loc)), constraint=constraints.lower_cholesky) # TODO: Flexible initial error guide_z, model_zs = self.mygroup.sample( self.prefix + 'guide_z', dist.MultivariateNormal(z_loc, scale_tril=z_scale_tril)) return guide_z, model_zs
def process_noise_cov(self, dt=0.): ''' Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state `x` resulting from stochastic integration (for use with EKF). ''' if dt not in self._Q_cache: Q = self.sv2 * dt * dt * eye_like(self.sv2, self._dimension) self._Q_cache[dt] = Q return self._Q_cache[dt]
def process_noise_cov(self, dt=0.): ''' Compute and return cached process noise covariance (Q). :param dt: time interval to integrate over. :return: Read-only covariance (Q) of the native state ``x`` resulting from stochastic integration (for use with EKF). ''' if dt not in self._Q_cache: # q: continuous-time process noise intensity with units # length^2/time (m^2/s). Choose ``q`` so that changes in position, # over a sampling period ``dt``, are roughly ``sqrt(q*dt)``. q = self.sv2 * dt Q = q * dt * eye_like(self.sv2, self._dimension) self._Q_cache[dt] = Q return self._Q_cache[dt]
def update(self, measurement): """ Use measurement to update state estimate in-place and return innovation. The innovation is useful, e.g., for evaluating filter consistency or updating model likelihoods when the ``EKFState`` is part of an ``IMMFState``. :param: measurement. :returns: EKF State, Innovation mean and covariance. """ if self._time is not None: assert (self._time == measurement.time ), "State time and measurement time must be aligned!" if self._frame_num is not None: assert (self._frame_num == measurement.frame_num ), "State time and measurement time must be aligned!" x = self._mean x_pv = self._dynamic_model.mean2pv(x) P = self.cov H = measurement.jacobian(x_pv)[:, :self.dimension] R = measurement.cov z = measurement.mean z_predicted = measurement(x_pv) dz = measurement.geodesic_difference(z, z_predicted) S = H.mm(P).mm(H.transpose(-1, -2)) + R # innovation cov K_prefix = self._cov.mm(H.transpose(-1, -2)) dx = K_prefix.mm(torch.linalg.solve(S, dz.unsqueeze(1))).squeeze( 1) # K*dz x = self._dynamic_model.geodesic_difference(x, -dx) I = eye_like(x, self._dynamic_model.dimension) # noqa: E741 ImKH = I - K_prefix.mm(torch.linalg.solve(S, H)) # *Joseph form* of covariance update for numerical stability. S_inv_R = torch.linalg.solve(S, R) P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) + K_prefix.mm( torch.linalg.solve(S, K_prefix.mm(S_inv_R).transpose(-1, -2))) pred_mean = x pred_cov = P state = EKFState(self._dynamic_model, pred_mean, pred_cov, self._time, self._frame_num) return state, (dz, S)
def _initialize_adapter(self): mass_matrix_size = sum( [p.numel() for p in self.initial_params.values()]) site_value = list(self.initial_params.values())[0] if self._adapter.is_diag_mass: initial_mass_matrix = torch.ones(mass_matrix_size, dtype=site_value.dtype, device=site_value.device) else: initial_mass_matrix = eye_like(site_value, mass_matrix_size) self._adapter.configure( self._warmup_steps, inv_mass_matrix=initial_mass_matrix, find_reasonable_step_size_fn=self._find_reasonable_step_size) if self._adapter.adapt_step_size: self._adapter.reset_step_size_adaptation(self._initial_params)
def __init__( self, X, y, kernel, Xu, likelihood, mean_function=None, latent_shape=None, num_data=None, whiten=False, jitter=1e-6, ): assert isinstance( X, torch.Tensor ), "X needs to be a torch Tensor instead of a {}".format(type(X)) if y is not None: assert isinstance( y, torch.Tensor ), "y needs to be a torch Tensor instead of a {}".format(type(y)) assert isinstance( Xu, torch.Tensor ), "Xu needs to be a torch Tensor instead of a {}".format(type(Xu)) super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood self.Xu = Parameter(Xu) y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([]) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape M = self.Xu.size(0) u_loc = self.Xu.new_zeros(self.latent_shape + (M,)) self.u_loc = Parameter(u_loc) identity = eye_like(self.Xu, M) u_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.u_scale_tril = PyroParam(u_scale_tril, constraints.lower_cholesky) self.num_data = num_data if num_data is not None else self.X.size(0) self.whiten = whiten self._sample_latent = True
def GP_sample(name, X, f_loc, f_scale_tril, f_loc_mean, Lff=None, kernel=None, jitter=1e-6, whiten=False): N = X.size(0) if Lff is None: Kff = kernel(X).contiguous() Kff.view(-1)[::N + 1] += jitter # add jitter to the diagonal Lff = Kff.cholesky() zero_loc = X.new_zeros(f_loc.shape) if whiten: identity = eye_like(X, N) gp_sample = pyro.sample( name, dist.MultivariateNormal( zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1)) gp_sample = Lff.matmul(gp_sample.unsqueeze(-1)).squeeze(-1) else: gp_sample = pyro.sample( name, dist.MultivariateNormal( zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1)) # gp_sample += (f_loc + f_loc_mean) # the guide already accounts for f_loc gp_sample += f_loc_mean return gp_sample
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self.locs = PyroModule() self.scales = PyroModule() self.scale_trils = PyroModule() self.conds = PyroModule() self.deps = PyroModule() self._batch_shapes = {} self._unconstrained_event_shapes = {} sample_sites = OrderedDict( self.prototype_trace.iter_stochastic_nodes()) self._auto_config(sample_sites, args, kwargs) # Collect unconstrained shapes. init_locs = {} numel = {} for name, site in sample_sites.items(): with helpful_support_errors(site): init_loc = (biject_to(site["fn"].support).inv( site["value"].detach()).detach()) self._batch_shapes[name] = site["fn"].batch_shape self._unconstrained_event_shapes[name] = init_loc.shape[ len(site["fn"].batch_shape):] numel[name] = init_loc.numel() init_locs[name] = init_loc.reshape(-1) # Initialize guide params. children = defaultdict(list) num_pending = {} for name, site in sample_sites.items(): # Initialize location parameters. init_loc = init_locs[name] deep_setattr(self.locs, name, PyroParam(init_loc)) # Initialize parameters of conditional distributions. conditional = self.conditionals[name] if callable(conditional): deep_setattr(self.conds, name, conditional) else: if conditional not in ("delta", "normal", "mvn"): raise ValueError( f"Unsupported conditional type: {conditional}") if conditional in ("normal", "mvn"): init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr(self.scales, name, PyroParam(init_scale, self.scale_constraint)) if conditional == "mvn": init_scale_tril = eye_like(init_loc, init_loc.numel()) deep_setattr( self.scale_trils, name, PyroParam(init_scale_tril, self.scale_tril_constraint), ) # Initialize dependencies on upstream variables. num_pending[name] = 0 deps = PyroModule() deep_setattr(self.deps, name, deps) for upstream, dep in self.dependencies.get(name, {}).items(): assert upstream in sample_sites children[upstream].append(name) num_pending[name] += 1 if isinstance(dep, str) and dep == "linear": dep = torch.nn.Linear(numel[upstream], numel[name], bias=False) dep.weight.data.zero_() elif not callable(dep): raise ValueError( f"Expected either the string 'linear' or a callable, but got {dep}" ) deep_setattr(deps, upstream, dep) # Topologically sort sites. # TODO should we choose a more optimal structure? self._sorted_sites = [] while num_pending: name, count = min(num_pending.items(), key=lambda kv: (kv[1], kv[0])) assert count == 0, f"cyclic dependency: {name}" del num_pending[name] for child in children[name]: num_pending[child] -= 1 site = self._compress_site(sample_sites[name]) self._sorted_sites.append((name, site)) # Prune non-essential parts of the trace to save memory. for name, site in self.prototype_trace.nodes.items(): site.clear()
def autoguide(self, name, dist_constructor): """ Sets an autoguide for an existing parameter with name ``name`` (mimic the behavior of module :mod:`pyro.infer.autoguide`). .. note:: `dist_constructor` should be one of :class:`~pyro.distributions.Delta`, :class:`~pyro.distributions.Normal`, and :class:`~pyro.distributions.MultivariateNormal`. More distribution constructor will be supported in the future if needed. :param str name: Name of the parameter. :param dist_constructor: A :class:`~pyro.distributions.distribution.Distribution` constructor. """ if name not in self._priors: raise ValueError( "There is no prior for parameter: {}".format(name)) if dist_constructor not in [ dist.Delta, dist.Normal, dist.MultivariateNormal ]: raise NotImplementedError( "Unsupported distribution type: {}".format(dist_constructor)) if name in self._guides: # delete previous guide's parameters dist_args = self._guides[name][1] for arg in dist_args: arg_name = "{}_{}".format(name, arg) if arg_name in self._constraints: # delete its unconstrained parameter self.set_constraint(arg_name, constraints.real) delattr(self, arg_name) # TODO: create a new argument `autoguide_args` to store other args for other # constructors. For example, in LowRankMVN, we need argument `rank`. p = self._buffers[name] if dist_constructor is dist.Delta: p_map = Parameter(p.detach()) self.register_parameter("{}_map".format(name), p_map) self.set_constraint("{}_map".format(name), _get_independent_support(self._priors[name])) dist_args = {"map"} elif dist_constructor is dist.Normal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) scale = Parameter(loc.new_ones(loc.shape)) self.register_parameter("{}_loc".format(name), loc) self.register_parameter("{}_scale".format(name), scale) dist_args = {"loc", "scale"} elif dist_constructor is dist.MultivariateNormal: loc = Parameter( biject_to(self._priors[name].support).inv(p).detach()) identity = eye_like(loc, loc.size(-1)) scale_tril = Parameter(identity.repeat(loc.shape[:-1] + (1, 1))) self.register_parameter("{}_loc".format(name), loc) self.register_parameter("{}_scale_tril".format(name), scale_tril) dist_args = {"loc", "scale_tril"} else: raise NotImplementedError if dist_constructor is not dist.Delta: # each arg has a constraint, so we set constraints for them for arg in dist_args: self.set_constraint("{}_{}".format(name, arg), dist_constructor.arg_constraints[arg]) self._guides[name] = (dist_constructor, dist_args)
def __init__(self, mean, cov, time=None, frame_num=None): super(PositionMeasurement, self).__init__(mean, cov, time=time, frame_num=frame_num) self._jacobian = torch.cat([ eye_like(mean, self.dimension), mean.new_zeros((self.dimension, self.dimension))], dim=1)