class RelationGCN(torch.jit.ScriptModule):
    def __init__(self, input_dim, hidden_dim, n_relations, n_bases, n_class):
        super(RelationGCN, self).__init__()
        self.conv1 = RGCNConv(input_dim, hidden_dim, n_relations, n_bases)
        self.conv2 = RGCNConv(hidden_dim, hidden_dim, n_relations, n_bases)
        self.weight = Parameter(torch.zeros(hidden_dim, n_class))
        self.bias = Parameter(torch.zeros(n_class))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias.dim() > 1:
            torch.nn.init.xavier_uniform_(self.bias)

    @torch.jit.script_method
    def forward_(self, x, first_edge_index, second_edge_index, first_edge_type,
                 second_edge_type):
        # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor
        x = self.embedding_(x, first_edge_index, second_edge_index,
                            first_edge_type, second_edge_type)
        x = torch.matmul(x, self.weight)
        x = x + self.bias
        return F.log_softmax(x, dim=1)

    @torch.jit.script_method
    def loss(self, y_pred, y_true):
        y_true = y_true.view(-1).to(torch.long)
        return F.nll_loss(y_pred, y_true)

    @torch.jit.script_method
    def predict_(self, x, first_edge_index, second_edge_index, first_edge_type,
                 second_edge_type):
        # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor
        output = self.forward_(x, first_edge_index, second_edge_index,
                               first_edge_type, second_edge_type)
        return output.max(1)[1]

    @torch.jit.script_method
    def embedding_(self, x, first_edge_index, second_edge_index,
                   first_edge_type, second_edge_type):
        # type: (Optional[Tensor], Tensor, Tensor, Tensor, Tensor) -> Tensor
        x = F.relu(self.conv1(x, second_edge_index, second_edge_type, None))
        x = self.conv2(x, first_edge_index, first_edge_type, None)
        return x
def _compress_module_param_dim(
    param: Parameter,
    target_dim: int,
    idxs_to_keep: Tensor,
    module: Optional[Module] = None,
    optimizer: Optional[Optimizer] = None,
):
    if param.dim() == 1:
        target_dim = 0

    if param.size(target_dim) == 1 and idxs_to_keep.numel() > 1:
        # DW Conv
        return

    if param.size(target_dim) % idxs_to_keep.size(0) != 0:
        _LOGGER.debug(
            "skipping compression of parameter due to shape incompatibility")

    stride = param.data.size(target_dim) // idxs_to_keep.size(0)
    if stride > 1:
        idxs_to_keep = idxs_to_keep.reshape(-1, 1).expand(-1,
                                                          stride).reshape(-1)

    param.data = (param.data[idxs_to_keep, ...]
                  if target_dim == 0 else param.data[:, idxs_to_keep, ...])

    if param.grad is not None:
        param.grad = (param.grad[idxs_to_keep, ...]
                      if target_dim == 0 else param.grad[:, idxs_to_keep, ...])

    if (optimizer is not None and param in optimizer.state
            and ("momentum_buffer" in optimizer.state[param])):
        optimizer.state[param]["momentum_buffer"] = (
            optimizer.state[param]["momentum_buffer"][idxs_to_keep,
                                                      ...] if target_dim == 0
            else optimizer.state[param]["momentum_buffer"][:, idxs_to_keep,
                                                           ...])

    # update module attrs
    if module is not None:
        # Batch Norm
        if param.dim() == 1:
            if hasattr(module, "num_features"):
                module.num_features = param.size(0)
            # BN running mean and var are not stored as Parameters so we must
            # update them here
            if hasattr(module, "running_mean") and (module.running_mean.size(0)
                                                    == idxs_to_keep.size(0)):
                module.running_mean = module.running_mean[idxs_to_keep]
            if hasattr(module, "running_var") and (module.running_var.size(0)
                                                   == idxs_to_keep.size(0)):
                module.running_var = module.running_var[idxs_to_keep]

        # Linear
        elif target_dim == 0 and hasattr(module, "out_features"):
            module.out_features = param.size(0)
        elif target_dim == 1 and hasattr(module, "in_features"):
            module.in_features = param.size(1)
        # Conv
        elif target_dim == 0 and hasattr(module, "out_channels"):
            module.out_channels = param.size(0)
        elif target_dim == 1 and hasattr(module, "in_channels"):
            module.in_channels = param.size(1)

        if (hasattr(module, "groups") and module.groups > 1
                and (hasattr(module, "out_channels")
                     and hasattr(module, "in_channels"))):
            module.groups = param.size(0) // param.size(1)
Exemple #3
0
class ADKL_KRR_net(MetaNetwork):
    TRAIN = 0
    DESCR = 1
    BOTH = 2
    NB_KERNEL_PARAMS = 1

    def __init__(self,
                 input_features_extractor_params,
                 target_features_extractor_params,
                 condition_on='train',
                 task_descr_extractor_params=None,
                 dataset_encoder_params=None,
                 hp_mode='fixe',
                 l2=0.1,
                 device='cuda',
                 task_encoder_reg=0.,
                 n_pseudo_inputs=0,
                 pseudo_inputs_reg=0,
                 stationary_kernel=False):
        """
        In the constructor we instantiate an lstm module
        """
        super(ADKL_KRR_net, self).__init__()

        if condition_on.lower() in ['train', 'train_samples']:
            assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified'
            self.condition_on = self.TRAIN
        elif condition_on.lower() in ['descr', 'task_descr']:
            assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified'
            self.condition_on = self.DESCR
        elif condition_on.lower() in ['both']:
            assert dataset_encoder_params is not None, 'dataset_encoder_params must be specified'
            assert task_descr_extractor_params is not None, 'task_descr_extractor_params must be specified'
            self.condition_on = self.BOTH
        else:
            raise ValueError('Invalid option for parameter condition_on')

        self.task_encoder_reg = task_encoder_reg

        if input_features_extractor_params.get('pooling_fn', 0) is None:
            pooling = GlobalAvgPool1d(dim=1)
            spectral_kernel = True
        else:
            pooling = None
            spectral_kernel = False

        task_encoder = TaskEncoderNet(
            input_features_extractor_params,
            target_features_extractor_params,
            dataset_encoder_params,
            complement_module_input_fextractor=pooling)
        self.features_extractor = task_encoder.input_fextractor
        fe_dim = self.features_extractor.output_dim

        self.task_descr_extractor = None
        tde_dim, de_dim = 0, 0
        if self.condition_on in [self.DESCR, self.BOTH]:
            self.task_descr_extractor = FeaturesExtractorFactory()(
                **task_descr_extractor_params)
            tde_dim = self.task_descr_extractor.output_dim
        if self.condition_on in [self.TRAIN, self.BOTH]:
            self.dataset_encoder = task_encoder
            de_dim = self.dataset_encoder.output_dim

        self.l2 = l2
        self.pseudo_inputs_reg = pseudo_inputs_reg
        self.hp_mode = hp_mode
        self.device = device
        if not stationary_kernel:
            self.kernel_network = NonStationaryKernel(fe_dim, de_dim + tde_dim,
                                                      fe_dim, spectral_kernel)
        else:
            self.kernel_network = StationaryKernel(fe_dim, de_dim + tde_dim,
                                                   fe_dim, spectral_kernel)

        if n_pseudo_inputs > 0:
            if spectral_kernel:
                self.pseudo_inputs = Parameter(
                    torch.Tensor(n_pseudo_inputs, fe_dim)).to(device)
            else:
                self.pseudo_inputs = Parameter(
                    torch.Tensor(n_pseudo_inputs, fe_dim)).to(device)
        else:
            self.pseudo_inputs = None
        self.phis_train_mean, self.phis_train_std = 0, 0

        if hp_mode.lower() in ['learn', 'learned', 'l']:
            self.hp_mode = 'l'
        elif hp_mode.lower() in [
                'predicted', 'predict', 'p', 't', 'task-specific', 'per-task'
        ]:
            self.hp_mode = 't'
            d = (de_dim if de_dim else 0) + (tde_dim if tde_dim else 0)
            self.hp_net = Linear(d, self.NB_KERNEL_PARAMS)
        else:
            raise Exception('hp_mode should be one of those: fixe, learn, cv')
        self._init_kernel_params(device)

    def _init_kernel_params(self, device):
        if self.pseudo_inputs is not None:
            init.kaiming_uniform_(self.pseudo_inputs, a=math.sqrt(5))
        self.l2 = torch.FloatTensor([self.l2]).to(device)

        if self.hp_mode == 'l':
            self.l2 = Parameter(self.l2)

    def compute_batch_gram_matrix(self, x, y, task_phis):
        k_ = self.kernel_network(x, y, task_phis)
        if self.pseudo_inputs is not None:
            ps = self.pseudo_inputs.unsqueeze(0).expand(
                x.shape[0], *self.pseudo_inputs.shape)
            k_g = self.kernel_network(x, ps, task_phis)
            k_ = torch.cat((k_, k_g), dim=-1)
        return k_

    def set_kernel_params(self, task_phis):
        if self.hp_mode == 't':
            self.l2 = self.hp_net(task_phis).squeeze(-1)

        l2 = hardtanh(self.l2.exp(), 1e-4, 1e1)
        return l2

    def add_pseudo_inputs_loss(self, loss):
        n = self.pseudo_inputs.shape[0]
        d = self.pseudo_inputs.shape[-1]
        p = self.pseudo_inputs.reshape(-1, d)

        # reg = torch.exp(-0.5 * (p.unsqueeze(2) - p.unsqueeze(1)).pow(2).sum(-1))
        # reg = torch.tril(reg).sum() / (n * (n - 1))

        if self.pseudo_inputs.dim() == 2:
            pi_mean = torch.mean(self.pseudo_inputs, dim=0)
            pi_std = torch.std(self.pseudo_inputs, dim=0)
        elif self.pseudo_inputs.dim() == 3:
            pi_mean = torch.mean(self.pseudo_inputs, dim=(0, 1)),
            pi_std = torch.std(self.pseudo_inputs, dim=(0, 1))
        else:
            raise Exception(
                'Pseudo inputs: the number of dimensions is incorrect')

        kl = kl_divergence(
            MultivariateNormal(pi_mean, torch.diag(pi_std + 0.1)),
            MultivariateNormal(self.phis_train_mean,
                               torch.diag(self.phis_train_std + 0.1)))

        res = self.pseudo_inputs_reg * kl
        return loss + res, res

    def add_task_encoder_loss(self, loss):
        reg = self.task_encoder_reg * self.task_encoder_loss
        return loss + reg, reg

    def get_alphas(self, phis, ys, masks, task_phis=None):
        l2 = self.set_kernel_params(task_phis)
        bsize, n_train = phis.shape[:2]

        k_ = self.compute_batch_gram_matrix(phis, phis, task_phis=task_phis)
        k = torch.bmm(k_, k_.transpose(1, 2))
        k_mask = masks[:, None, :] * masks[:, :, None]
        k = k * k_mask

        identity = torch.eye(n_train, device=k.device).unsqueeze(0).expand(
            (bsize, n_train, n_train))
        batch_K_inv = torch.inverse(k +
                                    l2.unsqueeze(1).unsqueeze(1) * identity)
        alphas = torch.bmm(batch_K_inv, ys)

        return alphas, k_

    def get_preds(self,
                  alphas,
                  K_train,
                  phis_train,
                  masks_train,
                  phis_test,
                  masks_test,
                  task_phis=None):
        k = self.compute_batch_gram_matrix(phis_test,
                                           phis_train,
                                           task_phis=task_phis)
        k = torch.bmm(k, K_train.transpose(1, 2))
        k_mask = masks_test[:, :, None] * masks_train[:, None, :]
        k = k * k_mask

        preds = torch.bmm(k, alphas)
        return preds

    def get_task_phis(self, tasks_descr, xs_train, ys_train, mask_train):
        if self.condition_on == self.DESCR:
            task_phis = self.task_descr_extractor(tasks_descr)
        elif self.condition_on == self.TRAIN:
            task_phis = self.dataset_encoder(None, xs_train, ys_train,
                                             mask_train)
        else:
            task_phis = torch.cat([
                self.task_descr_extractor(tasks_descr),
                self.dataset_encoder(None, xs_train, ys_train, mask_train)
            ],
                                  dim=1)
        return task_phis

    def get_phis(self, xs, train=False):
        phis = self.features_extractor(xs.reshape(-1, xs.shape[2]))
        if train:
            alpha = 0.8
            if phis.dim() == 2:
                self.phis_train_mean = (
                    1 - alpha) * self.phis_train_mean + alpha * torch.mean(
                        phis, dim=0).detach()
                self.phis_train_std = (
                    1 - alpha) * self.phis_train_std + alpha * torch.std(
                        phis, dim=0).detach()
            elif phis.dim() == 3:
                self.phis_train_mean = (
                    1 - alpha) * self.phis_train_mean + alpha * torch.mean(
                        phis, dim=(0, 1)).detach()
                self.phis_train_std = (
                    1 - alpha) * self.phis_train_std + alpha * torch.std(
                        phis, dim=(0, 1)).detach()

        phis = phis.reshape(*xs.shape[:2], *phis.shape[1:])
        return phis

    def forward(self, episodes):
        if self.condition_on == self.TRAIN:
            train, test = pack_episodes(episodes, return_tasks_descr=False)
            xs_train, ys_train, lens_train, mask_train = train
            xs_test, lens_test, mask_test = test
            task_phis = self.get_task_phis(None, xs_train, ys_train,
                                           mask_train)
        else:
            train, test, tasks_descr = pack_episodes(episodes,
                                                     return_tasks_descr=True)
            xs_train, ys_train, lens_train, mask_train = train
            xs_test, lens_test, mask_test = test
            task_phis = self.get_task_phis(tasks_descr, xs_train, ys_train,
                                           mask_train)

        phis_train, phis_test = self.get_phis(
            xs_train, train=True), self.get_phis(xs_test, train=False)

        # training
        alphas, K_train = self.get_alphas(phis_train, ys_train, mask_train,
                                          task_phis)

        # testing
        preds = self.get_preds(alphas, K_train, phis_train, mask_train,
                               phis_test, mask_test, task_phis)

        self.compute_task_encoder_loss_last_batch(episodes)

        if isinstance(preds, tuple):
            return [
                tuple(x[:n] for x in pred)
                for n, pred in zip(lens_test, preds)
            ]
        else:
            return [x[:n] for n, x in zip(lens_test, preds)]

    def compute_task_encoder_loss_last_batch(self, episodes):
        # train x test
        set_code = self.dataset_encoder(episodes)

        y_preds_class = torch.arange(len(set_code))
        if set_code.is_cuda:
            y_preds_class = y_preds_class.to('cuda')

        accuracy = (set_code.argmax(dim=1)
                    == y_preds_class).sum().item() / len(set_code)

        b = set_code.size(0)
        mi = set_code.diagonal().mean() \
             - torch.log((set_code * (1 - torch.eye(b))).exp().sum() / (b * (b - 1)))
        loss = -mi

        self.accuracy = accuracy
        self.task_encoder_loss = loss
Exemple #4
0
class VariationalGP(GPModel):
    def __init__(self,
                 X,
                 y,
                 kernel,
                 likelihood,
                 mean_function=None,
                 latent_shape=None,
                 whiten=False,
                 jitter=1e-6,
                 use_cuda=False):
        super().__init__(X, y, kernel, mean_function, jitter)

        self.likelihood = likelihood
        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size(
            [])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        N = self.X.size(0)
        f_loc = self.X.new_zeros(self.latent_shape + (N, ))
        self.f_loc = Parameter(f_loc)

        identity = eye_like(self.X, N)
        f_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky)

        self.whiten = whiten
        self._sample_latent = True
        if use_cuda:
            self.cuda()

    @pyro_method
    def model(self):
        self.set_mode("model")
        N = self.X.size(0)
        Kff = self.kernel(self.X).contiguous()
        Kff.view(-1)[::N + 1] += self.jitter  # add jitter to the diagonal
        Lff = Kff.cholesky()

        zero_loc = self.X.new_zeros(self.f_loc.shape)
        if self.whiten:
            identity = eye_like(self.X, N)
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc,
                    scale_tril=identity).to_event(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(self.f_scale_tril)
            f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
        else:
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1))
            f_scale_tril = self.f_scale_tril
            f_loc = self.f_loc

        f_loc = f_loc + self.mean_function(self.X)
        f_var = f_scale_tril.pow(2).sum(dim=-1)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)

    @pyro_method
    def guide(self):
        self.set_mode("guide")
        self._load_pyro_samples()

        pyro.sample(
            self._pyro_get_fullname("f"),
            dist.MultivariateNormal(
                self.f_loc,
                scale_tril=self.f_scale_tril).to_event(self.f_loc.dim() - 1))

    def forward(self, Xnew, full_cov=False):
        r"""
        Computes the mean and covariance matrix (or variance) of Gaussian Process
        posterior on a test input data :math:`X_{new}`:
        .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril})
            = \mathcal{N}(loc, cov).
        .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with
            kernel's parameters have been learned from a training procedure (MCMC or
            SVI).
        :param torch.Tensor Xnew: A input data for testing. Note that
            ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
        :param bool full_cov: A flag to decide if we want to predict full covariance
            matrix or just variance.
        :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        self._check_Xnew_shape(Xnew)
        self.set_mode("guide")

        loc, cov = conditional(Xnew,
                               self.X,
                               self.kernel,
                               self.f_loc,
                               self.f_scale_tril,
                               full_cov=full_cov,
                               whiten=self.whiten,
                               jitter=self.jitter)
        return loc + self.mean_function(Xnew), cov
Exemple #5
0
class Table(Module):
    r""" Defines a Table of parameters of an insurance contract or an assumption
    with padding mechanism supported throw `pad_mode` and `pad_value`.

    .. math::

        \text{out}_{i, j} = \text{table}_{\text{index}_i, j}

    All Table are inherited from this class.

    Attributes:
        - :attr:`name` (str)
           the name of Table.
        - :attr:`table` (Tensor)
           the table will be indexed, the dim of the table should be not more
           than 2. Padding can be used for long tedious table.
        - :attr:`n_col` (int)
           the total column number of the table after padding.

    """
    def __init__(self,
                 name: str,
                 table: Tensor,
                 n_col: int = None,
                 *,
                 pad_value: float = 0,
                 pad_mode=0):
        """
        :param str name: the name of Table
        :param Tensor table: raw table
        :param int n_col: if None(default), then no padding
            will be act on the input `table`, if provided the n_col is the
            total column number of the table after padding.
        :param float pad_value: value needed for the `pad_mode` to work,
            for example pad_value is the valued filled
            if `pad_mode=PadMode.Constant`.
        :param Union[int, PadMode] pad_mode: how to perform the padding,
            Constant padding by default.
        """
        super().__init__()
        self.name = name.strip()
        self.table = Parameter(table)
        self.n_col = n_col
        if self.table.nelement() == 1:
            self._need_lookup = False
        elif self.table.dim() == 1:
            self.n_col = self.table.nelement()
            self._need_lookup = False
        elif n_col and n_col > self.table.shape[1]:
            self._need_lookup = True
        else:
            self.n_col = self.table.shape[1]

        self.pad_value = pad_value
        self.pad_mode = pad_mode

    def forward(self, index: Tensor):
        """
        :param index: 1-D Tensor, index for index select
        :return: rows at `index`
        """
        table = pad(self.table, self.n_col, self.pad_value, self.pad_mode)
        if self._need_lookup:
            return torch.index_select(table, 0, index.long())
        else:
            return table.expand(index.nelement(), self.n_col)
Exemple #6
0
class VariationalGP(GPModel):
    r"""
    Variational Gaussian Process model.

    This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\
    :math:`X` and their noisy observations :math:`y`, the model takes the form

    .. math::
        f &\sim \mathcal{GP}(0, k(X, X)),\\
        y & \sim p(y) = p(y \mid f) p(f),

    where :math:`p(y \mid f)` is the likelihood.

    We will use a variational approach in this model by approximating :math:`q(f)` to
    the posterior :math:`p(f\mid y)`. Precisely, :math:`q(f)` will be a multivariate
    normal distribution with two parameters ``f_loc`` and ``f_scale_tril``, which will
    be learned during a variational inference process.

    .. note:: This model can be seen as a special version of
        :class:`.SparseVariationalGP` model with :math:`X_u = X`.

    .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training,
        :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number
        of train inputs. Size of variational parameters is :math:`\mathcal{O}(N^2)`.

    :param torch.Tensor X: A input data for training. Its first dimension is the number
        of data points.
    :param torch.Tensor y: An output data for training. Its last dimension is the
        number of data points.
    :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which
        is the covariance function :math:`k`.
    :param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood
        object.
    :param callable mean_function: An optional mean function :math:`m` of this Gaussian
        process. By default, we use zero mean.
    :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of
        :math:`q(f)`). By default, it equals to output batch shape ``y.shape[:-1]``.
        For the multi-class classification problems, ``latent_shape[-1]`` should
        corresponse to the number of classes.
    :param bool whiten: A flag to tell if variational parameters ``f_loc`` and
        ``f_scale_tril`` are transformed by the inverse of ``Lff``, where ``Lff`` is
        the lower triangular decomposition of :math:`kernel(X, X)`. Enable this flag
        will help optimization.
    :param float jitter: A small positive term which is added into the diagonal part of
        a covariance matrix to help stablize its Cholesky decomposition.
    """
    def __init__(self, X, y, kernel, likelihood, mean_function=None,
                 latent_shape=None, whiten=False, jitter=1e-6):
        super(VariationalGP, self).__init__(X, y, kernel, mean_function, jitter)

        self.likelihood = likelihood

        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        N = self.X.size(0)
        f_loc = self.X.new_zeros(self.latent_shape + (N,))
        self.f_loc = Parameter(f_loc)

        identity = eye_like(self.X, N)
        f_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.f_scale_tril = Parameter(f_scale_tril)
        self.set_constraint("f_scale_tril", constraints.lower_cholesky)

        self.whiten = whiten
        self._sample_latent = True

    @autoname.scope(prefix="VGP")
    def model(self):
        self.set_mode("model")

        N = self.X.size(0)
        Kff = self.kernel(self.X).contiguous()
        Kff.view(-1)[::N + 1] += self.jitter  # add jitter to the diagonal
        Lff = Kff.cholesky()

        zero_loc = self.X.new_zeros(self.f_loc.shape)
        if self.whiten:
            identity = eye_like(self.X, N)
            pyro.sample("f",
                        dist.MultivariateNormal(zero_loc, scale_tril=identity)
                            .to_event(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(self.f_scale_tril)
            f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
        else:
            pyro.sample("f",
                        dist.MultivariateNormal(zero_loc, scale_tril=Lff)
                            .to_event(zero_loc.dim() - 1))
            f_scale_tril = self.f_scale_tril
            f_loc = self.f_loc

        f_loc = f_loc + self.mean_function(self.X)
        f_var = f_scale_tril.pow(2).sum(dim=-1)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)

    @autoname.scope(prefix="VGP")
    def guide(self):
        self.set_mode("guide")

        pyro.sample("f",
                    dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril)
                        .to_event(self.f_loc.dim()-1))

    def forward(self, Xnew, full_cov=False):
        r"""
        Computes the mean and covariance matrix (or variance) of Gaussian Process
        posterior on a test input data :math:`X_{new}`:

        .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril})
            = \mathcal{N}(loc, cov).

        .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with
            kernel's parameters have been learned from a training procedure (MCMC or
            SVI).

        :param torch.Tensor Xnew: A input data for testing. Note that
            ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
        :param bool full_cov: A flag to decide if we want to predict full covariance
            matrix or just variance.
        :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        self._check_Xnew_shape(Xnew)
        self.set_mode("guide")

        loc, cov = conditional(Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril,
                               full_cov=full_cov, whiten=self.whiten, jitter=self.jitter)
        return loc + self.mean_function(Xnew), cov
Exemple #7
0
class VariationalSparseGP(GPModel):
    r"""
    Variational Sparse Gaussian Process model.

    In :class:`.VariationalGP` model, when the number of input data :math:`X` is large,
    the covariance matrix :math:`k(X, X)` will require a lot of computational steps to
    compute its inverse (for log likelihood and for prediction). This model introduces
    an additional inducing-input parameter :math:`X_u` to solve that problem. Given
    inputs :math:`X`, their noisy observations :math:`y`, and the inducing-input
    parameters :math:`X_u`, the model takes the form:

    .. math::
        [f, u] &\sim \mathcal{GP}(0, k([X, X_u], [X, X_u])),\\
        y & \sim p(y) = p(y \mid f) p(f),

    where :math:`p(y \mid f)` is the likelihood.

    We will use a variational approach in this model by approximating :math:`q(f,u)`
    to the posterior :math:`p(f,u \mid y)`. Precisely, :math:`q(f) = p(f\mid u)q(u)`,
    where :math:`q(u)` is a multivariate normal distribution with two parameters
    ``u_loc`` and ``u_scale_tril``, which will be learned during a variational
    inference process.

    .. note:: This model can be learned using MCMC method as in reference [2]. See also
        :class:`.GPModel`.

    .. note:: This model has :math:`\mathcal{O}(NM^2)` complexity for training,
        :math:`\mathcal{O}(M^3)` complexity for testing. Here, :math:`N` is the number
        of train inputs, :math:`M` is the number of inducing inputs. Size of
        variational parameters is :math:`\mathcal{O}(M^2)`.

    References:

    [1] `Scalable variational Gaussian process classification`,
    James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani

    [2] `MCMC for Variationally Sparse Gaussian Processes`,
    James Hensman, Alexander G. de G. Matthews, Maurizio Filippone, Zoubin Ghahramani

    :param torch.Tensor X: A input data for training. Its first dimension is the number
        of data points.
    :param torch.Tensor y: An output data for training. Its last dimension is the
        number of data points.
    :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which
        is the covariance function :math:`k`.
    :param torch.Tensor Xu: Initial values for inducing points, which are parameters
        of our model.
    :param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood
        object.
    :param callable mean_function: An optional mean function :math:`m` of this Gaussian
        process. By default, we use zero mean.
    :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of
        :math:`q(u)`). By default, it equals to output batch shape ``y.shape[:-1]``.
        For the multi-class classification problems, ``latent_shape[-1]`` should
        corresponse to the number of classes.
    :param int num_data: The size of full training dataset. It is useful for training
        this model with mini-batch.
    :param bool whiten: A flag to tell if variational parameters ``u_loc`` and
        ``u_scale_tril`` are transformed by the inverse of ``Luu``, where ``Luu`` is
        the lower triangular decomposition of :math:`kernel(X_u, X_u)`. Enable this
        flag will help optimization.
    :param float jitter: A small positive term which is added into the diagonal part of
        a covariance matrix to help stablize its Cholesky decomposition.
    """
    def __init__(self,
                 X,
                 y,
                 kernel,
                 Xu,
                 likelihood,
                 mean_function=None,
                 latent_shape=None,
                 num_data=None,
                 whiten=False,
                 jitter=1e-6):
        super(VariationalSparseGP, self).__init__(X, y, kernel, mean_function,
                                                  jitter)

        self.likelihood = likelihood
        self.Xu = Parameter(Xu)

        y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size(
            [])
        self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

        M = self.Xu.size(0)
        u_loc = self.Xu.new_zeros(self.latent_shape + (M, ))
        self.u_loc = Parameter(u_loc)

        identity = eye_like(self.Xu, M)
        u_scale_tril = identity.repeat(self.latent_shape + (1, 1))
        self.u_scale_tril = Parameter(u_scale_tril)
        self.set_constraint("u_scale_tril", constraints.lower_cholesky)

        self.num_data = num_data if num_data is not None else self.X.size(0)
        self.whiten = whiten
        self._sample_latent = True

    @autoname.scope(prefix="VSGP")
    def model(self):
        self.set_mode("model")

        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        Kuu.view(-1)[::M + 1] += self.jitter  # add jitter to the diagonal
        Luu = Kuu.cholesky()

        zero_loc = self.Xu.new_zeros(self.u_loc.shape)
        if self.whiten:
            identity = eye_like(self.Xu, M)
            pyro.sample(
                "u",
                dist.MultivariateNormal(
                    zero_loc,
                    scale_tril=identity).to_event(zero_loc.dim() - 1))
        else:
            pyro.sample(
                "u",
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Luu).to_event(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X,
                                   self.Xu,
                                   self.kernel,
                                   self.u_loc,
                                   self.u_scale_tril,
                                   Luu,
                                   full_cov=False,
                                   whiten=self.whiten,
                                   jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            with poutine.scale(scale=self.num_data / self.X.size(0)):
                return self.likelihood(f_loc, f_var, self.y)

    @autoname.scope(prefix="VSGP")
    def guide(self):
        self.set_mode("guide")

        pyro.sample(
            "u",
            dist.MultivariateNormal(
                self.u_loc,
                scale_tril=self.u_scale_tril).to_event(self.u_loc.dim() - 1))

    def forward(self, Xnew, full_cov=False):
        r"""
        Computes the mean and covariance matrix (or variance) of Gaussian Process
        posterior on a test input data :math:`X_{new}`:

        .. math:: p(f^* \mid X_{new}, X, y, k, X_u, u_{loc}, u_{scale\_tril})
            = \mathcal{N}(loc, cov).

        .. note:: Variational parameters ``u_loc``, ``u_scale_tril``, the
            inducing-point parameter ``Xu``, together with kernel's parameters have
            been learned from a training procedure (MCMC or SVI).

        :param torch.Tensor Xnew: A input data for testing. Note that
            ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
        :param bool full_cov: A flag to decide if we want to predict full covariance
            matrix or just variance.
        :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
        :rtype: tuple(torch.Tensor, torch.Tensor)
        """
        self._check_Xnew_shape(Xnew)
        self.set_mode("guide")

        loc, cov = conditional(Xnew,
                               self.Xu,
                               self.kernel,
                               self.u_loc,
                               self.u_scale_tril,
                               full_cov=full_cov,
                               whiten=self.whiten,
                               jitter=self.jitter)
        return loc + self.mean_function(Xnew), cov
Exemple #8
0
class SAGETwoSoftmax(torch.jit.ScriptModule):
    def __init__(self, in_dim, hidden, out_dim):
        super(SAGETwoSoftmax, self).__init__()
        self.weight1 = Parameter(torch.zeros(in_dim * 2, hidden))
        self.bias1 = Parameter(torch.zeros(hidden))

        self.weight2 = Parameter(torch.zeros(hidden * 2, hidden))
        self.bias2 = Parameter(torch.zeros(hidden))

        self.weight3 = Parameter(torch.zeros(hidden, out_dim))
        self.bias3 = Parameter(torch.zeros(out_dim))

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight1)
        torch.nn.init.xavier_uniform_(self.weight2)
        torch.nn.init.xavier_uniform_(self.weight3)
        if self.bias1.dim() > 1:
            torch.nn.init.xavier_uniform_(self.bias1)
        if self.bias2.dim() > 1:
            torch.nn.init.xavier_uniform_(self.bias2)
        if self.bias3.dim() > 1:
            torch.nn.init.xavier_uniform_(self.bias3)

    @torch.jit.script_method
    def forward_(self, x, first_edge_index, second_edge_index):
        embedding = self.embedding_(x, first_edge_index, second_edge_index)
        out = torch.matmul(embedding, self.weight3)
        out = out + self.bias3
        return F.log_softmax(out, dim=1)

    @torch.jit.script_method
    def loss(self, y_pred, y_true):
        y_true = y_true.view(-1).to(torch.long)
        return F.nll_loss(y_pred, y_true)

    @torch.jit.script_method
    def predict_(self, x, first_edge_index, second_edge_index):
        output = self.forward_(x, first_edge_index, second_edge_index)
        return output.max(1)[1]

    @torch.jit.script_method
    def embedding_predict_(self, embedding):
        out = torch.matmul(embedding, self.weight3)
        out = out + self.bias3
        return F.log_softmax(out, dim=1)

    @torch.jit.script_method
    def embedding_(self, x, first_edge_index, second_edge_index):
        # first layer
        row, col = second_edge_index[0], second_edge_index[1]
        out = scatter_mean(x[col], row, dim=0)  # do not set dim_size
        out = torch.cat([x[0:out.size(0)], out], dim=1)
        out = torch.matmul(out, self.weight1)
        out = out + self.bias1
        out = torch.relu(out)
        out = F.normalize(out, p=2.0, dim=-1)

        # second layer
        row, col = first_edge_index[0], first_edge_index[1]
        neighbors = scatter_mean(out[col], row, dim=0)  # do not set dim_size
        out = torch.cat([out[0:neighbors.size(0)], neighbors], dim=1)
        out = torch.matmul(out, self.weight2)
        out = out + self.bias2
        out = F.normalize(out, p=2.0, dim=-1)
        return out

    def acc(self, y_pred, y_true):
        y_true = y_true.view(-1).to(torch.long)
        return y_pred.max(1)[1].eq(y_true).sum().item() / y_pred.size(0)