def transition_kernel_directional( self, state: State, forward: bool, training: bool = None, ): """Implements a series of directional updates.""" state_prop = State(state.x, state.v, state.beta) sumlogdet = tf.zeros((self.batch_size, ), dtype=TF_FLOAT) logdets = tf.TensorArray(TF_FLOAT, dynamic_size=True, size=self.batch_size, clear_after_read=True) energies = tf.TensorArray(TF_FLOAT, dynamic_size=True, size=self.batch_size, clear_after_read=True) # ==== # Forward for first half of trajectory for step in range(self.config.num_steps // 2): if self._verbose: logdets = logdets.write(step, sumlogdet) energies = energies.write(step, self.hamiltonian(state_prop)) state_prop, logdet = self._forward_lf(step, state_prop, training) sumlogdet += logdet # ==== # Flip momentum state_prop = State(state_prop.x, -1. * state_prop.v, state_prop.beta) # ==== # Backward for second half of trajectory for step in range(self.config.num_steps // 2, self.config.num_steps): state_prop, logdet = self._backward_lf(step, state_prop, training) sumlogdet += logdet logdets = logdets.write(step, logdet) energies = energies.write(step, self.hamiltonian(state_prop)) accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet) metrics = AttrDict({ 'sumlogdet': sumlogdet, 'accept_prob': accept_prob, }) if self._verbose: metrics.update({ 'energies': [energies.read(i) for i in range(self.config.num_steps)], 'logdets': [logdets.read(i) for i in range(self.config.num_steps)], }) return state_prop, metrics
def _update_v_backward(self, state: State, step: int, training: bool = None): """Update the momentum `v` in the backward leapfrog step. Args: state (State): Input state. t (float): Current leapfrog step, represented as periodic time. training (bool): Currently training? Returns: new_state (State): New state, with updated momentum. logdet (float): Jacobian factor. """ x = self.normalizer(state.x) grad = self.grad_potential(x, state.beta) t = self._get_time(step, tile=tf.shape(x)[0]) S, T, Q = self._call_vnet((x, grad, t), step, training) scale = self._vsw * (-0.5 * self.eps * S) transf = self._vqw * (self.eps * Q) transl = self._vtw * T expS = tf.exp(scale) expQ = tf.exp(transf) vb = expS * (state.v + 0.5 * self.eps * (grad * expQ - transl)) state_out = State(x=x, v=vb, beta=state.beta) logdet = tf.reduce_sum(scale, axis=1) return state_out, logdet
def _half_v_update_forward( self, state: State, step: int, training: bool = None, ): """Perform a half-step momentum update in the forward direction.""" x = self.normalizer(state.x) grad = self.grad_potential(x, state.beta) t = self._get_time(step, tile=tf.shape(x)[0]) S, T, Q = self._call_vnet((x, grad, t), step, training) scale = self._vsw * (0.5 * self.eps * S) transl = self._vtw * T transf = self._vqw * (self.eps * Q) expS = tf.exp(scale) expQ = tf.exp(transf) vf = state.v * expS - 0.5 * self.eps * (grad * expQ - transl) state_out = State(x=x, v=vf, beta=state.beta) logdet = tf.reduce_sum(scale, axis=1) return state_out, logdet
def _update_x_backward( self, state: State, step: int, masks: Tuple[tf.Tensor, tf.Tensor], # [m, 1. - m] training: bool = None): """Update the position `x` in the backward leapfrog step. Args: state (State): Input state t (float): Current leapfrog step, represented as periodic time. training (bool): Currently training? Returns: new_state (State): New state, with updated momentum. logdet (float): logdet of Jacobian factor. """ if self.config.hmc: return super()._update_x_backward(state, step, masks, training) # if self.config.use_ncp: # return self._update_xb_ncp(state, step, masks, training) # Call `XNet` using `self._scattered_xnet` m, mc = masks x = self.normalizer(state.x) t = self._get_time(step, tile=tf.shape(x)[0]) S, T, Q = self._call_xnet((x, state.v, t), m, step, training) scale = self._xsw * (-self.eps * S) transl = self._xtw * T transf = self._xqw * (self.eps * Q) expS = tf.exp(scale) expQ = tf.exp(transf) if self.config.use_ncp: term1 = 2 * tf.math.atan(expS * tf.math.tan(state.x / 2)) term2 = expS * self.eps * (state.v * expQ + transl) y = term1 - term2 xb = (m * x) + (mc * y) cterm = tf.math.cos(x / 2)**2 sterm = (expS * tf.math.sin(x / 2))**2 logdet_ = tf.math.log(expS / (cterm + sterm)) logdet = tf.reduce_sum(mc * logdet_, axis=1) else: y = expS * (x - self.eps * (state.v * expQ + transl)) xb = m * x + mc * y logdet = tf.reduce_sum(mc * scale, axis=1) xb = self.normalizer(xb) state_out = State(xb, v=state.v, beta=state.beta) return state_out, logdet
def _transition_kernel_backward(self, state: State, training: bool = None): """Run the augmented leapfrog sampler in the forward direction.""" kwargs = { 'dynamic_size': True, 'size': self.batch_size, 'clear_after_read': True } logdets = tf.TensorArray(TF_FLOAT, **kwargs) energies = tf.TensorArray(TF_FLOAT, **kwargs) sumlogdet = tf.zeros((self.batch_size, )) state_prop = State(state.x, state.v, state.beta) state_prop, logdet = self._half_v_update_backward( state_prop, 0, training) sumlogdet += logdet for step in range(self.config.num_steps): if self._verbose: logdets = logdets.write(step, sumlogdet) energies = energies.write(step, self.hamiltonian(state_prop)) state_prop, logdet = self._full_x_update_backward( state_prop, step, training) if step < self.config.num_steps - 1: state_prop, logdet = self._full_v_update_backward( state_prop, step, training) sumlogdet += logdet state_prop, logdet = self._half_v_update_backward( state_prop, step, training) sumlogdet += logdet accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet) metrics = AttrDict({ 'sumlogdet': sumlogdet, 'accept_prob': accept_prob, }) if self._verbose: logdets = logdets.write(self.config.num_steps, sumlogdet) energies = energies.write(self.config.num_steps, self.hamiltonian(state_prop)) metrics.update({ 'energies': [energies.read(i) for i in range(self.config.num_steps)], 'logdets': [logdets.read(i) for i in range(self.config.num_steps)], }) return state_prop, metrics
def transition_kernel_sep_nets( self, state: State, forward: bool, training: bool = None, ): """Implements a transition kernel when using separate networks.""" lf_fn = self._forward_lf if forward else self._backward_lf state_prop = State(x=state.x, v=state.v, beta=state.beta) sumlogdet = tf.zeros((self.batch_size, )) logdets = tf.TensorArray(TF_FLOAT, dynamic_size=True, size=self.batch_size, clear_after_read=True) energies = tf.TensorArray(TF_FLOAT, dynamic_size=True, size=self.batch_size, clear_after_read=True) for step in range(self.config.num_steps): if self._verbose: logdets = logdets.write(step, sumlogdet) energies = energies.write(step, self.hamiltonian(state_prop)) state_prop, logdet = lf_fn(step, state_prop, training) sumlogdet += logdet accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet) metrics = AttrDict({ 'sumlogdet': sumlogdet, 'accept_prob': accept_prob, }) if self._verbose: metrics.update({ 'energies': [energies.read(i) for i in range(self.config.num_steps)], 'logdets': [logdets.read(i) for i in range(self.config.num_steps)], }) return state_prop, metrics
def _update_v_forward(self, state: State, step: int, training: bool = None): """Update the momentum `v` in the forward leapfrog step. Args: network (tf.keras.Layers): Network to use state (State): Input state t (float): Current leapfrog step, represented as periodic time. training (bool): Currently training? Returns: new_state (State): New state, with updated momentum. logdet (float): Jacobian factor """ if self.config.hmc: return super()._update_v_forward(state, step, training) x = self.normalizer(state.x) grad = self.grad_potential(x, state.beta) t = self._get_time(step, tile=tf.shape(x)[0]) S, T, Q = self._call_vnet((x, grad, t), step, training) scale = self._vsw * (0.5 * self.eps * S) transl = self._vtw * T transf = self._vqw * (self.eps * Q) expS = tf.exp(scale) expQ = tf.exp(transf) vf = state.v * expS - 0.5 * self.eps * (grad * expQ - transl) state_out = State(x=x, v=vf, beta=state.beta) logdet = tf.reduce_sum(scale, axis=1) return state_out, logdet
def _half_v_update_backward(self, state: State, step: int, training: bool = None): """Perform a half update of the momentum in the backward direction.""" step_r = self.config.num_steps - step - 1 x = self.normalizer(state.x) grad = self.grad_potential(x, state.beta) t = self._get_time(step_r, tile=tf.shape(x)[0]) S, T, Q = self._call_vnet((x, grad, t), step_r, training) scale = self._vsw * (-0.5 * self.eps * S) transf = self._vqw * (self.eps * Q) transl = self._vtw * T expS = tf.exp(scale) expQ = tf.exp(transf) vb = expS * (state.v + 0.5 * self.eps * (grad * expQ - transl)) state_out = State(x=x, v=vb, beta=state.beta) logdet = tf.reduce_sum(scale, axis=1) return state_out, logdet