Ejemplo n.º 1
0
def vcov(posterior):
    node_supports = extract_samples(posterior)
    packed_support = torch.cat([support.reshape(support.size(0), -1)
                                for support in node_supports.values()], dim=1)
    cov_scheme = WelfordCovariance(diagonal=False)
    for sample in packed_support:
        cov_scheme.update(sample)
    return cov_scheme.get_covariance(regularize=False)
Ejemplo n.º 2
0
def vcov(posterior):
    node_supports = posterior.marginal(
        posterior.exec_traces[0].stochastic_nodes).support()
    packed_support = torch.cat([
        support.reshape(support.size(0), -1)
        for support in node_supports.values()
    ],
                               dim=1)
    cov_scheme = WelfordCovariance(diagonal=False)
    for sample in packed_support:
        cov_scheme.update(sample)
    return cov_scheme.get_covariance().detach()
Ejemplo n.º 3
0
class CountMeanVarianceStats(StreamingStats):
    """
    Statistic tracking the count, mean, and (diagonal) variance of a single
    :class:`torch.Tensor`.
    """
    def __init__(self):
        self.shape = None
        self.welford = WelfordCovariance(diagonal=True)
        super().__init__()

    def update(self, sample: torch.Tensor) -> None:
        assert isinstance(sample, torch.Tensor)
        if self.shape is None:
            self.shape = sample.shape
        assert sample.shape == self.shape
        self.welford.update(sample.detach().reshape(-1))

    def merge(self,
              other: "CountMeanVarianceStats") -> "CountMeanVarianceStats":
        assert isinstance(other, type(self))
        if self.shape is None:
            return copy.deepcopy(other)
        if other.shape is None:
            return copy.deepcopy(self)
        result = copy.deepcopy(self)
        res = result.welford
        lhs = self.welford
        rhs = other.welford
        res.n_samples = lhs.n_samples + rhs.n_samples
        lhs_weight = lhs.n_samples / res.n_samples
        rhs_weight = rhs.n_samples / res.n_samples
        res._mean = lhs_weight * lhs._mean + rhs_weight * rhs._mean
        res._m2 = (lhs._m2 + rhs._m2 +
                   (lhs.n_samples * rhs.n_samples / res.n_samples) *
                   (lhs._mean - rhs._mean)**2)
        return result

    def get(self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        :returns: A dictionary with keys ``count: int`` and (if any samples
            have been collected) ``mean: torch.Tensor`` and ``variance:
            torch.Tensor``.
        :rtype: dict
        """
        if self.shape is None:
            return {"count": 0}
        count = self.welford.n_samples
        mean = self.welford._mean.reshape(self.shape)
        variance = self.welford.get_covariance(regularize=False).reshape(
            self.shape)
        return {"count": count, "mean": mean, "variance": variance}
Ejemplo n.º 4
0
def test_welford_dense(n_samples, dim_size):
    w = WelfordCovariance(diagonal=False)
    loc = torch.zeros(dim_size)
    cov = torch.randn(dim_size, dim_size)
    cov = torch.mm(cov, cov.t())
    dist = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov)
    samples = dist.sample(torch.Size([n_samples]))
    for sample in samples:
        w.update(sample)

    with optional(pytest.raises(RuntimeError), n_samples == 1):
        estimates = w.get_covariance(regularize=False).cpu().numpy()
        sample_cov = np.cov(samples.cpu().numpy(), bias=False, rowvar=False)
        assert_equal(estimates, sample_cov)
Ejemplo n.º 5
0
def test_welford_diagonal(n_samples, dim_size):
    w = WelfordCovariance(diagonal=True)
    loc = torch.zeros(dim_size)
    cov_diagonal = torch.rand(dim_size)
    cov = torch.diag(cov_diagonal)
    dist = torch.distributions.MultivariateNormal(loc=loc,
                                                  covariance_matrix=cov)
    samples = []
    for _ in range(n_samples):
        sample = dist.sample()
        samples.append(sample)
        w.update(sample)

    sample_variance = torch.stack(samples).var(dim=0, unbiased=True)
    estimates = w.get_covariance(regularize=False)
    assert_equal(estimates, sample_variance)
Ejemplo n.º 6
0
def test_welford_dense(n_samples, dim_size):
    w = WelfordCovariance(diagonal=False)
    loc = torch.zeros(dim_size)
    cov = torch.randn(dim_size, dim_size)
    cov = torch.mm(cov, cov.t())
    dist = torch.distributions.MultivariateNormal(loc=loc,
                                                  covariance_matrix=cov)
    samples = []
    for _ in range(n_samples):
        sample = dist.sample()
        samples.append(sample)
        w.update(sample)

    sample_cov = np.cov(torch.stack(samples).data.cpu().numpy(),
                        bias=False,
                        rowvar=False)
    estimates = w.get_covariance(regularize=False).data.cpu().numpy()
    assert_equal(estimates, sample_cov)
Ejemplo n.º 7
0
class WarmupAdapter(object):
    r"""
    Adapts tunable parameters, namely step size and mass matrix, during the
    warmup phase. This class provides lookup properties to read the latest
    values of ``step_size`` and ``inverse_mass_matrix``. These values are
    periodically updated when adaptation is engaged.
    """
    def __init__(self,
                 step_size=1,
                 adapt_step_size=False,
                 target_accept_prob=0.8,
                 adapt_mass_matrix=False,
                 is_diag_mass=True):
        self.adapt_step_size = adapt_step_size
        self.adapt_mass_matrix = adapt_mass_matrix
        self.target_accept_prob = target_accept_prob
        self.is_diag_mass = is_diag_mass
        self.step_size = 1 if step_size is None else step_size
        self._adaptation_disabled = not (adapt_step_size or adapt_mass_matrix)
        if adapt_step_size:
            self._step_size_adapt_scheme = DualAveraging()
        if adapt_mass_matrix:
            self._mass_matrix_adapt_scheme = WelfordCovariance(
                diagonal=is_diag_mass)

        # We separate warmup_steps into windows:
        #   start_buffer + window 1 + window 2 + window 3 + ... + end_buffer
        # where the length of each window will be doubled for the next window.
        # We won't adapt mass matrix during start and end buffers; and mass
        # matrix will be updated at the end of each window. This is helpful
        # for dealing with the intense computation of sampling momentum from the
        # inverse of mass matrix.
        self._adapt_start_buffer = 75  # from Stan
        self._adapt_end_buffer = 50  # from Stan
        self._adapt_initial_window = 25  # from Stan
        self._current_window = 0  # starting window index

        # configured later on setup
        self._warmup_steps = None
        self._inverse_mass_matrix = None
        self._r_dist = None
        self._adaptation_schedule = []

    def _build_adaptation_schedule(self):
        adaptation_schedule = []
        # from Stan, for small warmup_steps < 20
        if self._warmup_steps < 20:
            adaptation_schedule.append(adapt_window(0, self._warmup_steps - 1))
            return adaptation_schedule

        start_buffer_size = self._adapt_start_buffer
        end_buffer_size = self._adapt_end_buffer
        init_window_size = self._adapt_initial_window
        if (self._adapt_start_buffer + self._adapt_end_buffer +
                self._adapt_initial_window > self._warmup_steps):
            start_buffer_size = int(0.15 * self._warmup_steps)
            end_buffer_size = int(0.1 * self._warmup_steps)
            init_window_size = self._warmup_steps - start_buffer_size - end_buffer_size
        adaptation_schedule.append(
            adapt_window(start=0, end=start_buffer_size - 1))
        end_window_start = self._warmup_steps - end_buffer_size

        next_window_size = init_window_size
        next_window_start = start_buffer_size
        while next_window_start < end_window_start:
            cur_window_start, cur_window_size = next_window_start, next_window_size
            # Ensure that slow adaptation windows are monotonically increasing
            if 3 * cur_window_size <= end_window_start - cur_window_start:
                next_window_size = 2 * cur_window_size
            else:
                cur_window_size = end_window_start - cur_window_start
            next_window_start = cur_window_start + cur_window_size
            adaptation_schedule.append(
                adapt_window(cur_window_start, next_window_start - 1))
        adaptation_schedule.append(
            adapt_window(end_window_start, self._warmup_steps - 1))
        return adaptation_schedule

    def reset_step_size_adaptation(self, z):
        r"""
        Finds a reasonable step size and resets step size adaptation scheme.
        """
        if self._find_reasonable_step_size is not None:
            with pyro.validation_enabled(False):
                self.step_size = self._find_reasonable_step_size(z)
        self._step_size_adapt_scheme.prox_center = math.log(10 *
                                                            self.step_size)
        self._step_size_adapt_scheme.reset()

    def _update_step_size(self, accept_prob):
        # calculate a statistic for Dual Averaging scheme
        H = self.target_accept_prob - accept_prob
        self._step_size_adapt_scheme.step(H)
        log_step_size, _ = self._step_size_adapt_scheme.get_state()
        self.step_size = math.exp(log_step_size)

    def _update_r_dist(self):
        loc = torch.zeros(self._inverse_mass_matrix.size(0),
                          dtype=self._inverse_mass_matrix.dtype,
                          device=self._inverse_mass_matrix.device)
        if self.is_diag_mass:
            self._r_dist = dist.Normal(loc, self._inverse_mass_matrix.rsqrt())
        else:
            self._r_dist = dist.MultivariateNormal(
                loc, precision_matrix=self._inverse_mass_matrix)

    def _end_adaptation(self):
        if self.adapt_step_size:
            _, log_step_size_avg = self._step_size_adapt_scheme.get_state()
            self.step_size = math.exp(log_step_size_avg)

    def configure(self,
                  warmup_steps,
                  initial_step_size=None,
                  inv_mass_matrix=None,
                  find_reasonable_step_size_fn=None):
        r"""
        Model specific properties that are specified when the HMC kernel is setup.

        :param warmup_steps: Number of warmup steps that the sampler is initialized with.
        :param initial_step_size: Step size to use to initialize the Dual Averaging scheme.
        :param inv_mass_matrix: Initial value of the inverse mass matrix.
        :param find_reasonable_step_size_fn: A callable to find reasonable step size when
            mass matrix is changed.
        """
        self._warmup_steps = warmup_steps
        if initial_step_size is not None:
            self.step_size = initial_step_size
        if find_reasonable_step_size_fn is not None:
            self._find_reasonable_step_size = find_reasonable_step_size_fn
        if inv_mass_matrix is not None:
            self.inverse_mass_matrix = inv_mass_matrix
        if self.inverse_mass_matrix is None or self.step_size is None:
            raise ValueError(
                "Incomplete configuration - step size and inverse mass matrix "
                "need to be initialized.")
        if not self._adaptation_disabled:
            self._adaptation_schedule = self._build_adaptation_schedule()

    def step(self, t, z, accept_prob):
        r"""
        Called at each step during the warmup phase to learn tunable
        parameters.

        :param int t: time step, beginning at 0.
        :param dict z: latent variables.
        :param float accept_prob: acceptance probability of the proposal.
        """
        if t >= self._warmup_steps or self._adaptation_disabled:
            return
        window = self._adaptation_schedule[self._current_window]
        num_windows = len(self._adaptation_schedule)
        mass_matrix_adaptation_phase = self.adapt_mass_matrix and \
            (0 < self._current_window < num_windows - 1)
        if self.adapt_step_size:
            self._update_step_size(accept_prob.item())
        if mass_matrix_adaptation_phase:
            z_flat = torch.cat([z[name].reshape(-1) for name in sorted(z)])
            self._mass_matrix_adapt_scheme.update(z_flat.detach())
        if t == window.end:
            if self._current_window == num_windows - 1:
                self._current_window += 1
                self._end_adaptation()
                return

            if self._current_window == 0:
                self._current_window += 1
                return

            if mass_matrix_adaptation_phase:
                self.inverse_mass_matrix = self._mass_matrix_adapt_scheme.get_covariance(
                )
                if self.adapt_step_size:
                    self.reset_step_size_adaptation(z)

            self._current_window += 1

    @property
    def adaptation_schedule(self):
        return self._adaptation_schedule

    @property
    def inverse_mass_matrix(self):
        return self._inverse_mass_matrix

    @inverse_mass_matrix.setter
    def inverse_mass_matrix(self, value):
        self._inverse_mass_matrix = value
        self._update_r_dist()
        if self.adapt_mass_matrix:
            self._mass_matrix_adapt_scheme.reset()

    @property
    def r_dist(self):
        return self._r_dist