def __init__(self, step_size=1, adapt_step_size=False, target_accept_prob=0.8, adapt_mass_matrix=False, is_diag_mass=True): self.adapt_step_size = adapt_step_size self.adapt_mass_matrix = adapt_mass_matrix self.target_accept_prob = target_accept_prob self.is_diag_mass = is_diag_mass self.step_size = 1 if step_size is None else step_size self._adaptation_disabled = not (adapt_step_size or adapt_mass_matrix) if adapt_step_size: self._step_size_adapt_scheme = DualAveraging() if adapt_mass_matrix: self._mass_matrix_adapt_scheme = WelfordCovariance( diagonal=is_diag_mass) # We separate warmup_steps into windows: # start_buffer + window 1 + window 2 + window 3 + ... + end_buffer # where the length of each window will be doubled for the next window. # We won't adapt mass matrix during start and end buffers; and mass # matrix will be updated at the end of each window. This is helpful # for dealing with the intense computation of sampling momentum from the # inverse of mass matrix. self._adapt_start_buffer = 75 # from Stan self._adapt_end_buffer = 50 # from Stan self._adapt_initial_window = 25 # from Stan self._current_window = 0 # starting window index # configured later on setup self._warmup_steps = None self._inverse_mass_matrix = None self._r_dist = None self._adaptation_schedule = []
def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = { name: node["value"] for name, node in trace.iter_stochastic_nodes() } for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc)
def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()} for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc)
class WarmupAdapter(object): r""" Adapts tunable parameters, namely step size and mass matrix, during the warmup phase. This class provides lookup properties to read the latest values of ``step_size`` and ``inverse_mass_matrix``. These values are periodically updated when adaptation is engaged. """ def __init__(self, step_size=1, adapt_step_size=False, target_accept_prob=0.8, adapt_mass_matrix=False, is_diag_mass=True): self.adapt_step_size = adapt_step_size self.adapt_mass_matrix = adapt_mass_matrix self.target_accept_prob = target_accept_prob self.is_diag_mass = is_diag_mass self.step_size = 1 if step_size is None else step_size self._adaptation_disabled = not (adapt_step_size or adapt_mass_matrix) if adapt_step_size: self._step_size_adapt_scheme = DualAveraging() if adapt_mass_matrix: self._mass_matrix_adapt_scheme = WelfordCovariance( diagonal=is_diag_mass) # We separate warmup_steps into windows: # start_buffer + window 1 + window 2 + window 3 + ... + end_buffer # where the length of each window will be doubled for the next window. # We won't adapt mass matrix during start and end buffers; and mass # matrix will be updated at the end of each window. This is helpful # for dealing with the intense computation of sampling momentum from the # inverse of mass matrix. self._adapt_start_buffer = 75 # from Stan self._adapt_end_buffer = 50 # from Stan self._adapt_initial_window = 25 # from Stan self._current_window = 0 # starting window index # configured later on setup self._warmup_steps = None self._inverse_mass_matrix = None self._r_dist = None self._adaptation_schedule = [] def _build_adaptation_schedule(self): adaptation_schedule = [] # from Stan, for small warmup_steps < 20 if self._warmup_steps < 20: adaptation_schedule.append(adapt_window(0, self._warmup_steps - 1)) return adaptation_schedule start_buffer_size = self._adapt_start_buffer end_buffer_size = self._adapt_end_buffer init_window_size = self._adapt_initial_window if (self._adapt_start_buffer + self._adapt_end_buffer + self._adapt_initial_window > self._warmup_steps): start_buffer_size = int(0.15 * self._warmup_steps) end_buffer_size = int(0.1 * self._warmup_steps) init_window_size = self._warmup_steps - start_buffer_size - end_buffer_size adaptation_schedule.append( adapt_window(start=0, end=start_buffer_size - 1)) end_window_start = self._warmup_steps - end_buffer_size next_window_size = init_window_size next_window_start = start_buffer_size while next_window_start < end_window_start: cur_window_start, cur_window_size = next_window_start, next_window_size # Ensure that slow adaptation windows are monotonically increasing if 3 * cur_window_size <= end_window_start - cur_window_start: next_window_size = 2 * cur_window_size else: cur_window_size = end_window_start - cur_window_start next_window_start = cur_window_start + cur_window_size adaptation_schedule.append( adapt_window(cur_window_start, next_window_start - 1)) adaptation_schedule.append( adapt_window(end_window_start, self._warmup_steps - 1)) return adaptation_schedule def reset_step_size_adaptation(self, z): r""" Finds a reasonable step size and resets step size adaptation scheme. """ if self._find_reasonable_step_size is not None: with pyro.validation_enabled(False): self.step_size = self._find_reasonable_step_size(z) self._step_size_adapt_scheme.prox_center = math.log(10 * self.step_size) self._step_size_adapt_scheme.reset() def _update_step_size(self, accept_prob): # calculate a statistic for Dual Averaging scheme H = self.target_accept_prob - accept_prob self._step_size_adapt_scheme.step(H) log_step_size, _ = self._step_size_adapt_scheme.get_state() self.step_size = math.exp(log_step_size) def _update_r_dist(self): loc = torch.zeros(self._inverse_mass_matrix.size(0), dtype=self._inverse_mass_matrix.dtype, device=self._inverse_mass_matrix.device) if self.is_diag_mass: self._r_dist = dist.Normal(loc, self._inverse_mass_matrix.rsqrt()) else: self._r_dist = dist.MultivariateNormal( loc, precision_matrix=self._inverse_mass_matrix) def _end_adaptation(self): if self.adapt_step_size: _, log_step_size_avg = self._step_size_adapt_scheme.get_state() self.step_size = math.exp(log_step_size_avg) def configure(self, warmup_steps, initial_step_size=None, inv_mass_matrix=None, find_reasonable_step_size_fn=None): r""" Model specific properties that are specified when the HMC kernel is setup. :param warmup_steps: Number of warmup steps that the sampler is initialized with. :param initial_step_size: Step size to use to initialize the Dual Averaging scheme. :param inv_mass_matrix: Initial value of the inverse mass matrix. :param find_reasonable_step_size_fn: A callable to find reasonable step size when mass matrix is changed. """ self._warmup_steps = warmup_steps if initial_step_size is not None: self.step_size = initial_step_size if find_reasonable_step_size_fn is not None: self._find_reasonable_step_size = find_reasonable_step_size_fn if inv_mass_matrix is not None: self.inverse_mass_matrix = inv_mass_matrix if self.inverse_mass_matrix is None or self.step_size is None: raise ValueError( "Incomplete configuration - step size and inverse mass matrix " "need to be initialized.") if not self._adaptation_disabled: self._adaptation_schedule = self._build_adaptation_schedule() def step(self, t, z, accept_prob): r""" Called at each step during the warmup phase to learn tunable parameters. :param int t: time step, beginning at 0. :param dict z: latent variables. :param float accept_prob: acceptance probability of the proposal. """ if t >= self._warmup_steps or self._adaptation_disabled: return window = self._adaptation_schedule[self._current_window] num_windows = len(self._adaptation_schedule) mass_matrix_adaptation_phase = self.adapt_mass_matrix and \ (0 < self._current_window < num_windows - 1) if self.adapt_step_size: self._update_step_size(accept_prob.item()) if mass_matrix_adaptation_phase: z_flat = torch.cat([z[name].reshape(-1) for name in sorted(z)]) self._mass_matrix_adapt_scheme.update(z_flat.detach()) if t == window.end: if self._current_window == num_windows - 1: self._current_window += 1 self._end_adaptation() return if self._current_window == 0: self._current_window += 1 return if mass_matrix_adaptation_phase: self.inverse_mass_matrix = self._mass_matrix_adapt_scheme.get_covariance( ) if self.adapt_step_size: self.reset_step_size_adaptation(z) self._current_window += 1 @property def adaptation_schedule(self): return self._adaptation_schedule @property def inverse_mass_matrix(self): return self._inverse_mass_matrix @inverse_mass_matrix.setter def inverse_mass_matrix(self, value): self._inverse_mass_matrix = value self._update_r_dist() if self.adapt_mass_matrix: self._mass_matrix_adapt_scheme.reset() @property def r_dist(self): return self._r_dist
class HMC(TraceKernel): """ Simple Hamiltonian Monte Carlo kernel, where ``step_size`` and ``num_steps`` need to be explicitly specified by the user. **References** [1] `MCMC Using Hamiltonian Dynamics`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param float trajectory_length: Length of a MCMC trajectory. If not specified, it will be set to ``step_size x num_steps``. In case ``num_steps`` is not specified, it will be set to :math:`2\pi`. :param int num_steps: The number of discrete steps over which to simulate Hamiltonian dynamics. The state at the end of the trajectory is returned as the proposal. This value is always equal to ``int(trajectory_length / step_size)``. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) >>> mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) >>> posterior = EmpiricalMarginal(mcmc_run, 'beta') >>> posterior.mean # doctest: +SKIP tensor([ 0.9819, 1.9258, 2.9737]) """ def __init__(self, model, step_size=None, trajectory_length=None, num_steps=None, adapt_step_size=False, transforms=None): self.model = model self.step_size = step_size if step_size is not None else 1 # from Stan if trajectory_length is not None: self.trajectory_length = trajectory_length elif num_steps is not None: self.trajectory_length = self.step_size * num_steps else: self.trajectory_length = 2 * math.pi # from Stan self.num_steps = max(1, int(self.trajectory_length / self.step_size)) self.adapt_step_size = adapt_step_size self._target_accept_prob = 0.8 # from Stan self.transforms = {} if transforms is None else transforms self._automatic_transform_enabled = True if transforms is None else False self._reset() super(HMC, self).__init__() def _get_trace(self, z): z_trace = self._prototype_trace for name, value in z.items(): z_trace.nodes[name]["value"] = value trace_poutine = poutine.trace(poutine.replay(self.model, trace=z_trace)) trace_poutine(*self._args, **self._kwargs) return trace_poutine.trace def _kinetic_energy(self, r): return 0.5 * sum(x.pow(2).sum() for x in r.values()) def _potential_energy(self, z): # Since the model is specified in the constrained space, transform the # unconstrained R.V.s `z` to the constrained space. z_constrained = z.copy() for name, transform in self.transforms.items(): z_constrained[name] = transform.inv(z_constrained[name]) trace = self._get_trace(z_constrained) potential_energy = -trace.log_prob_sum() # adjust by the jacobian for this transformation. for name, transform in self.transforms.items(): potential_energy += transform.log_abs_det_jacobian( z_constrained[name], z[name]).sum() return potential_energy def _energy(self, z, r): return self._kinetic_energy(r) + self._potential_energy(z) def _reset(self): self._t = 0 self._accept_cnt = 0 self._r_dist = OrderedDict() self._args = None self._kwargs = None self._prototype_trace = None self._adapt_phase = False self._adapted_scheme = None def _find_reasonable_step_size(self, z): step_size = self.step_size # NOTE: This target_accept_prob is 0.5 in NUTS paper, is 0.8 in Stan, # and is different to the target_accept_prob for Dual Averaging scheme. # We need to discuss which one is better. target_accept_logprob = math.log(self._target_accept_prob) # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. r = { name: pyro.sample("r_{}_presample".format(name), self._r_dist[name]) for name in self._r_dist } energy_current = self._energy(z, r) z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( z, r, self._potential_energy, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample # site). direction = 1 if target_accept_logprob < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2**direction direction_new = direction # keep scale step_size until accept_prob crosses its target # TODO: make thresholds for too small step_size or too large step_size while direction_new == direction: step_size = step_size_scale * step_size z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( z, r, self._potential_energy, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current direction_new = 1 if target_accept_logprob < -delta_energy else -1 return step_size def _adapt_step_size(self, accept_prob): # calculate a statistic for Dual Averaging scheme H = self._target_accept_prob - accept_prob self._adapted_scheme.step(H) log_step_size, _ = self._adapted_scheme.get_state() self.step_size = math.exp(log_step_size) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) def _validate_trace(self, trace): trace_log_prob_sum = trace.log_prob_sum() if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum): raise ValueError( "Model specification incorrect - trace log pdf is NaN or Inf.") def initial_trace(self): return self._prototype_trace def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = { name: node["value"] for name, node in trace.iter_stochastic_nodes() } for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc) def end_warmup(self): if self.adapt_step_size: self._adapt_phase = False _, log_step_size_avg = self._adapted_scheme.get_state() self.step_size = math.exp(log_step_size_avg) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) def cleanup(self): self._reset() def sample(self, trace): z = { name: node["value"].detach() for name, node in trace.iter_stochastic_nodes() } # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r = { name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name]) for name in self._r_dist } # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled( ) with dist.validation_enabled(dist_arg_check): z_new, r_new = velocity_verlet(z, r, self._potential_energy, self.step_size, self.num_steps) # apply Metropolis correction. energy_proposal = self._energy(z_new, r_new) energy_current = self._energy(z, r) delta_energy = energy_proposal - energy_current rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < (-delta_energy).exp(): self._accept_cnt += 1 z = z_new if self._adapt_phase: # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1).item() self._adapt_step_size(accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z) def diagnostics(self): return "Step size: {:.6f} \t Acceptance rate: {:.6f}".format( self.step_size, self._accept_cnt / self._t)
class HMC(TraceKernel): """ Simple Hamiltonian Monte Carlo kernel, where ``step_size`` and ``num_steps`` need to be explicitly specified by the user. **References** [1] `MCMC Using Hamiltonian Dynamics`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param float trajectory_length: Length of a MCMC trajectory. If not specified, it will be set to ``step_size x num_steps``. In case ``num_steps`` is not specified, it will be set to :math:`2\pi`. :param int num_steps: The number of discrete steps over which to simulate Hamiltonian dynamics. The state at the end of the trajectory is returned as the proposal. This value is always equal to ``int(trajectory_length / step_size)``. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) >>> mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) >>> posterior = EmpiricalMarginal(mcmc_run, 'beta') >>> posterior.mean # doctest: +SKIP tensor([ 0.9819, 1.9258, 2.9737]) """ def __init__(self, model, step_size=None, trajectory_length=None, num_steps=None, adapt_step_size=False, transforms=None): self.model = model self.step_size = step_size if step_size is not None else 1 # from Stan if trajectory_length is not None: self.trajectory_length = trajectory_length elif num_steps is not None: self.trajectory_length = self.step_size * num_steps else: self.trajectory_length = 2 * math.pi # from Stan self.num_steps = max(1, int(self.trajectory_length / self.step_size)) self.adapt_step_size = adapt_step_size self._target_accept_prob = 0.8 # from Stan self.transforms = {} if transforms is None else transforms self._automatic_transform_enabled = True if transforms is None else False self._reset() super(HMC, self).__init__() def _get_trace(self, z): z_trace = self._prototype_trace for name, value in z.items(): z_trace.nodes[name]["value"] = value trace_poutine = poutine.trace(poutine.replay(self.model, trace=z_trace)) trace_poutine(*self._args, **self._kwargs) return trace_poutine.trace def _kinetic_energy(self, r): return 0.5 * sum(x.pow(2).sum() for x in r.values()) def _potential_energy(self, z): # Since the model is specified in the constrained space, transform the # unconstrained R.V.s `z` to the constrained space. z_constrained = z.copy() for name, transform in self.transforms.items(): z_constrained[name] = transform.inv(z_constrained[name]) trace = self._get_trace(z_constrained) potential_energy = -trace.log_prob_sum() # adjust by the jacobian for this transformation. for name, transform in self.transforms.items(): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum() return potential_energy def _energy(self, z, r): return self._kinetic_energy(r) + self._potential_energy(z) def _reset(self): self._t = 0 self._accept_cnt = 0 self._r_dist = OrderedDict() self._args = None self._kwargs = None self._prototype_trace = None self._adapt_phase = False self._adapted_scheme = None def _find_reasonable_step_size(self, z): step_size = self.step_size # NOTE: This target_accept_prob is 0.5 in NUTS paper, is 0.8 in Stan, # and is different to the target_accept_prob for Dual Averaging scheme. # We need to discuss which one is better. target_accept_logprob = math.log(self._target_accept_prob) # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. r = {name: pyro.sample("r_{}_presample".format(name), self._r_dist[name]) for name in self._r_dist} energy_current = self._energy(z, r) z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( z, r, self._potential_energy, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample # site). direction = 1 if target_accept_logprob < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2 ** direction direction_new = direction # keep scale step_size until accept_prob crosses its target # TODO: make thresholds for too small step_size or too large step_size while direction_new == direction: step_size = step_size_scale * step_size z_new, r_new, z_grads, potential_energy = single_step_velocity_verlet( z, r, self._potential_energy, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current direction_new = 1 if target_accept_logprob < -delta_energy else -1 return step_size def _adapt_step_size(self, accept_prob): # calculate a statistic for Dual Averaging scheme H = self._target_accept_prob - accept_prob self._adapted_scheme.step(H) log_step_size, _ = self._adapted_scheme.get_state() self.step_size = math.exp(log_step_size) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) def _validate_trace(self, trace): trace_log_prob_sum = trace.log_prob_sum() if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum): raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.") def initial_trace(self): return self._prototype_trace def setup(self, *args, **kwargs): self._args = args self._kwargs = kwargs # set the trace prototype to inter-convert between trace object # and dict object used by the integrator trace = poutine.trace(self.model).get_trace(*args, **kwargs) self._prototype_trace = trace if self._automatic_transform_enabled: self.transforms = {} for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) r_loc = site_value.new_zeros(site_value.shape) r_scale = site_value.new_ones(site_value.shape) self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale) self._validate_trace(trace) if self.adapt_step_size: self._adapt_phase = True z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()} for name, transform in self.transforms.items(): z[name] = transform(z[name]) self.step_size = self._find_reasonable_step_size(z) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) # make prox-center for Dual Averaging scheme loc = math.log(10 * self.step_size) self._adapted_scheme = DualAveraging(prox_center=loc) def end_warmup(self): if self.adapt_step_size: self._adapt_phase = False _, log_step_size_avg = self._adapted_scheme.get_state() self.step_size = math.exp(log_step_size_avg) self.num_steps = max(1, int(self.trajectory_length / self.step_size)) def cleanup(self): self._reset() def sample(self, trace): z = {name: node["value"].detach() for name, node in trace.iter_stochastic_nodes()} # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r = {name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name]) for name in self._r_dist} # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled() with dist.validation_enabled(dist_arg_check): z_new, r_new = velocity_verlet(z, r, self._potential_energy, self.step_size, self.num_steps) # apply Metropolis correction. energy_proposal = self._energy(z_new, r_new) energy_current = self._energy(z, r) delta_energy = energy_proposal - energy_current rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < (-delta_energy).exp(): self._accept_cnt += 1 z = z_new if self._adapt_phase: # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1).item() self._adapt_step_size(accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z) def diagnostics(self): return "Step size: {:.6f} \t Acceptance rate: {:.6f}".format( self.step_size, self._accept_cnt / self._t)