def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError('Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) dummy_momentum = [tf.ones_like(state) for state in init_state] def _init(shape_and_dtype): """Allocate TensorArray for storing state and momentum.""" return [ # pylint: disable=g-complex-comprehension ps.zeros( ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x] momentum_state_memory = MomentumStateSwap( momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)), state_swap=_init(get_shapes_and_dtypes(init_state))) [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, momentum_state_memory=momentum_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, dummy_momentum), # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) self._step_size = step_size if len(step_size) != len(init_state): raise ValueError('Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) dummy_momentum = [tf.ones_like(state) for state in init_state] [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) batch_size = prefer_static.size(current_target_log_prob) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, leapfrogs_computed=tf.zeros([], dtype=tf.int32, name='leapfrogs_computed'), is_accepted=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] dummy_momentum = [tf.ones_like(state) for state in init_state] [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) # Confirm that the step size is compatible with the state parts. _ = _prepare_step_size(self.step_size, current_target_log_prob.dtype, len(init_state)) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), self.step_size), log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian( current_target_log_prob, dummy_momentum, shard_axis_names=self.experimental_shard_axis_names), # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] dummy_momentum = [tf.ones_like(state) for state in init_state] def _init(shape_and_dtype): """Allocate TensorArray for storing state and momentum.""" return [ # pylint: disable=g-complex-comprehension ps.zeros( ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x] momentum_state_memory = MomentumStateSwap( momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)), state_swap=_init(get_shapes_and_dtypes(init_state))) [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) # Confirm that the step size is compatible with the state parts. _ = _prepare_step_size( self.step_size, current_target_log_prob.dtype, len(init_state)) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, momentum_state_memory=momentum_state_memory, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), self.step_size), log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian( current_target_log_prob, dummy_momentum, shard_axis_names=self.experimental_shard_axis_names), # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError( 'Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) dummy_momentum = [tf.ones_like(state) for state in init_state] def _init(shape_and_dtype): """Allocate TensorArray for storing state and momentum.""" if USE_TENSORARRAY: return [ # pylint: disable=g-complex-comprehension tf.TensorArray(dtype=d, size=self.max_tree_depth + 1, element_shape=s, clear_after_read=False) for (s, d) in shape_and_dtype ] else: return [ # pylint: disable=g-complex-comprehension tf.zeros(tf.TensorShape([self.max_tree_depth + 1 ]).concatenate(s), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [(x_.shape, x_.dtype) for x_ in x] momentum_state_memory = MomentumStateSwap( momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)), state_swap=_init(get_shapes_and_dtypes(init_state))) [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) batch_size = prefer_static.size(current_target_log_prob) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, momentum_state_memory=momentum_state_memory, step_size=step_size, log_accept_ratio=tf.zeros([batch_size], dtype=current_target_log_prob.dtype, name='log_accept_ratio'), leapfrogs_taken=tf.zeros([batch_size], dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros([batch_size], dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros([batch_size], dtype=tf.bool, name='has_divergence'), )