def test_simple( train_window, min_train_window, test_window, min_test_window, stride, warm_start ): duration = 30 obs_dim = 2 covariates = torch.zeros(duration, 0) data = torch.randn(duration, obs_dim) + 4 forecaster_options = {"num_steps": 2, "warm_start": warm_start} expect_error = warm_start and train_window is not None with optional(pytest.raises(ValueError), expect_error): windows = backtest( data, covariates, Model, train_window=train_window, min_train_window=min_train_window, test_window=test_window, min_test_window=min_test_window, stride=stride, forecaster_options=forecaster_options, ) if not expect_error: assert any(window["t0"] == 0 for window in windows) if stride == 1: assert any(window["t2"] == duration for window in windows) for window in windows: assert window["train_walltime"] >= 0 assert window["test_walltime"] >= 0 for name in DEFAULT_METRICS: assert name in window assert 0 < window[name] < math.inf
def test_num_chains(num_chains, cpu_count, default_init_params, monkeypatch): monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: cpu_count) data = torch.tensor([1.0]) initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data, ), num_chains=num_chains) if default_init_params: initial_params = None kernel = PriorKernel(normal_normal_model) available_cpu = max(1, cpu_count - 1) mp_context = "spawn" with optional(pytest.warns(UserWarning), available_cpu < num_chains): mcmc = MCMC( kernel, num_samples=10, warmup_steps=10, num_chains=num_chains, initial_params=initial_params, transforms=transforms, mp_context=mp_context, ) mcmc.run(data) assert mcmc.num_chains == num_chains if mcmc.num_chains == 1 or available_cpu < num_chains: assert isinstance(mcmc.sampler, _UnarySampler) else: assert isinstance(mcmc.sampler, _MultiSampler)
def test_predictive(num_samples, parallel): model, data, true_probs = beta_bernoulli() init_params, potential_fn, transforms, _ = initialize_model( model, model_args=(data, )) nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms) mcmc = MCMC(nuts_kernel, 100, initial_params=init_params, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() with ignore_experimental_warning(): with optional(pytest.warns(UserWarning), num_samples not in (None, 100)): predictive_samples = predictive(model, samples, num_samples=num_samples, return_sites=["beta", "obs"], parallel=parallel) # check shapes assert predictive_samples["beta"].shape == (100, 5) assert predictive_samples["obs"].shape == (100, 1000, 5) # check sample mean assert_close(predictive_samples["obs"].reshape([-1, 5]).mean(0), true_probs, rtol=0.1)
def _potential_energy_jit(self, z): names, vals = zip(*sorted(z.items())) if self._compiled_potential_fn: return self._compiled_potential_fn(*vals) def compiled(*zi): z_constrained = list(zi) # transform to constrained space. for i, name in enumerate(names): if name in self.transforms: transform = self.transforms[name] z_constrained[i] = transform.inv(z_constrained[i]) z_constrained = dict(zip(names, z_constrained)) trace = self._get_trace(z_constrained) potential_energy = -self._compute_trace_log_prob(trace) # adjust by the jacobian for this transformation. for i, name in enumerate(names): if name in self.transforms: transform = self.transforms[name] potential_energy += transform.log_abs_det_jacobian( z_constrained[name], zi[i]).sum() return potential_energy with pyro.validation_enabled(False), optional( ignore_jit_warnings(), self._ignore_jit_warnings): self._compiled_potential_fn = torch.jit.trace(compiled, vals, check_trace=False) return self._compiled_potential_fn(*vals)
def __call__(self, *args, **kwargs): key = _hashable_args_kwargs(args, kwargs) # if first time if key not in self.compiled: # param capture with poutine.block(): with poutine.trace(param_only=True) as first_param_capture: self.fn(*args, **kwargs) self._param_names = list( set(first_param_capture.trace.nodes.keys())) unconstrained_params = tuple( pyro.param(name).unconstrained() for name in self._param_names) params_and_args = unconstrained_params + args weakself = weakref.ref(self) def compiled(*params_and_args): self = weakself() unconstrained_params = params_and_args[:len(self._param_names)] args = params_and_args[len(self._param_names):] constrained_params = {} for name, unconstrained_param in zip(self._param_names, unconstrained_params): constrained_param = pyro.param( name) # assume param has been initialized assert constrained_param.unconstrained( ) is unconstrained_param constrained_params[name] = constrained_param return poutine.replay(self.fn, params=constrained_params)(*args, **kwargs) if self.ignore_warnings: compiled = ignore_jit_warnings()(compiled) with pyro.validation_enabled(False): time_compilation = self.jit_options.pop( "time_compilation", False) with optional(timed(), time_compilation) as t: self.compiled[key] = torch.jit.trace( compiled, params_and_args, **self.jit_options) if time_compilation: self.compile_time = t.elapsed else: unconstrained_params = [ pyro.param(name).unconstrained() for name in self._param_names ] params_and_args = unconstrained_params + list(args) with poutine.block(hide=self._param_names): with poutine.trace(param_only=True) as param_capture: ret = self.compiled[key](*params_and_args) for name in param_capture.trace.nodes.keys(): if name not in self._param_names: raise NotImplementedError( "pyro.ops.jit.trace assumes all params are created on " "first invocation, but found new param: {}".format(name)) return ret
def test_num_chains(num_chains, cpu_count, monkeypatch): monkeypatch.setattr(torch.multiprocessing, 'cpu_count', lambda: cpu_count) kernel = PriorKernel(normal_normal_model) available_cpu = max(1, cpu_count - 1) with optional(pytest.warns(UserWarning), available_cpu < num_chains): mcmc = MCMC(kernel, num_samples=10, num_chains=num_chains) assert mcmc.num_chains == min(num_chains, available_cpu) if mcmc.num_chains == 1: assert isinstance(mcmc.sampler, _SingleSampler) else: assert isinstance(mcmc.sampler, _ParallelSampler)
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return params r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), scalar_like(accept_prob, 1.))) accepted = False if rand < accept_prob: accepted = True z = z_new z_grads = z_grads_new self._cache(z, potential_energy_new, z_grads) self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def run(self, *args, **kwargs): """ Run StreamingMCMC to compute required `self._statistics`. """ self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains with optional( pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None, ): args = [ arg.detach() if torch.is_tensor(arg) else arg for arg in args ] for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: self._set_transforms(*args, **kwargs) num_samples[chain_id] += 1 z_structure = x elif num_samples[chain_id] == self.num_samples + 1: self._diagnostics[chain_id] = x else: num_samples[chain_id] += 1 if self.num_chains > 1: x_cloned = x.clone() del x else: x_cloned = x # unpack latent pos = 0 z_acc = z_structure.copy() for k in sorted(z_structure): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = x_cloned[pos:next_pos].reshape(shape) pos = next_pos for name, z in z_acc.items(): if name in self.transforms: z_acc[name] = self.transforms[name].inv(z) self._statistics.update({ (chain_id, name): transformed_sample for name, transformed_sample in z_acc.items() }) # terminate the sampler (shut down worker processes) self.sampler.terminate(True)
def sample(self, trace): z = { name: node["value"].detach() for name, node in self._iter_latent_nodes(trace) } # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, _ = self._sample_r(name="r_t={}".format(self._t)) potential_energy, z_grads = self._fetch_from_cache() # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self._potential_energy, self.inverse_mass_matrix, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. energy_proposal = self._kinetic_energy( r_new) + potential_energy_new energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) delta_energy = energy_proposal - energy_current # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < accept_prob: self._accept_cnt += 1 z = z_new if self._t < self._warmup_steps: self._adapter.step(self._t, z, accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def test_welford_dense(n_samples, dim_size): w = WelfordCovariance(diagonal=False) loc = torch.zeros(dim_size) cov = torch.randn(dim_size, dim_size) cov = torch.mm(cov, cov.t()) dist = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov) samples = dist.sample(torch.Size([n_samples])) for sample in samples: w.update(sample) with optional(pytest.raises(RuntimeError), n_samples == 1): estimates = w.get_covariance(regularize=False).cpu().numpy() sample_cov = np.cov(samples.cpu().numpy(), bias=False, rowvar=False) assert_equal(estimates, sample_cov)
def test_ubersum_collide_implemented(impl, implemented): # Non-tree plates cause exponential blowup, # so ubersum() refuses to evaluate them. # # z {a,b} # / \ # x {a} y {b} # \ / # {} <--- target a, b, c, d = 2, 3, 4, 5 x = torch.randn(a, c) y = torch.randn(b, d) z = torch.randn(a, b, c, d) raises = pytest.raises(NotImplementedError, match='Expected tree-structured plate nesting') with optional(raises, not implemented): impl('ac,bd,abcd->', x, y, z, plates='ab', modulo_total=True)
def test_welford_diagonal(n_samples, dim_size): w = WelfordCovariance(diagonal=True) loc = torch.zeros(dim_size) cov_diagonal = torch.rand(dim_size) cov = torch.diag(cov_diagonal) dist = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov) samples = [] for _ in range(n_samples): sample = dist.sample() samples.append(sample) w.update(sample) sample_variance = torch.stack(samples).var(dim=0, unbiased=True) with optional(pytest.raises(RuntimeError), n_samples == 1): estimates = w.get_covariance(regularize=False) assert_equal(estimates, sample_variance)
def _traces(self, *args, **kwargs): logger_id = kwargs.pop("logger_id", "") log_queue = kwargs.pop("log_queue", None) self.logger = logging.getLogger("pyro.infer.mcmc") is_multiprocessing = log_queue is not None progress_bar = None if not is_multiprocessing: progress_bar = initialize_progbar(self.warmup_steps, self.num_samples, disable=self.disable_progbar) self.logger = initialize_logger(self.logger, logger_id, progress_bar, log_queue) self.kernel.setup(self.warmup_steps, *args, **kwargs) trace = self.kernel.initial_trace with optional(progress_bar, not is_multiprocessing): for trace in self._gen_samples(self.warmup_steps, trace): continue if progress_bar: progress_bar.set_description("Sample") for trace in self._gen_samples(self.num_samples, trace): yield (trace, 1.0) self.kernel.cleanup()
def run(self, *args, **kwargs): """ Run MCMC to generate samples and populate `self._samples`. Example usage: .. code-block:: python def model(data): ... nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=500) mcmc.run(data) samples = mcmc.get_samples() :param args: optional arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. :param kwargs: optional keywords arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. """ self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains z_flat_acc = [[] for _ in range(self.num_chains)] with optional(pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None): for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1 z_structure = x elif num_samples[chain_id] == self.num_samples + 1: self._diagnostics[chain_id] = x else: num_samples[chain_id] += 1 if self.num_chains > 1: x_cloned = x.clone() del x else: x_cloned = x z_flat_acc[chain_id].append(x_cloned) z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc]) # unpack latent pos = 0 z_acc = z_structure.copy() for k in sorted(z_structure): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = z_flat_acc[:, :, pos:next_pos].reshape((self.num_chains, self.num_samples) + shape) pos = next_pos assert pos == z_flat_acc.shape[-1] # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: # Use `kernel.transforms` when available if getattr(self.kernel, "transforms", None) is not None: self.transforms = self.kernel.transforms # Else, get transforms from model (e.g. in multiprocessing). elif self.kernel.model: warmup_steps = 0 self.kernel.setup(warmup_steps, *args, **kwargs) self.transforms = self.kernel.transforms # Assign default value else: self.transforms = {} # transform samples back to constrained space for name, transform in self.transforms.items(): z_acc[name] = transform.inv(z_acc[name]) self._samples = z_acc # terminate the sampler (shut down worker processes) self.sampler.terminate(True)
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params potential_energy = self.potential_fn(z) self._cache(z, potential_energy) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return z r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), dist.Exponential(scalar_like(energy_current, 1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat sum_accept_probs = 0. num_proposals = 0 tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs num_proposals = num_proposals + new_tree.num_proposals # stop doubling if new_tree.diverging: if self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) break if new_tree.turning: break tree_depth += 1 if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(scalar_like(new_tree_prob, 0.), scalar_like(new_tree_prob, 1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = _logaddexp(tree_weight, new_tree.weight) else: tree_weight = tree_weight + new_tree.weight accept_prob = sum_accept_probs / num_proposals self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def run(self, *args, **kwargs): """ Run MCMC to generate samples and populate `self._samples`. Example usage: .. code-block:: python def model(data): ... nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_samples=500) mcmc.run(data) samples = mcmc.get_samples() :param args: optional arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. :param kwargs: optional keywords arguments taken by :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`. """ self._args, self._kwargs = args, kwargs num_samples = [0] * self.num_chains z_flat_acc = [[] for _ in range(self.num_chains)] with optional(pyro.validation_enabled(not self.disable_validation), self.disable_validation is not None): # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" # at https://github.com/pytorch/pytorch/issues/10375 # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which # requires_grad", which happens with `jit_compile` under PyTorch 1.7 args = [ arg.detach() if torch.is_tensor(arg) else arg for arg in args ] for x, chain_id in self.sampler.run(*args, **kwargs): if num_samples[chain_id] == 0: num_samples[chain_id] += 1 z_structure = x elif num_samples[chain_id] == self.num_samples + 1: self._diagnostics[chain_id] = x else: num_samples[chain_id] += 1 if self.num_chains > 1: x_cloned = x.clone() del x else: x_cloned = x z_flat_acc[chain_id].append(x_cloned) z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc]) # unpack latent pos = 0 z_acc = z_structure.copy() for k in sorted(z_structure): shape = z_structure[k] next_pos = pos + shape.numel() z_acc[k] = z_flat_acc[:, :, pos:next_pos].reshape((self.num_chains, self.num_samples) + shape) pos = next_pos assert pos == z_flat_acc.shape[-1] # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: # Use `kernel.transforms` when available if getattr(self.kernel, "transforms", None) is not None: self.transforms = self.kernel.transforms # Else, get transforms from model (e.g. in multiprocessing). elif self.kernel.model: warmup_steps = 0 self.kernel.setup(warmup_steps, *args, **kwargs) self.transforms = self.kernel.transforms # Assign default value else: self.transforms = {} # transform samples back to constrained space for name, transform in self.transforms.items(): z_acc[name] = transform.inv(z_acc[name]) self._samples = z_acc # terminate the sampler (shut down worker processes) self.sampler.terminate(True)
def sample(self, trace): z, potential_energy, z_grads = self._fetch_from_cache() r, _ = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(z, r, self._potential_energy, self.inverse_mass_matrix, self.step_size, self.batch_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. energy_proposal = self._kinetic_energy(r_new) + potential_energy_new delta_energy = energy_proposal - energy_current # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1.) rand = torch.rand(self.batch_size) accepted = rand < accept_prob self._accept_cnt += accepted.sum()/self.batch_size # select accepted zs to get z_new transitioned_z = {} for name in z: assert len(z_grads[name].shape) == 2 assert z_grads[name].shape[0] == self.batch_size assert len(z[name].shape) == 2 assert z[name].shape[0] == self.batch_size old_val = z[name] old_grad = z_grads[name] new_val = z[name] new_grad = z_grads_new[name] val_dim = old_val.shape[1] accept_val = accepted.view(self.batch_size, 1).repeat(1, val_dim) transitioned_z[name] = torch.where(accept_val, new_val, old_val) transitioned_grads = torch.where(accept_val, new_grad, old_grad) self._cache(transitioned_z, potential_energy, transitioned_grads) if self._t < self._warmup_steps: self._adapter.step(self._t, transitioned_z, accept_prob) self._t += 1 # get trace with the constrained values for `z`. z = transitioned_z.copy() for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def sample(self, trace): z = { name: node["value"].detach() for name, node in self._iter_latent_nodes(trace) } potential_energy, z_grads = self._fetch_from_cache() # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample( "slicevar_exp_t={}".format(self._t), dist.Exponential(energy_current.new_tensor(1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat if self.use_multinomial_sampling: tree_weight = energy_current.new_zeros(()) else: tree_weight = energy_current.new_ones(()) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging for tree_depth in range(self._max_tree_depth + 1): direction = pyro.sample( "direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=torch.ones(1) * 0.5)) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads if new_tree.turning or new_tree.diverging: # stop doubling break if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample( "rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(new_tree_prob.new_tensor(0.), new_tree_prob.new_tensor(1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = logsumexp(torch.stack( [tree_weight, new_tree.weight]), dim=0) else: tree_weight = tree_weight + new_tree.weight if self._t < self._warmup_steps: accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals self._adapter.step(self._t, z, accept_prob) if accepted: self._accept_cnt += 1 self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)