Example #1
0
    def transition_kernel_directional(
        self,
        state: State,
        forward: bool,
        training: bool = None,
    ):
        """Implements a series of directional updates."""
        state_prop = State(state.x, state.v, state.beta)
        sumlogdet = tf.zeros((self.batch_size, ), dtype=TF_FLOAT)
        logdets = tf.TensorArray(TF_FLOAT,
                                 dynamic_size=True,
                                 size=self.batch_size,
                                 clear_after_read=True)
        energies = tf.TensorArray(TF_FLOAT,
                                  dynamic_size=True,
                                  size=self.batch_size,
                                  clear_after_read=True)
        # ====
        # Forward for first half of trajectory
        for step in range(self.config.num_steps // 2):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = self._forward_lf(step, state_prop, training)
            sumlogdet += logdet

        # ====
        # Flip momentum
        state_prop = State(state_prop.x, -1. * state_prop.v, state_prop.beta)

        # ====
        # Backward for second half of trajectory
        for step in range(self.config.num_steps // 2, self.config.num_steps):
            state_prop, logdet = self._backward_lf(step, state_prop, training)
            sumlogdet += logdet

            logdets = logdets.write(step, logdet)
            energies = energies.write(step, self.hamiltonian(state_prop))

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)
        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #2
0
    def _update_v_backward(self,
                           state: State,
                           step: int,
                           training: bool = None):
        """Update the momentum `v` in the backward leapfrog step.

        Args:
            state (State): Input state.
            t (float): Current leapfrog step, represented as periodic time.
            training (bool): Currently training?

        Returns:
            new_state (State): New state, with updated momentum.
            logdet (float): Jacobian factor.
        """
        x = self.normalizer(state.x)
        grad = self.grad_potential(x, state.beta)
        t = self._get_time(step, tile=tf.shape(x)[0])
        S, T, Q = self._call_vnet((x, grad, t), step, training)

        scale = self._vsw * (-0.5 * self.eps * S)
        transf = self._vqw * (self.eps * Q)
        transl = self._vtw * T

        expS = tf.exp(scale)
        expQ = tf.exp(transf)

        vb = expS * (state.v + 0.5 * self.eps * (grad * expQ - transl))

        state_out = State(x=x, v=vb, beta=state.beta)
        logdet = tf.reduce_sum(scale, axis=1)

        return state_out, logdet
Example #3
0
    def _half_v_update_forward(
        self,
        state: State,
        step: int,
        training: bool = None,
    ):
        """Perform a half-step momentum update in the forward direction."""
        x = self.normalizer(state.x)
        grad = self.grad_potential(x, state.beta)
        t = self._get_time(step, tile=tf.shape(x)[0])

        S, T, Q = self._call_vnet((x, grad, t), step, training)

        scale = self._vsw * (0.5 * self.eps * S)
        transl = self._vtw * T
        transf = self._vqw * (self.eps * Q)

        expS = tf.exp(scale)
        expQ = tf.exp(transf)

        vf = state.v * expS - 0.5 * self.eps * (grad * expQ - transl)

        state_out = State(x=x, v=vf, beta=state.beta)
        logdet = tf.reduce_sum(scale, axis=1)

        return state_out, logdet
Example #4
0
    def _update_x_backward(
            self,
            state: State,
            step: int,
            masks: Tuple[tf.Tensor, tf.Tensor],  # [m, 1. - m]
            training: bool = None):
        """Update the position `x` in the backward leapfrog step.

        Args:
            state (State): Input state
            t (float): Current leapfrog step, represented as periodic time.
            training (bool): Currently training?


        Returns:
            new_state (State): New state, with updated momentum.
            logdet (float): logdet of Jacobian factor.
        """
        if self.config.hmc:
            return super()._update_x_backward(state, step, masks, training)
        #  if self.config.use_ncp:
        #      return self._update_xb_ncp(state, step, masks, training)

        # Call `XNet` using `self._scattered_xnet`
        m, mc = masks
        x = self.normalizer(state.x)
        t = self._get_time(step, tile=tf.shape(x)[0])
        S, T, Q = self._call_xnet((x, state.v, t), m, step, training)

        scale = self._xsw * (-self.eps * S)
        transl = self._xtw * T
        transf = self._xqw * (self.eps * Q)

        expS = tf.exp(scale)
        expQ = tf.exp(transf)

        if self.config.use_ncp:
            term1 = 2 * tf.math.atan(expS * tf.math.tan(state.x / 2))
            term2 = expS * self.eps * (state.v * expQ + transl)
            y = term1 - term2
            xb = (m * x) + (mc * y)

            cterm = tf.math.cos(x / 2)**2
            sterm = (expS * tf.math.sin(x / 2))**2
            logdet_ = tf.math.log(expS / (cterm + sterm))
            logdet = tf.reduce_sum(mc * logdet_, axis=1)

        else:
            y = expS * (x - self.eps * (state.v * expQ + transl))
            xb = m * x + mc * y
            logdet = tf.reduce_sum(mc * scale, axis=1)

        xb = self.normalizer(xb)
        state_out = State(xb, v=state.v, beta=state.beta)
        return state_out, logdet
Example #5
0
    def _transition_kernel_backward(self, state: State, training: bool = None):
        """Run the augmented leapfrog sampler in the forward direction."""
        kwargs = {
            'dynamic_size': True,
            'size': self.batch_size,
            'clear_after_read': True
        }
        logdets = tf.TensorArray(TF_FLOAT, **kwargs)
        energies = tf.TensorArray(TF_FLOAT, **kwargs)
        sumlogdet = tf.zeros((self.batch_size, ))
        state_prop = State(state.x, state.v, state.beta)

        state_prop, logdet = self._half_v_update_backward(
            state_prop, 0, training)
        sumlogdet += logdet
        for step in range(self.config.num_steps):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = self._full_x_update_backward(
                state_prop, step, training)

            if step < self.config.num_steps - 1:
                state_prop, logdet = self._full_v_update_backward(
                    state_prop, step, training)
                sumlogdet += logdet

        state_prop, logdet = self._half_v_update_backward(
            state_prop, step, training)
        sumlogdet += logdet

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)

        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            logdets = logdets.write(self.config.num_steps, sumlogdet)
            energies = energies.write(self.config.num_steps,
                                      self.hamiltonian(state_prop))
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #6
0
    def transition_kernel_sep_nets(
        self,
        state: State,
        forward: bool,
        training: bool = None,
    ):
        """Implements a transition kernel when using separate networks."""
        lf_fn = self._forward_lf if forward else self._backward_lf
        state_prop = State(x=state.x, v=state.v, beta=state.beta)
        sumlogdet = tf.zeros((self.batch_size, ))
        logdets = tf.TensorArray(TF_FLOAT,
                                 dynamic_size=True,
                                 size=self.batch_size,
                                 clear_after_read=True)
        energies = tf.TensorArray(TF_FLOAT,
                                  dynamic_size=True,
                                  size=self.batch_size,
                                  clear_after_read=True)

        for step in range(self.config.num_steps):
            if self._verbose:
                logdets = logdets.write(step, sumlogdet)
                energies = energies.write(step, self.hamiltonian(state_prop))

            state_prop, logdet = lf_fn(step, state_prop, training)
            sumlogdet += logdet

        accept_prob = self.compute_accept_prob(state, state_prop, sumlogdet)

        metrics = AttrDict({
            'sumlogdet': sumlogdet,
            'accept_prob': accept_prob,
        })
        if self._verbose:
            metrics.update({
                'energies':
                [energies.read(i) for i in range(self.config.num_steps)],
                'logdets':
                [logdets.read(i) for i in range(self.config.num_steps)],
            })

        return state_prop, metrics
Example #7
0
    def _update_v_forward(self,
                          state: State,
                          step: int,
                          training: bool = None):
        """Update the momentum `v` in the forward leapfrog step.

        Args:
            network (tf.keras.Layers): Network to use
            state (State): Input state
            t (float): Current leapfrog step, represented as periodic time.
            training (bool): Currently training?

        Returns:
            new_state (State): New state, with updated momentum.
            logdet (float): Jacobian factor
        """
        if self.config.hmc:
            return super()._update_v_forward(state, step, training)

        x = self.normalizer(state.x)
        grad = self.grad_potential(x, state.beta)
        t = self._get_time(step, tile=tf.shape(x)[0])

        S, T, Q = self._call_vnet((x, grad, t), step, training)

        scale = self._vsw * (0.5 * self.eps * S)
        transl = self._vtw * T
        transf = self._vqw * (self.eps * Q)

        expS = tf.exp(scale)
        expQ = tf.exp(transf)

        vf = state.v * expS - 0.5 * self.eps * (grad * expQ - transl)

        state_out = State(x=x, v=vf, beta=state.beta)
        logdet = tf.reduce_sum(scale, axis=1)

        return state_out, logdet
Example #8
0
    def _half_v_update_backward(self,
                                state: State,
                                step: int,
                                training: bool = None):
        """Perform a half update of the momentum in the backward direction."""
        step_r = self.config.num_steps - step - 1
        x = self.normalizer(state.x)
        grad = self.grad_potential(x, state.beta)
        t = self._get_time(step_r, tile=tf.shape(x)[0])
        S, T, Q = self._call_vnet((x, grad, t), step_r, training)

        scale = self._vsw * (-0.5 * self.eps * S)
        transf = self._vqw * (self.eps * Q)
        transl = self._vtw * T

        expS = tf.exp(scale)
        expQ = tf.exp(transf)

        vb = expS * (state.v + 0.5 * self.eps * (grad * expQ - transl))

        state_out = State(x=x, v=vb, beta=state.beta)
        logdet = tf.reduce_sum(scale, axis=1)

        return state_out, logdet