Ejemplo n.º 1
0
    def _integrator_conserves_energy(self, x, independent_chain_ndims):
        event_dims = tf.range(independent_chain_ndims, tf.rank(x))

        target_fn = lambda x: self._log_gamma_log_prob(x, event_dims)

        m = tf.random.normal(tf.shape(input=x))
        log_prob_0 = target_fn(x)
        old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(input_tensor=m**2.,
                                                       axis=event_dims)

        event_size = np.prod(self.evaluate(x).shape[independent_chain_ndims:])

        integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
            target_fn, step_sizes=[0.1 / event_size], num_steps=1000)

        [[new_m], [_], log_prob_1, [_]] = integrator([m], [x])

        new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(input_tensor=new_m**2.,
                                                       axis=event_dims)

        old_energy_, new_energy_ = self.evaluate([old_energy, new_energy])
        tf1.logging.vlog(
            1, 'average energy relative change: {}'.format(
                (1. - new_energy_ / old_energy_).mean()))
        self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
Ejemplo n.º 2
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=prefer_static.rank_from_shape(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
Ejemplo n.º 3
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            seeds = samplers.split_seed(seed, n=len(current_state_parts))
            seeds = distribute_lib.fold_in_axis_index(
                seeds, self.experimental_shard_axis_names)

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=ps.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = ps.rank(current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts,
                    next_momentum_parts,
                    independent_chain_ndims,
                    shard_axis_names=self.experimental_shard_axis_names),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                momentum_distribution,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                self.momentum_distribution,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)
            current_momentum_parts = momentum_distribution.sample(seed=seed)
            momentum_log_prob = getattr(momentum_distribution,
                                        '_log_prob_unnormalized',
                                        momentum_distribution.log_prob)
            kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)

            # Let the integrator handle the case where no momentum distribution
            # is provided
            if self.momentum_distribution is None:
                leapfrog_kinetic_energy_fn = None
            else:
                leapfrog_kinetic_energy_fn = kinetic_energy_fn

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(
                current_momentum_parts,
                current_state_parts,
                target=current_target_log_prob,
                target_grad_parts=current_target_log_prob_grad_parts,
                kinetic_energy_fn=leapfrog_kinetic_energy_fn)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    kinetic_energy_fn, current_momentum_parts,
                    next_momentum_parts),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Ejemplo n.º 5
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(x),
                                     dtype=self._momentum_dtype
                                     or dtype_util.base_dtype(x.dtype),
                                     seed=self._seed_stream()))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = prefer_static.rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Ejemplo n.º 6
0
    def loop_tree_doubling(self, step_size, log_slice_sample, init_energy,
                           momentum_state_memory, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_size = prefer_static.size(init_energy)
            direction = tf.cast(tf.random.uniform(shape=[batch_size],
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            left_right_index = tf.concat([
                tf.cast(direction, tf.int32)[..., tf.newaxis],
                tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis]
            ],
                                         axis=1)
            tree_start_states = tf.nest.map_structure(
                # Alternatively: `lambda v: tf.where(direction, v[1], v[0])`
                lambda v: tf.gather_nd(v, left_right_index),
                initial_step_state)

            directions_expanded = [
                _expand_dims_under_batch_dim(direction,
                                             prefer_static.rank(state))
                for state in tree_start_states.state
            ]
            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(direction, ss, -ss)
                    for direction, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state,
                tree_final_states,
                final_not_divergence,
                continue_tree_final,
                energy_diff_tree_sum,
                leapfrogs_taken,
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                log_slice_sample,
                init_energy,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(choose_new_state,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(candidate_tree_state.state,
                                                  last_candidate_state.state)
                ],
                target=tf.where(choose_new_state, candidate_tree_state.target,
                                last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            choose_new_state, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                weight=weight_sum)
            # Update left right information of the trajectory, and check trajectory
            # level U turn

            # Alternative approach
            # left_right_mask = tf.transpose(
            #     tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2),
            #            [1, initial_step_metastate.candidate_state[0].shape[-1], 1]),
            #     [2, 0, 1])

            # trajactory_state_left_right = tf.where(
            #     tf.equal(left_right_mask, 0.),
            #     trajactory_state_left_right,
            #     tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1]))
            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    # Alternative approach:
                    # tf.where(tf.equal(left_right_mask, 0.),
                    #          v,
                    #          tf.tile(r[tf.newaxis],
                    #                  tf.concat([[2], tf.ones_like(tf.shape(r))], 0)))
                    tf.tensor_scatter_nd_update(v, left_right_index, r)
                    for v, r in zip(tf.nest.flatten(initial_step_state),
                                    tf.nest.flatten(tree_final_states))
                ])
            no_u_turns_trajectory = has_not_u_turn(
                [s[0] for s in new_step_state.state],
                [m[0] for m in new_step_state.momentum],
                [s[1] for s in new_step_state.state],
                [m[1] for m in new_step_state.momentum])

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
Ejemplo n.º 7
0
  def _loop_build_sub_tree(
      self, direction, log_slice_sample,
      iter_, prev_tree_state, candidate_tree_state,
      continue_tree_previous, trace_arrays):
    """Base case in tree doubling."""
    with tf.name_scope('loop_build_sub_tree'):
      # Take one leapfrog step in the direction v and check divergence
      directions_expanded = [
          _expand_dims_under_batch_dim(direction, prefer_static.rank(state))
          for state in prev_tree_state.state]
      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn,
          step_sizes=[tf.where(direction, ss, -ss)
                      for direction, ss in zip(
                          directions_expanded, self.step_size)],
          num_steps=self.unrolled_leapfrog_steps)
      [
          next_momentum_parts,
          next_state_parts,
          next_target,
          next_target_grad_parts
      ] = integrator(prev_tree_state.momentum,
                     prev_tree_state.state,
                     prev_tree_state.target,
                     prev_tree_state.target_grad_parts)

      next_tree_state = TreeDoublingState(
          momentum=next_momentum_parts,
          state=next_state_parts,
          target=next_target,
          target_grad_parts=next_target_grad_parts)

      # Save state and momentum at odd step, check U turn at even step.
      # Note that here we also write to a Placeholder at even step to avoid
      # using tf.cond
      index = iter_ // 2
      if USE_RAGGED_TENSOR:
        write_index_ = self.write_instruction[index]
      else:
        write_index_ = tf.switch_case(index, self.write_instruction)

      write_index = tf.where(tf.equal(iter_ % 2, 0),
                             write_index_, self.max_tree_depth)

      if USE_TENSORARRAY:
        trace_arrays = TraceArrays(
            momentum_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.state_swap, next_state_parts)])
      else:
        trace_arrays = TraceArrays(
            momentum_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.state_swap, next_state_parts)])
      batch_size = prefer_static.size(next_target)
      has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

      if USE_RAGGED_TENSOR:
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                self.read_instruction, iter_ // 2, directions_expanded,
                trace_arrays, next_momentum_parts, next_state_parts))
      else:
        f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
            self.read_instruction, int_iter, directions_expanded, trace_arrays,
            next_momentum_parts, next_state_parts)
        branch_excution = {x: functools.partial(f, x)
                           for x in range(len(self.read_instruction))}
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: tf.switch_case(iter_ // 2, branch_excution))

      energy = compute_hamiltonian(next_target, next_momentum_parts)
      valid_candidate = log_slice_sample <= energy

      # Uniform sampling on the trajectory within the subtree
      sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE)
      weight_sum = candidate_tree_state.weight + sample_weight
      log_accept_thresh = tf.math.log(
          tf.cast(sample_weight, tf.float32) /
          tf.cast(weight_sum, tf.float32))
      log_accept_thresh = tf.where(
          tf.math.is_nan(log_accept_thresh),
          tf.zeros([], log_accept_thresh.dtype),
          log_accept_thresh)
      u = tf.math.log1p(-tf.random.uniform(
          shape=[batch_size],
          dtype=tf.float32,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      next_candidate_tree_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(s0)), s0, s1)
              for s0, s1 in zip(next_state_parts,
                                candidate_tree_state.state)
          ],
          target=tf.where(is_sample_accepted,
                          next_target,
                          candidate_tree_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(grad0)),
                  grad0, grad1)
              for grad0, grad1 in zip(next_target_grad_parts,
                                      candidate_tree_state.target_grad_parts)
          ],
          weight=weight_sum)

      not_divergent = log_slice_sample - energy < self.max_energy_diff
      continue_tree = not_divergent & no_u_turns_within_tree
      continue_tree_next = continue_tree_previous & continue_tree

      return (
          iter_ + 1,
          next_tree_state,
          next_candidate_tree_state,
          continue_tree_next,
          trace_arrays,
      )
Ejemplo n.º 8
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                seed = samplers.sanitize_seed(self._seed_stream())
            seeds = samplers.split_seed(seed, n=len(current_state_parts))

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=tf.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = prefer_static.rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results