class RelationGCN(torch.jit.ScriptModule): def __init__(self, input_dim, hidden_dim, n_relations, n_bases, n_class): super(RelationGCN, self).__init__() self.conv1 = RGCNConv(input_dim, hidden_dim, n_relations, n_bases) self.conv2 = RGCNConv(hidden_dim, hidden_dim, n_relations, n_bases) self.weight = Parameter(torch.zeros(hidden_dim, n_class)) self.bias = Parameter(torch.zeros(n_class)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.weight) if self.bias.dim() > 1: torch.nn.init.xavier_uniform_(self.bias) @torch.jit.script_method def forward_(self, x, first_edge_index, second_edge_index, first_edge_type, second_edge_type): # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor x = self.embedding_(x, first_edge_index, second_edge_index, first_edge_type, second_edge_type) x = torch.matmul(x, self.weight) x = x + self.bias return F.log_softmax(x, dim=1) @torch.jit.script_method def loss(self, y_pred, y_true): y_true = y_true.view(-1).to(torch.long) return F.nll_loss(y_pred, y_true) @torch.jit.script_method def predict_(self, x, first_edge_index, second_edge_index, first_edge_type, second_edge_type): # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor output = self.forward_(x, first_edge_index, second_edge_index, first_edge_type, second_edge_type) return output.max(1)[1] @torch.jit.script_method def embedding_(self, x, first_edge_index, second_edge_index, first_edge_type, second_edge_type): # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor x = F.relu(self.conv1(x, second_edge_index, second_edge_type, None)) x = self.conv2(x, first_edge_index, first_edge_type, None) return x
def _compress_module_param_dim( param: Parameter, target_dim: int, idxs_to_keep: Tensor, module: Optional[Module] = None, optimizer: Optional[Optimizer] = None, ): if param.dim() == 1: target_dim = 0 if param.size(target_dim) == 1 and idxs_to_keep.numel() > 1: # DW Conv return if param.size(target_dim) % idxs_to_keep.size(0) != 0: _LOGGER.debug( "skipping compression of parameter due to shape incompatibility") stride = param.data.size(target_dim) // idxs_to_keep.size(0) if stride > 1: idxs_to_keep = idxs_to_keep.reshape(-1, 1).expand(-1, stride).reshape(-1) param.data = (param.data[idxs_to_keep, ...] if target_dim == 0 else param.data[:, idxs_to_keep, ...]) if param.grad is not None: param.grad = (param.grad[idxs_to_keep, ...] if target_dim == 0 else param.grad[:, idxs_to_keep, ...]) if (optimizer is not None and param in optimizer.state and ("momentum_buffer" in optimizer.state[param])): optimizer.state[param]["momentum_buffer"] = ( optimizer.state[param]["momentum_buffer"][idxs_to_keep, ...] if target_dim == 0 else optimizer.state[param]["momentum_buffer"][:, idxs_to_keep, ...]) # update module attrs if module is not None: # Batch Norm if param.dim() == 1: if hasattr(module, "num_features"): module.num_features = param.size(0) # BN running mean and var are not stored as Parameters so we must # update them here if hasattr(module, "running_mean") and (module.running_mean.size(0) == idxs_to_keep.size(0)): module.running_mean = module.running_mean[idxs_to_keep] if hasattr(module, "running_var") and (module.running_var.size(0) == idxs_to_keep.size(0)): module.running_var = module.running_var[idxs_to_keep] # Linear elif target_dim == 0 and hasattr(module, "out_features"): module.out_features = param.size(0) elif target_dim == 1 and hasattr(module, "in_features"): module.in_features = param.size(1) # Conv elif target_dim == 0 and hasattr(module, "out_channels"): module.out_channels = param.size(0) elif target_dim == 1 and hasattr(module, "in_channels"): module.in_channels = param.size(1) if (hasattr(module, "groups") and module.groups > 1 and (hasattr(module, "out_channels") and hasattr(module, "in_channels"))): module.groups = param.size(0) // param.size(1)
class ADKL_KRR_net(MetaNetwork): TRAIN = 0 DESCR = 1 BOTH = 2 NB_KERNEL_PARAMS = 1 def __init__(self, input_features_extractor_params, target_features_extractor_params, condition_on='train', task_descr_extractor_params=None, dataset_encoder_params=None, hp_mode='fixe', l2=0.1, device='cuda', task_encoder_reg=0., n_pseudo_inputs=0, pseudo_inputs_reg=0, stationary_kernel=False): """ In the constructor we instantiate an lstm module """ super(ADKL_KRR_net, self).__init__() if condition_on.lower() in ['train', 'train_samples']: assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified' self.condition_on = self.TRAIN elif condition_on.lower() in ['descr', 'task_descr']: assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified' self.condition_on = self.DESCR elif condition_on.lower() in ['both']: assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified' assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified' self.condition_on = self.BOTH else: raise ValueError('Invalid option for parameter condition_on') self.task_encoder_reg = task_encoder_reg if input_features_extractor_params.get('pooling_fn', 0) is None: pooling = GlobalAvgPool1d(dim=1) spectral_kernel = True else: pooling = None spectral_kernel = False task_encoder = TaskEncoderNet( input_features_extractor_params, target_features_extractor_params, dataset_encoder_params, complement_module_input_fextractor=pooling) self.features_extractor = task_encoder.input_fextractor fe_dim = self.features_extractor.output_dim self.task_descr_extractor = None tde_dim, de_dim = 0, 0 if self.condition_on in [self.DESCR, self.BOTH]: self.task_descr_extractor = FeaturesExtractorFactory()( **task_descr_extractor_params) tde_dim = self.task_descr_extractor.output_dim if self.condition_on in [self.TRAIN, self.BOTH]: self.dataset_encoder = task_encoder de_dim = self.dataset_encoder.output_dim self.l2 = l2 self.pseudo_inputs_reg = pseudo_inputs_reg self.hp_mode = hp_mode self.device = device if not stationary_kernel: self.kernel_network = NonStationaryKernel(fe_dim, de_dim + tde_dim, fe_dim, spectral_kernel) else: self.kernel_network = StationaryKernel(fe_dim, de_dim + tde_dim, fe_dim, spectral_kernel) if n_pseudo_inputs > 0: if spectral_kernel: self.pseudo_inputs = Parameter( torch.Tensor(n_pseudo_inputs, fe_dim)).to(device) else: self.pseudo_inputs = Parameter( torch.Tensor(n_pseudo_inputs, fe_dim)).to(device) else: self.pseudo_inputs = None self.phis_train_mean, self.phis_train_std = 0, 0 if hp_mode.lower() in ['learn', 'learned', 'l']: self.hp_mode = 'l' elif hp_mode.lower() in [ 'predicted', 'predict', 'p', 't', 'task-specific', 'per-task' ]: self.hp_mode = 't' d = (de_dim if de_dim else 0) + (tde_dim if tde_dim else 0) self.hp_net = Linear(d, self.NB_KERNEL_PARAMS) else: raise Exception('hp_mode should be one of those: fixe, learn, cv') self._init_kernel_params(device) def _init_kernel_params(self, device): if self.pseudo_inputs is not None: init.kaiming_uniform_(self.pseudo_inputs, a=math.sqrt(5)) self.l2 = torch.FloatTensor([self.l2]).to(device) if self.hp_mode == 'l': self.l2 = Parameter(self.l2) def compute_batch_gram_matrix(self, x, y, task_phis): k_ = self.kernel_network(x, y, task_phis) if self.pseudo_inputs is not None: ps = self.pseudo_inputs.unsqueeze(0).expand( x.shape[0], *self.pseudo_inputs.shape) k_g = self.kernel_network(x, ps, task_phis) k_ = torch.cat((k_, k_g), dim=-1) return k_ def set_kernel_params(self, task_phis): if self.hp_mode == 't': self.l2 = self.hp_net(task_phis).squeeze(-1) l2 = hardtanh(self.l2.exp(), 1e-4, 1e1) return l2 def add_pseudo_inputs_loss(self, loss): n = self.pseudo_inputs.shape[0] d = self.pseudo_inputs.shape[-1] p = self.pseudo_inputs.reshape(-1, d) # reg = torch.exp(-0.5 * (p.unsqueeze(2) - p.unsqueeze(1)).pow(2).sum(-1)) # reg = torch.tril(reg).sum() / (n * (n - 1)) if self.pseudo_inputs.dim() == 2: pi_mean = torch.mean(self.pseudo_inputs, dim=0) pi_std = torch.std(self.pseudo_inputs, dim=0) elif self.pseudo_inputs.dim() == 3: pi_mean = torch.mean(self.pseudo_inputs, dim=(0, 1)), pi_std = torch.std(self.pseudo_inputs, dim=(0, 1)) else: raise Exception( 'Pseudo inputs: the number of dimensions is incorrect') kl = kl_divergence( MultivariateNormal(pi_mean, torch.diag(pi_std + 0.1)), MultivariateNormal(self.phis_train_mean, torch.diag(self.phis_train_std + 0.1))) res = self.pseudo_inputs_reg * kl return loss + res, res def add_task_encoder_loss(self, loss): reg = self.task_encoder_reg * self.task_encoder_loss return loss + reg, reg def get_alphas(self, phis, ys, masks, task_phis=None): l2 = self.set_kernel_params(task_phis) bsize, n_train = phis.shape[:2] k_ = self.compute_batch_gram_matrix(phis, phis, task_phis=task_phis) k = torch.bmm(k_, k_.transpose(1, 2)) k_mask = masks[:, None, :] * masks[:, :, None] k = k * k_mask identity = torch.eye(n_train, device=k.device).unsqueeze(0).expand( (bsize, n_train, n_train)) batch_K_inv = torch.inverse(k + l2.unsqueeze(1).unsqueeze(1) * identity) alphas = torch.bmm(batch_K_inv, ys) return alphas, k_ def get_preds(self, alphas, K_train, phis_train, masks_train, phis_test, masks_test, task_phis=None): k = self.compute_batch_gram_matrix(phis_test, phis_train, task_phis=task_phis) k = torch.bmm(k, K_train.transpose(1, 2)) k_mask = masks_test[:, :, None] * masks_train[:, None, :] k = k * k_mask preds = torch.bmm(k, alphas) return preds def get_task_phis(self, tasks_descr, xs_train, ys_train, mask_train): if self.condition_on == self.DESCR: task_phis = self.task_descr_extractor(tasks_descr) elif self.condition_on == self.TRAIN: task_phis = self.dataset_encoder(None, xs_train, ys_train, mask_train) else: task_phis = torch.cat([ self.task_descr_extractor(tasks_descr), self.dataset_encoder(None, xs_train, ys_train, mask_train) ], dim=1) return task_phis def get_phis(self, xs, train=False): phis = self.features_extractor(xs.reshape(-1, xs.shape[2])) if train: alpha = 0.8 if phis.dim() == 2: self.phis_train_mean = ( 1 - alpha) * self.phis_train_mean + alpha * torch.mean( phis, dim=0).detach() self.phis_train_std = ( 1 - alpha) * self.phis_train_std + alpha * torch.std( phis, dim=0).detach() elif phis.dim() == 3: self.phis_train_mean = ( 1 - alpha) * self.phis_train_mean + alpha * torch.mean( phis, dim=(0, 1)).detach() self.phis_train_std = ( 1 - alpha) * self.phis_train_std + alpha * torch.std( phis, dim=(0, 1)).detach() phis = phis.reshape(*xs.shape[:2], *phis.shape[1:]) return phis def forward(self, episodes): if self.condition_on == self.TRAIN: train, test = pack_episodes(episodes, return_tasks_descr=False) xs_train, ys_train, lens_train, mask_train = train xs_test, lens_test, mask_test = test task_phis = self.get_task_phis(None, xs_train, ys_train, mask_train) else: train, test, tasks_descr = pack_episodes(episodes, return_tasks_descr=True) xs_train, ys_train, lens_train, mask_train = train xs_test, lens_test, mask_test = test task_phis = self.get_task_phis(tasks_descr, xs_train, ys_train, mask_train) phis_train, phis_test = self.get_phis( xs_train, train=True), self.get_phis(xs_test, train=False) # training alphas, K_train = self.get_alphas(phis_train, ys_train, mask_train, task_phis) # testing preds = self.get_preds(alphas, K_train, phis_train, mask_train, phis_test, mask_test, task_phis) self.compute_task_encoder_loss_last_batch(episodes) if isinstance(preds, tuple): return [ tuple(x[:n] for x in pred) for n, pred in zip(lens_test, preds) ] else: return [x[:n] for n, x in zip(lens_test, preds)] def compute_task_encoder_loss_last_batch(self, episodes): # train x test set_code = self.dataset_encoder(episodes) y_preds_class = torch.arange(len(set_code)) if set_code.is_cuda: y_preds_class = y_preds_class.to('cuda') accuracy = (set_code.argmax(dim=1) == y_preds_class).sum().item() / len(set_code) b = set_code.size(0) mi = set_code.diagonal().mean() \ - torch.log((set_code * (1 - torch.eye(b))).exp().sum() / (b * (b - 1))) loss = -mi self.accuracy = accuracy self.task_encoder_loss = loss
class VariationalGP(GPModel): def __init__(self, X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-6, use_cuda=False): super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size( []) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape N = self.X.size(0) f_loc = self.X.new_zeros(self.latent_shape + (N, )) self.f_loc = Parameter(f_loc) identity = eye_like(self.X, N) f_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky) self.whiten = whiten self._sample_latent = True if use_cuda: self.cuda() @pyro_method def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1)) f_scale_tril = self.f_scale_tril f_loc = self.f_loc f_loc = f_loc + self.mean_function(self.X) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y) @pyro_method def guide(self): self.set_mode("guide") self._load_pyro_samples() pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal( self.f_loc, scale_tril=self.f_scale_tril).to_event(self.f_loc.dim() - 1)) def forward(self, Xnew, full_cov=False): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov). .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") loc, cov = conditional(Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril, full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) return loc + self.mean_function(Xnew), cov
class Table(Module): r""" Defines a Table of parameters of an insurance contract or an assumption with padding mechanism supported throw `pad_mode` and `pad_value`. .. math:: \text{out}_{i, j} = \text{table}_{\text{index}_i, j} All Table are inherited from this class. Attributes: - :attr:`name` (str) the name of Table. - :attr:`table` (Tensor) the table will be indexed, the dim of the table should be not more than 2. Padding can be used for long tedious table. - :attr:`n_col` (int) the total column number of the table after padding. """ def __init__(self, name: str, table: Tensor, n_col: int = None, *, pad_value: float = 0, pad_mode=0): """ :param str name: the name of Table :param Tensor table: raw table :param int n_col: if None(default), then no padding will be act on the input `table`, if provided the n_col is the total column number of the table after padding. :param float pad_value: value needed for the `pad_mode` to work, for example pad_value is the valued filled if `pad_mode=PadMode.Constant`. :param Union[int, PadMode] pad_mode: how to perform the padding, Constant padding by default. """ super().__init__() self.name = name.strip() self.table = Parameter(table) self.n_col = n_col if self.table.nelement() == 1: self._need_lookup = False elif self.table.dim() == 1: self.n_col = self.table.nelement() self._need_lookup = False elif n_col and n_col > self.table.shape[1]: self._need_lookup = True else: self.n_col = self.table.shape[1] self.pad_value = pad_value self.pad_mode = pad_mode def forward(self, index: Tensor): """ :param index: 1-D Tensor, index for index select :return: rows at `index` """ table = pad(self.table, self.n_col, self.pad_value, self.pad_mode) if self._need_lookup: return torch.index_select(table, 0, index.long()) else: return table.expand(index.nelement(), self.n_col)
class VariationalGP(GPModel): r""" Variational Gaussian Process model. This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\ :math:`X` and their noisy observations :math:`y`, the model takes the form .. math:: f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim p(y) = p(y \mid f) p(f), where :math:`p(y \mid f)` is the likelihood. We will use a variational approach in this model by approximating :math:`q(f)` to the posterior :math:`p(f\mid y)`. Precisely, :math:`q(f)` will be a multivariate normal distribution with two parameters ``f_loc`` and ``f_scale_tril``, which will be learned during a variational inference process. .. note:: This model can be seen as a special version of :class:`.SparseVariationalGP` model with :math:`X_u = X`. .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training, :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number of train inputs. Size of variational parameters is :math:`\mathcal{O}(N^2)`. :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood object. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of :math:`q(f)`). By default, it equals to output batch shape ``y.shape[:-1]``. For the multi-class classification problems, ``latent_shape[-1]`` should corresponse to the number of classes. :param bool whiten: A flag to tell if variational parameters ``f_loc`` and ``f_scale_tril`` are transformed by the inverse of ``Lff``, where ``Lff`` is the lower triangular decomposition of :math:`kernel(X, X)`. Enable this flag will help optimization. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ def __init__(self, X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-6): super(VariationalGP, self).__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([]) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape N = self.X.size(0) f_loc = self.X.new_zeros(self.latent_shape + (N,)) self.f_loc = Parameter(f_loc) identity = eye_like(self.X, N) f_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.f_scale_tril = Parameter(f_scale_tril) self.set_constraint("f_scale_tril", constraints.lower_cholesky) self.whiten = whiten self._sample_latent = True @autoname.scope(prefix="VGP") def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal Lff = Kff.cholesky() zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) pyro.sample("f", dist.MultivariateNormal(zero_loc, scale_tril=identity) .to_event(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: pyro.sample("f", dist.MultivariateNormal(zero_loc, scale_tril=Lff) .to_event(zero_loc.dim() - 1)) f_scale_tril = self.f_scale_tril f_loc = self.f_loc f_loc = f_loc + self.mean_function(self.X) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y) @autoname.scope(prefix="VGP") def guide(self): self.set_mode("guide") pyro.sample("f", dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril) .to_event(self.f_loc.dim()-1)) def forward(self, Xnew, full_cov=False): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov). .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") loc, cov = conditional(Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril, full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) return loc + self.mean_function(Xnew), cov
class VariationalSparseGP(GPModel): r""" Variational Sparse Gaussian Process model. In :class:`.VariationalGP` model, when the number of input data :math:`X` is large, the covariance matrix :math:`k(X, X)` will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). This model introduces an additional inducing-input parameter :math:`X_u` to solve that problem. Given inputs :math:`X`, their noisy observations :math:`y`, and the inducing-input parameters :math:`X_u`, the model takes the form: .. math:: [f, u] &\sim \mathcal{GP}(0, k([X, X_u], [X, X_u])),\\ y & \sim p(y) = p(y \mid f) p(f), where :math:`p(y \mid f)` is the likelihood. We will use a variational approach in this model by approximating :math:`q(f,u)` to the posterior :math:`p(f,u \mid y)`. Precisely, :math:`q(f) = p(f\mid u)q(u)`, where :math:`q(u)` is a multivariate normal distribution with two parameters ``u_loc`` and ``u_scale_tril``, which will be learned during a variational inference process. .. note:: This model can be learned using MCMC method as in reference [2]. See also :class:`.GPModel`. .. note:: This model has :math:`\mathcal{O}(NM^2)` complexity for training, :math:`\mathcal{O}(M^3)` complexity for testing. Here, :math:`N` is the number of train inputs, :math:`M` is the number of inducing inputs. Size of variational parameters is :math:`\mathcal{O}(M^2)`. References: [1] `Scalable variational Gaussian process classification`, James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani [2] `MCMC for Variationally Sparse Gaussian Processes`, James Hensman, Alexander G. de G. Matthews, Maurizio Filippone, Zoubin Ghahramani :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param torch.Tensor Xu: Initial values for inducing points, which are parameters of our model. :param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood object. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of :math:`q(u)`). By default, it equals to output batch shape ``y.shape[:-1]``. For the multi-class classification problems, ``latent_shape[-1]`` should corresponse to the number of classes. :param int num_data: The size of full training dataset. It is useful for training this model with mini-batch. :param bool whiten: A flag to tell if variational parameters ``u_loc`` and ``u_scale_tril`` are transformed by the inverse of ``Luu``, where ``Luu`` is the lower triangular decomposition of :math:`kernel(X_u, X_u)`. Enable this flag will help optimization. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ def __init__(self, X, y, kernel, Xu, likelihood, mean_function=None, latent_shape=None, num_data=None, whiten=False, jitter=1e-6): super(VariationalSparseGP, self).__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood self.Xu = Parameter(Xu) y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size( []) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape M = self.Xu.size(0) u_loc = self.Xu.new_zeros(self.latent_shape + (M, )) self.u_loc = Parameter(u_loc) identity = eye_like(self.Xu, M) u_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.u_scale_tril = Parameter(u_scale_tril) self.set_constraint("u_scale_tril", constraints.lower_cholesky) self.num_data = num_data if num_data is not None else self.X.size(0) self.whiten = whiten self._sample_latent = True @autoname.scope(prefix="VSGP") def model(self): self.set_mode("model") M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal Luu = Kuu.cholesky() zero_loc = self.Xu.new_zeros(self.u_loc.shape) if self.whiten: identity = eye_like(self.Xu, M) pyro.sample( "u", dist.MultivariateNormal( zero_loc, scale_tril=identity).to_event(zero_loc.dim() - 1)) else: pyro.sample( "u", dist.MultivariateNormal( zero_loc, scale_tril=Luu).to_event(zero_loc.dim() - 1)) f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: with poutine.scale(scale=self.num_data / self.X.size(0)): return self.likelihood(f_loc, f_var, self.y) @autoname.scope(prefix="VSGP") def guide(self): self.set_mode("guide") pyro.sample( "u", dist.MultivariateNormal( self.u_loc, scale_tril=self.u_scale_tril).to_event(self.u_loc.dim() - 1)) def forward(self, Xnew, full_cov=False): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, X_u, u_{loc}, u_{scale\_tril}) = \mathcal{N}(loc, cov). .. note:: Variational parameters ``u_loc``, ``u_scale_tril``, the inducing-point parameter ``Xu``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") loc, cov = conditional(Xnew, self.Xu, self.kernel, self.u_loc, self.u_scale_tril, full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) return loc + self.mean_function(Xnew), cov
class SAGETwoSoftmax(torch.jit.ScriptModule): def __init__(self, in_dim, hidden, out_dim): super(SAGETwoSoftmax, self).__init__() self.weight1 = Parameter(torch.zeros(in_dim * 2, hidden)) self.bias1 = Parameter(torch.zeros(hidden)) self.weight2 = Parameter(torch.zeros(hidden * 2, hidden)) self.bias2 = Parameter(torch.zeros(hidden)) self.weight3 = Parameter(torch.zeros(hidden, out_dim)) self.bias3 = Parameter(torch.zeros(out_dim)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.weight1) torch.nn.init.xavier_uniform_(self.weight2) torch.nn.init.xavier_uniform_(self.weight3) if self.bias1.dim() > 1: torch.nn.init.xavier_uniform_(self.bias1) if self.bias2.dim() > 1: torch.nn.init.xavier_uniform_(self.bias2) if self.bias3.dim() > 1: torch.nn.init.xavier_uniform_(self.bias3) @torch.jit.script_method def forward_(self, x, first_edge_index, second_edge_index): embedding = self.embedding_(x, first_edge_index, second_edge_index) out = torch.matmul(embedding, self.weight3) out = out + self.bias3 return F.log_softmax(out, dim=1) @torch.jit.script_method def loss(self, y_pred, y_true): y_true = y_true.view(-1).to(torch.long) return F.nll_loss(y_pred, y_true) @torch.jit.script_method def predict_(self, x, first_edge_index, second_edge_index): output = self.forward_(x, first_edge_index, second_edge_index) return output.max(1)[1] @torch.jit.script_method def embedding_predict_(self, embedding): out = torch.matmul(embedding, self.weight3) out = out + self.bias3 return F.log_softmax(out, dim=1) @torch.jit.script_method def embedding_(self, x, first_edge_index, second_edge_index): # first layer row, col = second_edge_index[0], second_edge_index[1] out = scatter_mean(x[col], row, dim=0) # do not set dim_size out = torch.cat([x[0:out.size(0)], out], dim=1) out = torch.matmul(out, self.weight1) out = out + self.bias1 out = torch.relu(out) out = F.normalize(out, p=2.0, dim=-1) # second layer row, col = first_edge_index[0], first_edge_index[1] neighbors = scatter_mean(out[col], row, dim=0) # do not set dim_size out = torch.cat([out[0:neighbors.size(0)], neighbors], dim=1) out = torch.matmul(out, self.weight2) out = out + self.bias2 out = F.normalize(out, p=2.0, dim=-1) return out def acc(self, y_pred, y_true): y_true = y_true.view(-1).to(torch.long) return y_pred.max(1)[1].eq(y_true).sum().item() / y_pred.size(0)