def __init__( self, n_input: int, n_conditions: int, lam_scale: float, bias_scale: float, alpha: float = 1.0, ): super().__init__() self.n_input = n_input self.n_conditions = n_conditions # weight on monotonic constraint self.register_buffer("alpha", torch.as_tensor(alpha)) # scale of priors on weights and bias self.register_buffer("lam_p_scale", torch.as_tensor(lam_scale)) self.register_buffer("bias_p_scale", torch.as_tensor(bias_scale)) # parameters for guide self.weight_loc = nn.Parameter( torch.nn.init.normal_(torch.Tensor(n_input))) self.weight_scale = PyroParam(torch.full((n_input, ), lam_scale), constraint=constraints.positive) self.bias_loc = nn.Parameter( torch.nn.init.normal_(torch.Tensor(n_conditions))) self.bias_scale = PyroParam(torch.full((n_conditions, ), bias_scale), constraint=constraints.positive)
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.locs = PyroModule() self.scales = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = getattr(frame, "full_size", frame.size) if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) _deep_setattr(self.scales, name, PyroParam(init_scale, self.scale_constraint, event_dim))
def __init__(self, nu=1.5, num_gps=1, length_scale_init=None, kernel_scale_init=None): if nu not in [0.5, 1.5, 2.5]: raise NotImplementedError( "The only supported values of nu are 0.5, 1.5 and 2.5") self.nu = nu self.state_dim = {0.5: 1, 1.5: 2, 2.5: 3}[nu] self.num_gps = num_gps if length_scale_init is None: length_scale_init = torch.ones(num_gps) assert length_scale_init.shape == (num_gps, ) if kernel_scale_init is None: kernel_scale_init = torch.ones(num_gps) assert kernel_scale_init.shape == (num_gps, ) super().__init__() self.length_scale = PyroParam(length_scale_init, constraint=constraints.positive) self.kernel_scale = PyroParam(kernel_scale_init, constraint=constraints.positive) if self.state_dim > 1: for x in range(self.state_dim): for y in range(self.state_dim): mask = torch.zeros(self.state_dim, self.state_dim) mask[x, y] = 1.0 self.register_buffer("mask{}{}".format(x, y), mask)
def __init__(self, obs_dim=1, state_dim=2, obs_noise_scale_init=None, learnable_observation_loc=False): self.obs_dim = obs_dim self.state_dim = state_dim if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim, ) super().__init__() self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim), constraint=constraints.positive) self.trans_matrix = nn.Parameter( torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim)) self.obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim)) self.init_noise_scale_sq = PyroParam(torch.ones(state_dim), constraint=constraints.positive) if learnable_observation_loc: self.obs_loc = nn.Parameter(torch.zeros(obs_dim)) else: self.register_buffer('obs_loc', torch.zeros(obs_dim))
def init_mvn_guide(self): """ Initialize multivariate normal guide """ init_loc = torch.full((self.n_params, ), 0.0) init_scale = eye_like(init_loc, self.n_params) * 0.1 _deep_setattr(self, "mvn.loc", PyroParam(init_loc, constraints.real)) _deep_setattr(self, "mvn.scale_tril", PyroParam(init_scale, constraints.lower_cholesky))
def __init__( self, obs_dim=1, state_dim=2, nu=1.5, obs_noise_scale_init=None, length_scale_init=None, kernel_scale_init=None, learnable_observation_loc=False, ): self.obs_dim = obs_dim self.state_dim = state_dim self.nu = nu if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim, ) super().__init__() self.kernel = MaternKernel( nu=nu, num_gps=obs_dim, length_scale_init=length_scale_init, kernel_scale_init=kernel_scale_init, ) self.dt = 1.0 self.full_state_dim = self.kernel.state_dim * obs_dim + state_dim self.full_gp_state_dim = self.kernel.state_dim * obs_dim self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) self.trans_noise_scale_sq = PyroParam(torch.ones(state_dim), constraint=constraints.positive) self.z_trans_matrix = nn.Parameter( torch.eye(state_dim) + 0.03 * torch.randn(state_dim, state_dim)) self.z_obs_matrix = nn.Parameter(0.3 * torch.randn(state_dim, obs_dim)) self.init_noise_scale_sq = PyroParam(torch.ones(state_dim), constraint=constraints.positive) gp_obs_matrix = torch.zeros(self.kernel.state_dim * obs_dim, obs_dim) for i in range(obs_dim): gp_obs_matrix[self.kernel.state_dim * i, i] = 1.0 self.register_buffer("gp_obs_matrix", gp_obs_matrix) self.obs_selector = torch.tensor( [self.kernel.state_dim * d for d in range(obs_dim)], dtype=torch.long) if learnable_observation_loc: self.obs_loc = nn.Parameter(torch.zeros(obs_dim)) else: self.register_buffer("obs_loc", torch.zeros(obs_dim))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale = PyroParam( self.loc.new_full((self.latent_dim, ), self._init_scale), constraints.positive)
def __init__( self, nu=1.5, dt=1.0, obs_dim=1, length_scale_init=None, kernel_scale_init=None, obs_noise_scale_init=None, ): self.nu = nu self.dt = dt self.obs_dim = obs_dim if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim, ) super().__init__() self.kernel = MaternKernel( nu=nu, num_gps=obs_dim, length_scale_init=length_scale_init, kernel_scale_init=kernel_scale_init, ) self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) obs_matrix = [1.0] + [0.0] * (self.kernel.state_dim - 1) self.register_buffer("obs_matrix", torch.tensor(obs_matrix).unsqueeze(-1))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): value = PyroParam(site["value"].detach(), constraint=site["fn"].support) _deep_setattr(self, name, value)
def __init__(self, nu=1.5, dt=1.0, obs_dim=2, num_gps=1, length_scale_init=None, kernel_scale_init=None, obs_noise_scale_init=None): self.nu = nu self.dt = dt assert obs_dim > 1, "If obs_dim==1 you should use IndependentMaternGP" self.obs_dim = obs_dim self.num_gps = num_gps if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim,) self.dt = dt self.obs_dim = obs_dim self.num_gps = num_gps super().__init__() self.kernel = MaternKernel(nu=nu, num_gps=num_gps, length_scale_init=length_scale_init, kernel_scale_init=kernel_scale_init) self.full_state_dim = num_gps * self.kernel.state_dim self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) self.A = nn.Parameter(0.3 * torch.randn(self.num_gps, self.obs_dim))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale_tril = PyroParam( eye_like(self.loc, self.latent_dim) * self._init_scale, constraints.lower_cholesky)
def __init__( self, nu=1.5, dt=1.0, obs_dim=1, linearly_coupled=False, length_scale_init=None, obs_noise_scale_init=None, ): if nu != 1.5: raise NotImplementedError("The only supported value of nu is 1.5") self.dt = dt self.obs_dim = obs_dim if obs_noise_scale_init is None: obs_noise_scale_init = 0.2 * torch.ones(obs_dim) assert obs_noise_scale_init.shape == (obs_dim, ) super().__init__() self.kernel = MaternKernel(nu=nu, num_gps=obs_dim, length_scale_init=length_scale_init) self.full_state_dim = self.kernel.state_dim * obs_dim # we demote self.kernel.kernel_scale from being a nn.Parameter # since the relevant scales are now encoded in the wiener noise matrix del self.kernel.kernel_scale self.kernel.register_buffer("kernel_scale", torch.ones(obs_dim)) self.obs_noise_scale = PyroParam(obs_noise_scale_init, constraint=constraints.positive) self.wiener_noise_tril = PyroParam( torch.eye(obs_dim) + 0.03 * torch.randn(obs_dim, obs_dim).tril(-1), constraint=constraints.lower_cholesky, ) if linearly_coupled: self.obs_matrix = nn.Parameter( 0.3 * torch.randn(self.obs_dim, self.obs_dim)) else: obs_matrix = torch.zeros(self.full_state_dim, obs_dim) for i in range(obs_dim): obs_matrix[self.kernel.state_dim * i, i] = 1.0 self.register_buffer("obs_matrix", obs_matrix)
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) if self.rank is None: self.rank = int(round(self.latent_dim ** 0.5)) self.scale = PyroParam( self.loc.new_full((self.latent_dim,), 0.5 ** 0.5 * self._init_scale), constraint=constraints.positive) self.cov_factor = nn.Parameter( self.loc.new_empty(self.latent_dim, self.rank).normal_(0, 1 / self.rank ** 0.5))
def __init__(self, sites, name='', init_scale: tp.Union[torch.Tensor, float] = 1., *args, **kwargs): super().__init__(sites, name, *args, **kwargs) self.scale = PyroParam(self._scale_diagonal(init_scale, self.jacobian(self.loc)), event_dim=1, constraint=constraints.positive)
def __init__(self, sites, name='', diag=_nomatch, init_scale_full: tp.Union[torch.Tensor, float] = 1., init_scale_diag: tp.Union[torch.Tensor, float] = 1., *args, **kwargs): self.diag_pattern = re.compile(diag) self.sites_full, self.sites_diag = ({ site['name']: site for site in _ } for _ in partition(lambda _: self.diag_pattern.match(_['name']), sites)) super().__init__( dict_union(self.sites_full, self.sites_diag).values(), name, *args, **kwargs) self.size_full, self.size_diag = (sum(self.sizes[site] for site in _) for _ in (self.sites_full, self.sites_diag)) jac = self.jacobian(self.loc) self.scale_full = PyroParam(self._scale_matrix(init_scale_full, jac[:self.size_full]), event_dim=2, constraint=constraints.lower_cholesky) self.scale_cross = PyroParam(self.loc.new_zeros( torch.Size((self.size_diag, self.size_full))), event_dim=2) self.scale_diag = PyroParam(self._scale_diagonal( init_scale_diag, jac[self.size_full:]), event_dim=1, constraint=constraints.positive) self.guide_z_aux = PyroSample( dist.Normal(self.loc.new_zeros(()), 1.).expand(self.event_shape).to_event(1))
def __init__(self, sites, name='', init_scale: tp.Union[torch.Tensor, float] = 1., *args, **kwargs): super().__init__(sites, name, *args, **kwargs) self.scale_tril = PyroParam(self._scale_matrix(init_scale, self.jacobian( self.loc)), event_dim=2, constraint=constraints.lower_cholesky)
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._plates = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): if site["infer"].get("enumerate") != "parallel": raise NotImplementedError( 'Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format( Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._plates[frame.name] = frame else: raise NotImplementedError( "AutoDiscreteParallel does not support sequential pyro.plate" ) # Initialize guide params for site, Dist, param_spec in self._discrete_sites: name = site["name"] for param_name, param_init, param_constraint in param_spec: _deep_setattr( self, "{}_{}".format(name, param_name), PyroParam(param_init, constraint=param_constraint))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): value = site["value"].detach() event_dim = site["fn"].event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = getattr(frame, "full_size", frame.size) if full_size != frame.size: dim = frame.dim - event_dim value = periodic_repeat(value, full_size, dim).contiguous() value = PyroParam(value, site["fn"].support, event_dim) _deep_setattr(self, name, value)
def __init__(self, create_plates=None): super().__init__() # we define parameters here; make sure that the shape is aligned # with the shapes of sample sites in model. self.ma_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 3, 7), event_dim=3) self.ma_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 3, 7) * 0.1, dist.constraints.positive, event_dim=3) self.snap_weight_loc = PyroParam(torch.zeros(10, 1, 1, 2, 7), event_dim=2) self.snap_weight_scale = PyroParam(torch.ones(10, 1, 1, 2, 7) * 0.1, dist.constraints.positive, event_dim=2) self.seasonal_loc = PyroParam(torch.zeros(10, 1, 7, 2, 7), event_dim=2) self.seasonal_scale = PyroParam(torch.ones(10, 1, 7, 2, 7) * 0.1, dist.constraints.positive, event_dim=2) self.create_plates = create_plates
def __init__(self, sites, name='', *args, **kwargs): super().__init__(sites, name, *args, **kwargs) self.loc = PyroParam(self.init[self.mask], event_dim=1)
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.hidden2locs = PyroModule() self.hidden2scales = PyroModule() if "multiple" in self.encoder_mode: # create module for collecting multiple encoder NN self.multiple_encoders = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # determine the number of hidden layers if "multiple" in self.encoder_mode: if "multiple" in self.n_hidden.keys(): n_hidden = self.n_hidden["multiple"] else: n_hidden = self.n_hidden[name] elif "single" in self.encoder_mode: n_hidden = self.n_hidden["single"] # add linear layer for locs and scales param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2locs, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2scales, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) if "multiple" in self.encoder_mode: # create multiple encoders if self.encoder_instance is not None: # copy instances encoder_ = deepcopy(self.encoder_instance).to(site["value"].device) # convert to pyro module to_pyro_module_(encoder_) _deep_setattr( self.multiple_encoders, name, encoder_, ) else: # create instances _deep_setattr( self.multiple_encoders, name, self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to( site["value"].device ), )