Exemple #1
0
def val_where(cond, tval, fval):
    """Like tf.where but works on namedtuples."""
    if isinstance(tval, tf.Tensor):
        return broadcast_util.where_left_justified_mask(cond, tval, fval)
    elif isinstance(tval, tuple):
        cls = type(tval)
        return cls(*(val_where(cond, t, f) for t, f in zip(tval, fval)))
    else:
        raise Exception(TypeError)
Exemple #2
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory, seed):
        """Base case in tree doubling."""
        acceptance_seed, next_seed = samplers.split_seed(seed)
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum_previous
                state_to_check = momentum_cumsum
            else:
                state_to_write = next_state_parts
                state_to_check = next_state_parts

            batch_shape = ps.shape(next_target)
            has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool)

            read_index = read_instruction.gather([iter_])[0]
            no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
                read_index,
                directions,
                momentum_state_memory,
                next_momentum_parts,
                state_to_check,
                has_not_u_turn_init,
                log_prob_rank=ps.rank(next_target))

            # Get index to write state into memory swap
            write_index = write_instruction.gather([iter_])
            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(is_sample_accepted, s0, s1)
                    for s0, s1 in zip(next_state_parts,
                                      candidate_tree_state.state)
                ],
                target=bu.where_left_justified_mask(
                    is_sample_accepted, next_target,
                    candidate_tree_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(is_sample_accepted, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        next_target_grad_parts,
                        candidate_tree_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    is_sample_accepted, current_energy,
                    candidate_tree_state.energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                ps.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                next_seed,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Exemple #3
0
    def _loop_tree_doubling(self, step_size, momentum_state_memory,
                            current_step_meta_info, iter_, initial_step_state,
                            initial_step_metastate, seed):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            (direction_seed, subtree_seed, acceptance_seed,
             next_seed) = samplers.split_seed(seed, n=4)
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            direction = tf.cast(samplers.uniform(shape=batch_shape,
                                                 minval=0,
                                                 maxval=2,
                                                 dtype=tf.int32,
                                                 seed=direction_seed),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                bu.left_justified_expand_dims_like(direction, state)
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory,
                seed=subtree_seed)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(choose_new_state, s0, s1)
                    for s0, s1 in zip(candidate_tree_state.state,
                                      last_candidate_state.state)
                ],
                target=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.target,
                    last_candidate_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(choose_new_state, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        candidate_tree_state.target_grad_parts,
                        last_candidate_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.energy,
                    last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            bu.where_left_justified_mask(
                                direction, right, left),
                            bu.where_left_justified_mask(
                                direction, left, right),
                        ],
                        axis=0) for left, right in zip(
                            tf.nest.flatten(tree_final_states),
                            tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=ps.rank_from_shape(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, next_seed, new_step_state, new_step_metastate
Exemple #4
0
 def pick(new, old):
   return bu.where_left_justified_mask(new_good_values_mask, new, old)