Example #1
0
def safe_mean_summary(name, value):
    """Generate mean summary of `value`.

    It skips the summary if `value` is empty.

    Args:
        name (str): name of the summary
        value (Tensor): tensor to be summarized
    Returns:
        None
    """
    run_if(
        tf.reduce_prod(tf.shape(value)) > 0,
        lambda: add_mean_summary(name, value))
Example #2
0
    def _add_batch(self, batch):
        """
        Add a batch of samples to the buffer. When the buffer is not full, add
        whatever many samples to the buffer until full. Once it's full, the samples
        in the buffer will be randomly replaced.

        Args:
            batch (nested Tensor): shape should be [batch_size] + data_spec.shape
        """
        batch_size = get_nest_batch_size(batch, tf.int32)
        buffer_space_left = self._capacity - self._current_size
        batch_replacing_size = batch_size - buffer_space_left

        # Split the batch into two: the first one (if exists) will be directly
        # added into the buffer; the second one (if exists) will randomly replace
        # samples in the buffer
        def _fill_buffer():
            data_filling_buffer = tf.nest.map_structure(
                lambda bat: bat[:buffer_space_left], batch)
            super(ReservoirSampler, self).add_batch(data_filling_buffer)

        def _replace_buffer():
            # Sample the slots to be replaced in the set
            replace_indices = tf.random.uniform(shape=(batch_replacing_size, ),
                                                dtype=tf.int32,
                                                minval=0,
                                                maxval=self._capacity)
            # replace_indices can contain duplicates; sort the indices and the
            # corresponding batch data
            sort_idx = tf.argsort(replace_indices)
            replace_indices = tf.sort(replace_indices)
            data_replacing_buffer = tf.nest.map_structure(
                lambda bat: tf.gather(
                    bat[-batch_replacing_size:], sort_idx, axis=0), batch)

            # For duplicate indices, we only need to keep the last one
            unique_indices, _, count = tf.unique_with_counts(replace_indices)
            last = tf.cumsum(count) - 1
            unique_indices = tf.expand_dims(unique_indices, axis=-1)

            tf.nest.map_structure(
                lambda buf, bat: buf.scatter_nd_update(
                    unique_indices,
                    tf.stop_gradient(tf.gather(bat, last, axis=0))),
                self._buffer, data_replacing_buffer)

        run_if(buffer_space_left > 0, _fill_buffer)
        run_if(batch_replacing_size > 0, _replace_buffer)
Example #3
0
def safe_mean_summary(name, value):
    """Generate mean summary of `value`.

    It skips the summary if `value` is empty.

    Args:
        name (str): name of the summary
        value (Tensor): tensor to be summarized
    Returns:
        None
    """
    current_scope = get_current_scope()
    # The reason of prefixing name with current_scope is that inside run_if,
    # somehow get_current_scope() is '', which makes summary tag unscoped.
    run_if(
        tf.reduce_prod(tf.shape(value)) > 0,
        lambda: add_mean_summary(current_scope + name, value))
Example #4
0
 def wrapper(*args, **kwargs):
     from alf.utils.common import run_if
     return run_if(should_record_summaries(),
                   lambda: summary_func(*args, **kwargs))