def val_where(cond, tval, fval): """Like tf.where but works on namedtuples.""" if isinstance(tval, tf.Tensor): return broadcast_util.where_left_justified_mask(cond, tval, fval) elif isinstance(tval, tuple): cls = type(tval) return cls(*(val_where(cond, t, f) for t, f in zip(tval, fval))) else: raise Exception(TypeError)
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed): """Base case in tree doubling.""" acceptance_seed, next_seed = samplers.split_seed(seed) with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [ p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts) ] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy if GENERALIZED_UTURN: state_to_write = momentum_cumsum_previous state_to_check = momentum_cumsum else: state_to_write = next_state_parts state_to_check = next_state_parts batch_shape = ps.shape(next_target) has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_])[0] no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda read_index, directions, momentum_state_memory, next_momentum_parts, state_to_check, has_not_u_turn_init, log_prob_rank=ps.rank(next_target)) # Get index to write state into memory swap write_index = write_instruction.gather([iter_]) momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.state_swap, state_to_write) ]) energy = compute_hamiltonian(next_target, next_momentum_parts) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-samplers.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(is_sample_accepted, s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=bu.where_left_justified_mask( is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ bu.where_left_justified_mask(is_sample_accepted, grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=bu.where_left_justified_mask( is_sample_accepted, current_energy, candidate_tree_state.energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, ps.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.)) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, next_seed, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): (direction_seed, subtree_seed, acceptance_seed, next_seed) = samplers.split_seed(seed, n=4) batch_shape = ps.shape(current_step_meta_info.init_energy) direction = tf.cast(samplers.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=direction_seed), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[1], v[0]), initial_step_state) directions_expanded = [ bu.left_justified_expand_dims_like(direction, state) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory, seed=subtree_seed) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-samplers.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(choose_new_state, s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ bu.where_left_justified_mask(choose_new_state, grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension bu.where_left_justified_mask( direction, right, left), bu.where_left_justified_mask( direction, left, right), ], axis=0) for left, right in zip( tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=ps.rank_from_shape(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, next_seed, new_step_state, new_step_metastate
def pick(new, old): return bu.where_left_justified_mask(new_good_values_mask, new, old)