def append(self, transitions, rows=None): """Append a batch of transitions to rows of the memory. Args: transitions: Tuple of transition quantities with batch dimension. rows: Episodes to append to, defaults to all. Returns: Operation. """ rows = tf.range(self._capacity) if rows is None else rows assert rows.shape.ndims == 1 assert_capacity = tf.assert_less(rows, self._capacity, message='capacity exceeded') with tf.control_dependencies([assert_capacity]): assert_max_length = tf.assert_less(tf.gather(self._length, rows), self._max_length, message='max length exceeded') append_ops = [] with tf.control_dependencies([assert_max_length]): for buffer_, elements in zip(self._buffers, transitions): timestep = tf.gather(self._length, rows) indices = tf.stack([rows, timestep], 1) append_ops.append( tf.scatter_nd_update(buffer_, indices, elements)) with tf.control_dependencies(append_ops): episode_mask = tf.reduce_sum( tf.one_hot(rows, self._capacity, dtype=tf.int32), 0) return self._length.assign_add(episode_mask)
def assert_univariate_target_conservation(test, mk_target, step_size, stackless): # Sample count limited partly by memory reliably available on Forge. The test # remains reasonable even if the nuts recursion limit is severely curtailed # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump # the sample count. num_samples = int(5e4) num_steps = 1 target_d = mk_target() strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1) # We wrap the initial values in `tf.identity` to avoid broken gradients # resulting from a bijector cache hit, since bijectors of the same # type/parameterization now share a cache. # TODO(b/72831017): Fix broken gradients caused by bijector caching. initialization = tf.identity(target_d.sample([num_samples], seed=strm())) def target(*args): # TODO(axch): Just use target_d.log_prob directly, and accept target_d # itself as an argument instead of a maker function. Blocked by # b/128932888. It would then also be nice not to eta-expand # target_d.log_prob; that was blocked by b/122414321, but maybe tfp's port # of value_and_gradients_function fixed that bug. return mk_target().log_prob(*args) operator = tfp.experimental.mcmc.NoUTurnSampler(target, step_size=step_size, max_tree_depth=3, use_auto_batching=True, stackless=stackless, unrolled_leapfrog_steps=2, seed=strm()) result, extra = tfp.mcmc.sample_chain(num_results=num_steps, num_burnin_steps=0, current_state=initialization, kernel=operator) # Note: sample_chain puts the chain history on top, not the (independent) # chains. test.assertAllEqual([num_steps, num_samples], result.shape) answer = result[0] check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer, target_d.cdf, false_fail_rate=1e-6) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm( num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025) test.assertAllEqual([num_samples], extra.leapfrogs_taken[0].shape) unique, _ = tf.unique(extra.leapfrogs_taken[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=extra.leapfrogs_taken[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.abs(answer - initialization) test.assertAllEqual([num_samples], movement.shape) # This movement distance (1 * step_size) was selected by reducing until 100 # runs with independent seeds all passed. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 1 * step_size) return (check_cdf_agrees, check_enough_power, check_leapfrogs_vary, check_leapfrogs, check_movement)
def assert_mvn_target_conservation(event_size, batch_size, **kwargs): initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4) samples, leapfrogs = run_nuts_chain(event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = tf1.assert_equal( tf.shape(input=answer)[0], batch_size) unique, _ = tf.unique(leapfrogs[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.linalg.norm(tensor=answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 0.3) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary, check_leapfrogs, check_movement, check_enough_power)
def replace(self, episodes, length, rows=None): """Replace full episodes. Args: episodes: Tuple of transition quantities with batch and time dimensions. length: Batch of sequence lengths. rows: Episodes to replace, defaults to all. Returns: Operation. """ rows = tf.range(self._capacity) if rows is None else rows assert rows.shape.ndims == 1 assert_capacity = tf.assert_less(rows, self._capacity, message='capacity exceeded') with tf.control_dependencies([assert_capacity]): assert_max_length = tf.assert_less_equal( length, self._max_length, message='max length exceeded') replace_ops = [] with tf.control_dependencies([assert_max_length]): for buffer_, elements in zip(self._buffers, episodes): replace_op = tf.scatter_update(buffer_, rows, elements) replace_ops.append(replace_op) with tf.control_dependencies(replace_ops): return tf.scatter_update(self._length, rows, length)
def testRejection4D(self): num_samples = int(1e5) # Chosen for a small min detectable discrepancy det_bounds = np.array([0.0], dtype=np.float32) exact_volumes = [four_by_four_volume()] (rej_weights, rej_proposal_volume ) = corr.correlation_matrix_volume_rejection_samples(det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) # shape of rej_weights: [num_samples, 1, 4, 4] chk1 = st.assert_true_mean_equal_by_dkwm(rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, false_fail_rate=1e-6) chk2 = tf1.assert_less( st.min_discrepancy_of_true_means_detectable_by_dkwm( num_samples, low=0., high=rej_proposal_volume, false_fail_rate=1e-6, false_pass_rate=1e-6), # Going for about a 10% relative error 1.1) with tf.control_dependencies([chk1, chk2]): rej_weights = tf.identity(rej_weights) self.evaluate(rej_weights)
def testRejection2D(self): num_samples = int(1e5) # Chosen for a small min detectable discrepancy det_bounds = np.array( [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) exact_volumes = two_by_two_volume(det_bounds) (rej_weights, rej_proposal_volume ) = corr.correlation_matrix_volume_rejection_samples(det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) # shape of rej_weights: [num_samples, 9, 2, 2] chk1 = st.assert_true_mean_equal_by_dkwm(rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, false_fail_rate=1e-6) chk2 = tf1.assert_less( st.min_discrepancy_of_true_means_detectable_by_dkwm( num_samples, low=0., high=rej_proposal_volume, # Correct the false fail rate due to different broadcasting false_fail_rate=1.1e-7, false_pass_rate=1e-6), 0.036) with tf.control_dependencies([chk1, chk2]): rej_weights = tf.identity(rej_weights) self.evaluate(rej_weights)
def area_range_to_index(area_range, length, max_area_width): """Computes the indices of each area in the area expansion. Args: area_range: tensor in shape of [batch_size, 2] length: a scalar tensor gives the length of the original feature space. max_area_width: a constant scalar. Returns: indices: area indices tensor in shape of [batch_size] """ with tf.control_dependencies([ tf.assert_equal(tf.rank(area_range), 2), tf.assert_equal(tf.shape(area_range)[1], 2) ]): area_range = tf.cast(area_range, tf.int32) target_size = area_range[:, 1] - area_range[:, 0] with tf.control_dependencies( [tf.assert_less(target_size, max_area_width + 1, summarize=100000)]): sizes = target_size - 1 start_length = length pre_end_length = length - sizes + 1 base = (start_length + pre_end_length) *\ (start_length - pre_end_length + 1) // 2 base = tf.where(tf.less_equal(target_size, 1), tf.zeros_like(target_size), base) offset = area_range[:, 0] return base + offset
def push(self, value, mask, name=None): """Pushes `value` onto the stack, advances frame of batch members in `mask`. In this impl, we update each thread's top-of-stack (regardless of `mask`) to the corresponding `value`, then advance the stack pointers of only those threads indicated by `mask`. Args: value: `Tensor` having the shape of a single batch of the variable. mask: Boolean `Tensor` of shape `[batch_size]`. Threads at `True` indices of `mask` have their stack frames advanced; the others remain. name: Optional name for this op. Returns: stack: Updated stack. Does not mutate `self`. asserted_value: A assertion-bound snapshot of the input `value`, assertions used to catch stack overflows. """ with tf.name_scope(name or 'Stack.push'): value = tf.convert_to_tensor(value=value, name='value') mask = tf.convert_to_tensor(value=mask, name='mask') # self.stack: [max_stack_depth * batch_size, ...] # self.stack_index: [batch_size] # value: [batch_size, ...] batch_size = ( tf.compat.dimension_value(self.stack_index.shape[0]) or tf.shape(input=self.stack_index)[0]) max_stack_depth = (tf.compat.dimension_value(self.stack.shape[0]) or tf.shape(input=self.stack)[0]) // batch_size max_stack_depth_tensor = tf.convert_to_tensor(value=max_stack_depth) tiled_value = tf.tile( input=value[tf.newaxis, ...], multiples=tf.concat( [[max_stack_depth_tensor], tf.ones(tf.rank(value), dtype=max_stack_depth_tensor.dtype)], axis=0)) update_stack_mask = tf.one_hot( self.stack_index, depth=max_stack_depth, axis=0, # Stack depth x batch are both in outermost dim, stack major. on_value=True, off_value=False, dtype=tf.bool) new_stack = tf1.where( tf.reshape(update_stack_mask, [-1]), tf.reshape(tiled_value, tf.shape(input=self.stack)), self.stack) new_stack.set_shape(self.stack.shape) new_stack_index = self.stack_index + tf.cast(mask, self.stack_index.dtype) new_stack_index.set_shape(self.stack_index.shape) if self._safety_checks(): with tf.control_dependencies( [tf1.assert_less( new_stack_index, tf.cast( max_stack_depth_tensor, new_stack_index.dtype))]): value = tf.identity(value) new_stack_index = tf.identity(new_stack_index) return type(self)(new_stack, new_stack_index), value
def _scan_fn(*_): exchange = exchange_proposed_fn(num_replica, seed) flat_replicas = tf.reshape(exchange, [-1]) with tf.control_dependencies([ tf1.assert_equal( tf.size(input=flat_replicas), tf.size(input=tf.unique(flat_replicas)[0])), tf1.assert_greater_equal(flat_replicas, 0), tf1.assert_less(flat_replicas, num_replica), ]): return tf.shape(input=exchange)[0]
def _maybe_validate_target_accept_prob(target_accept_prob, validate_args): """Validates that target_accept_prob is in (0, 1).""" if not validate_args: return target_accept_prob with tf.control_dependencies([ tf1.assert_positive(target_accept_prob, message='`target_accept_prob` must be > 0.'), tf1.assert_less(target_accept_prob, tf.ones_like(target_accept_prob), message='`target_accept_prob` must be < 1.') ]): return tf.identity(target_accept_prob)
def append(self, value): """Appends a new tensor to the end of the buffer. Args: value: The tensor to append. Must match the shape specified in the initializer. Returns: An op appending the new tensor to the end of the buffer. """ def _double_capacity(): """Doubles the capacity of the current tensor buffer.""" padding = tf.zeros_like(self._buffer, self._buffer.dtype) new_buffer = tf.concat([self._buffer, padding], axis=0) if tf.executing_eagerly(): with tf.compat.v1.variable_scope(self._name, reuse=True): self._buffer = tf.get_variable(name='buffer', dtype=self._dtype, initializer=new_buffer, trainable=False) return self._buffer, tf.compat.v1.assign( self._capacity, tf.multiply(self._capacity, 2)) else: return tf.compat.v1.assign( self._buffer, new_buffer, validate_shape=False), tf.compat.v1.assign( self._capacity, tf.multiply(self._capacity, 2)) update_buffer, update_capacity = tf.cond( pred=tf.equal(self._current_size, self._capacity), true_fn=_double_capacity, false_fn=lambda: (self._buffer, self._capacity)) with tf.control_dependencies([update_buffer, update_capacity]): with tf.control_dependencies([ tf.assert_less( self._current_size, self._capacity, message='Appending past end of TensorBuffer.'), tf.assert_equal( tf.shape(input=value), tf.shape(input=self._buffer)[1:], message='Appending value of inconsistent shape.') ]): with tf.control_dependencies([ tf.compat.v1.assign( self._buffer[self._current_size, :], value) ]): return tf.compat.v1.assign_add(self._current_size, 1)
def sparse_softmax_cross_entropy(labels, logits, num_classes, weights=1.0, label_smoothing=0.1): """Softmax cross entropy with example weights, label smoothing.""" assert_valid_label = [ tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)), tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64)) ] with tf.control_dependencies(assert_valid_label): labels = tf.reshape(labels, [-1]) dense_labels = tf.one_hot(labels, num_classes) loss = tf.losses.softmax_cross_entropy(onehot_labels=dense_labels, logits=logits, weights=weights, label_smoothing=label_smoothing) return loss
def pad_to_fixed_size(data, pad_value, output_shape): """Pad data to a fixed length at the first dimension. Args: data: Tensor to be padded to output_shape. pad_value: A constant value assigned to the paddings. output_shape: The output shape of a 2D tensor. Returns: The Padded tensor with output_shape [max_instances_per_image, dimension]. """ max_instances_per_image = output_shape[0] dimension = output_shape[1] data = tf.reshape(data, [-1, dimension]) num_instances = tf.shape(data)[0] msg = 'ERROR: please increase config.max_instances_per_image' with tf.control_dependencies( [tf.assert_less(num_instances, max_instances_per_image, message=msg)]): pad_length = max_instances_per_image - num_instances paddings = pad_value * tf.ones([pad_length, dimension]) padded_data = tf.concat([data, paddings], axis=0) padded_data = tf.reshape(padded_data, output_shape) return padded_data