def __init__(self,
                 size_in,
                 prior_factor=1.0,
                 weight_prior_std=1.0,
                 bias_prior_std=3.0,
                 **kwargs):

        self._params = OrderedDict()
        self._param_dists = OrderedDict()

        self.prior_factor = prior_factor
        self.gp = VectorizedGP(size_in, **kwargs)

        for name, shape in self.gp.parameter_shapes().items():

            if name == 'constant_mean':
                mean_p_loc = torch.zeros(1).to(device)
                mean_p_scale = torch.ones(1).to(device)
                self._param_dist(name,
                                 Normal(mean_p_loc, mean_p_scale).to_event(1))

            if name == 'lengthscale_raw':
                lengthscale_p_loc = torch.zeros(shape[-1]).to(device)
                lengthscale_p_scale = torch.ones(shape[-1]).to(device)
                self._param_dist(
                    name,
                    Normal(lengthscale_p_loc, lengthscale_p_scale).to_event(1))

            if name == 'noise_raw':
                noise_p_loc = -1. * torch.ones(1).to(device)
                noise_p_scale = torch.ones(1).to(device)
                self._param_dist(
                    name,
                    Normal(noise_p_loc, noise_p_scale).to_event(1))

            if 'mean_nn' in name or 'kernel_nn' in name:
                mean = torch.zeros(shape).to(device)
                if "weight" in name:
                    std = weight_prior_std * torch.ones(shape).to(device)
                elif "bias" in name:
                    std = bias_prior_std * torch.ones(shape).to(device)
                else:
                    raise NotImplementedError
                self._param_dist(name, Normal(mean, std).to_event(1))

        # check that parameters in prior and gp modules are aligned
        for param_name_gp, param_name_prior in zip(
                self.gp.named_parameters().keys(), self._param_dists.keys()):
            assert param_name_gp == param_name_prior

        self.hyper_prior = CatDist(self._param_dists.values())
Example #2
0
    def test_sampling2(self):
        torch.manual_seed(22)
        dist1 = pyro.distributions.Normal(torch.ones(5),
                                          0.01 * torch.ones(5, )).to_event(1)
        dist2 = pyro.distributions.Normal(-1 * torch.ones(3),
                                          0.01 * torch.ones(3, )).to_event(1)

        catdist = CatDist([dist1, dist2])
        sample = catdist.rsample((100, ))
        assert sample.shape == (100, 5 + 3)

        sample1_mean = sample[:, :5].mean().item()
        sample2_mean = sample[:, 5:].mean().item()
        assert np.abs(sample1_mean - 1) < 0.2
        assert np.abs(sample2_mean + 1) < 0.2
Example #3
0
    def test_pdf(self):
        torch.manual_seed(22)
        dist1 = pyro.distributions.Normal(torch.ones(7),
                                          0.01 * torch.ones(7, )).to_event(1)
        dist2 = pyro.distributions.Normal(-1 * torch.ones(3),
                                          0.01 * torch.ones(3, )).to_event(1)
        catdist = CatDist([dist1, dist2])

        x1 = torch.normal(1.0, 0.01, size=(150, 10))
        x2 = torch.normal(-1.0, 0.01, size=(150, 10))
        x3 = torch.cat([x1[:, :7], x2[:, 7:]], dim=-1)

        logp1 = catdist.log_prob(x1)
        logp3 = catdist.log_prob(x3)
        assert torch.mean(logp1).item() < -1000
        assert torch.mean(logp3).item() > 10
Example #4
0
    def test_pdf2(self):
        torch.manual_seed(22)
        dist1 = pyro.distributions.Normal(torch.ones(7),
                                          0.01 * torch.ones(7, )).to_event(1)
        dist2 = pyro.distributions.Normal(-1 * torch.ones(3),
                                          0.01 * torch.ones(3, )).to_event(1)
        catdist1 = CatDist([dist1, dist2], reduce_event_dim=False)
        catdist2 = CatDist([dist1, dist2], reduce_event_dim=True)

        x1 = torch.normal(1.0, 0.01, size=(150, 10))
        x2 = torch.normal(-1.0, 0.01, size=(150, 10))
        x3 = torch.cat([x1[:, :7], x2[:, 7:]], dim=-1)

        logp1 = torch.sum(catdist1.log_prob(x3), dim=0).numpy()
        logp2 = catdist2.log_prob(x3).numpy()
        assert np.array_equal(logp1, logp2)
class _RandomGPBase:
    def __init__(self,
                 size_in,
                 prior_factor=1.0,
                 weight_prior_std=1.0,
                 bias_prior_std=3.0,
                 **kwargs):

        self._params = OrderedDict()
        self._param_dists = OrderedDict()

        self.prior_factor = prior_factor
        self.gp = VectorizedGP(size_in, **kwargs)

        for name, shape in self.gp.parameter_shapes().items():

            if name == 'constant_mean':
                mean_p_loc = torch.zeros(1).to(device)
                mean_p_scale = torch.ones(1).to(device)
                self._param_dist(name,
                                 Normal(mean_p_loc, mean_p_scale).to_event(1))

            if name == 'lengthscale_raw':
                lengthscale_p_loc = torch.zeros(shape[-1]).to(device)
                lengthscale_p_scale = torch.ones(shape[-1]).to(device)
                self._param_dist(
                    name,
                    Normal(lengthscale_p_loc, lengthscale_p_scale).to_event(1))

            if name == 'noise_raw':
                noise_p_loc = -1. * torch.ones(1).to(device)
                noise_p_scale = torch.ones(1).to(device)
                self._param_dist(
                    name,
                    Normal(noise_p_loc, noise_p_scale).to_event(1))

            if 'mean_nn' in name or 'kernel_nn' in name:
                mean = torch.zeros(shape).to(device)
                if "weight" in name:
                    std = weight_prior_std * torch.ones(shape).to(device)
                elif "bias" in name:
                    std = bias_prior_std * torch.ones(shape).to(device)
                else:
                    raise NotImplementedError
                self._param_dist(name, Normal(mean, std).to_event(1))

        # check that parameters in prior and gp modules are aligned
        for param_name_gp, param_name_prior in zip(
                self.gp.named_parameters().keys(), self._param_dists.keys()):
            assert param_name_gp == param_name_prior

        self.hyper_prior = CatDist(self._param_dists.values())

    def sample_params_from_prior(self, shape=torch.Size()):
        return self.hyper_prior.sample(shape)

    def sample_fn_from_prior(self, shape=torch.Size()):
        params = self.sample_params_from_prior(shape=shape)
        return self.get_forward_fn(params)

    def get_forward_fn(self, params):
        gp_model = copy.deepcopy(self.gp)
        gp_model.set_parameters_as_vector(params)
        return gp_model

    def _param_dist(self, name, dist):
        assert type(name) == str
        assert isinstance(dist, torch.distributions.Distribution)
        assert name not in list(self._param_dists.keys())
        assert hasattr(dist, 'rsample')
        self._param_dists[name] = dist
        return dist

    def _log_prob_prior(self, params):
        return self.hyper_prior.log_prob(params)

    def _log_prob_likelihood(self, *args):
        raise NotImplementedError

    def log_prob(self, *args):
        raise NotImplementedError

    def parameter_shapes(self):
        param_shapes_dict = OrderedDict()
        for name, dist in self._param_dists.items():
            param_shapes_dict[name] = dist.event_shape
        return param_shapes_dict