Beispiel #1
0
    def append(self, transitions, rows=None):
        """Append a batch of transitions to rows of the memory.

    Args:
      transitions: Tuple of transition quantities with batch dimension.
      rows: Episodes to append to, defaults to all.

    Returns:
      Operation.
    """
        rows = tf.range(self._capacity) if rows is None else rows
        assert rows.shape.ndims == 1
        assert_capacity = tf.assert_less(rows,
                                         self._capacity,
                                         message='capacity exceeded')
        with tf.control_dependencies([assert_capacity]):
            assert_max_length = tf.assert_less(tf.gather(self._length, rows),
                                               self._max_length,
                                               message='max length exceeded')
        append_ops = []
        with tf.control_dependencies([assert_max_length]):
            for buffer_, elements in zip(self._buffers, transitions):
                timestep = tf.gather(self._length, rows)
                indices = tf.stack([rows, timestep], 1)
                append_ops.append(
                    tf.scatter_nd_update(buffer_, indices, elements))
        with tf.control_dependencies(append_ops):
            episode_mask = tf.reduce_sum(
                tf.one_hot(rows, self._capacity, dtype=tf.int32), 0)
            return self._length.assign_add(episode_mask)
Beispiel #2
0
def assert_univariate_target_conservation(test, mk_target, step_size,
                                          stackless):
    # Sample count limited partly by memory reliably available on Forge.  The test
    # remains reasonable even if the nuts recursion limit is severely curtailed
    # (e.g., 3 or 4 levels), so use that to recover some memory footprint and bump
    # the sample count.
    num_samples = int(5e4)
    num_steps = 1
    target_d = mk_target()
    strm = tfp.util.SeedStream(salt='univariate_nuts_test', seed=1)
    # We wrap the initial values in `tf.identity` to avoid broken gradients
    # resulting from a bijector cache hit, since bijectors of the same
    # type/parameterization now share a cache.
    # TODO(b/72831017): Fix broken gradients caused by bijector caching.
    initialization = tf.identity(target_d.sample([num_samples], seed=strm()))

    def target(*args):
        # TODO(axch): Just use target_d.log_prob directly, and accept target_d
        # itself as an argument instead of a maker function.  Blocked by
        # b/128932888.  It would then also be nice not to eta-expand
        # target_d.log_prob; that was blocked by b/122414321, but maybe tfp's port
        # of value_and_gradients_function fixed that bug.
        return mk_target().log_prob(*args)

    operator = tfp.experimental.mcmc.NoUTurnSampler(target,
                                                    step_size=step_size,
                                                    max_tree_depth=3,
                                                    use_auto_batching=True,
                                                    stackless=stackless,
                                                    unrolled_leapfrog_steps=2,
                                                    seed=strm())
    result, extra = tfp.mcmc.sample_chain(num_results=num_steps,
                                          num_burnin_steps=0,
                                          current_state=initialization,
                                          kernel=operator)
    # Note: sample_chain puts the chain history on top, not the (independent)
    # chains.
    test.assertAllEqual([num_steps, num_samples], result.shape)
    answer = result[0]
    check_cdf_agrees = st.assert_true_cdf_equal_by_dkwm(answer,
                                                        target_d.cdf,
                                                        false_fail_rate=1e-6)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm(
            num_samples, false_fail_rate=1e-6, false_pass_rate=1e-6), 0.025)
    test.assertAllEqual([num_samples], extra.leapfrogs_taken[0].shape)
    unique, _ = tf.unique(extra.leapfrogs_taken[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=extra.leapfrogs_taken[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.abs(answer - initialization)
    test.assertAllEqual([num_samples], movement.shape)
    # This movement distance (1 * step_size) was selected by reducing until 100
    # runs with independent seeds all passed.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 1 * step_size)
    return (check_cdf_agrees, check_enough_power, check_leapfrogs_vary,
            check_leapfrogs, check_movement)
Beispiel #3
0
def assert_mvn_target_conservation(event_size, batch_size, **kwargs):
    initialization = tfd.MultivariateNormalFullCovariance(
        loc=tf.zeros(event_size),
        covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4)
    samples, leapfrogs = run_nuts_chain(event_size,
                                        batch_size,
                                        num_steps=1,
                                        initial_state=initialization,
                                        **kwargs)
    answer = samples[0][-1]
    check_cdf_agrees = (
        st.assert_multivariate_true_cdf_equal_on_projections_two_sample(
            answer, initialization, num_projections=100, false_fail_rate=1e-6))
    check_sample_shape = tf1.assert_equal(
        tf.shape(input=answer)[0], batch_size)
    unique, _ = tf.unique(leapfrogs[0])
    check_leapfrogs_vary = tf1.assert_greater_equal(
        tf.shape(input=unique)[0], 3)
    avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0])
    check_leapfrogs = tf1.assert_greater_equal(
        avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype))
    movement = tf.linalg.norm(tensor=answer - initialization, axis=-1)
    # This movement distance (0.3) was copied from the univariate case.
    check_movement = tf1.assert_greater_equal(
        tf.reduce_mean(input_tensor=movement), 0.3)
    check_enough_power = tf1.assert_less(
        st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample(
            batch_size, batch_size, false_fail_rate=1e-8,
            false_pass_rate=1e-6), 0.055)
    return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary,
            check_leapfrogs, check_movement, check_enough_power)
Beispiel #4
0
    def replace(self, episodes, length, rows=None):
        """Replace full episodes.

    Args:
      episodes: Tuple of transition quantities with batch and time dimensions.
      length: Batch of sequence lengths.
      rows: Episodes to replace, defaults to all.

    Returns:
      Operation.
    """
        rows = tf.range(self._capacity) if rows is None else rows
        assert rows.shape.ndims == 1
        assert_capacity = tf.assert_less(rows,
                                         self._capacity,
                                         message='capacity exceeded')
        with tf.control_dependencies([assert_capacity]):
            assert_max_length = tf.assert_less_equal(
                length, self._max_length, message='max length exceeded')
        replace_ops = []
        with tf.control_dependencies([assert_max_length]):
            for buffer_, elements in zip(self._buffers, episodes):
                replace_op = tf.scatter_update(buffer_, rows, elements)
                replace_ops.append(replace_op)
        with tf.control_dependencies(replace_ops):
            return tf.scatter_update(self._length, rows, length)
Beispiel #5
0
 def testRejection4D(self):
     num_samples = int(1e5)  # Chosen for a small min detectable discrepancy
     det_bounds = np.array([0.0], dtype=np.float32)
     exact_volumes = [four_by_four_volume()]
     (rej_weights, rej_proposal_volume
      ) = corr.correlation_matrix_volume_rejection_samples(det_bounds,
                                                           4,
                                                           [num_samples, 1],
                                                           dtype=np.float32,
                                                           seed=45)
     # shape of rej_weights: [num_samples, 1, 4, 4]
     chk1 = st.assert_true_mean_equal_by_dkwm(rej_weights,
                                              low=0.,
                                              high=rej_proposal_volume,
                                              expected=exact_volumes,
                                              false_fail_rate=1e-6)
     chk2 = tf1.assert_less(
         st.min_discrepancy_of_true_means_detectable_by_dkwm(
             num_samples,
             low=0.,
             high=rej_proposal_volume,
             false_fail_rate=1e-6,
             false_pass_rate=1e-6),
         # Going for about a 10% relative error
         1.1)
     with tf.control_dependencies([chk1, chk2]):
         rej_weights = tf.identity(rej_weights)
     self.evaluate(rej_weights)
Beispiel #6
0
 def testRejection2D(self):
     num_samples = int(1e5)  # Chosen for a small min detectable discrepancy
     det_bounds = np.array(
         [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5],
         dtype=np.float32)
     exact_volumes = two_by_two_volume(det_bounds)
     (rej_weights, rej_proposal_volume
      ) = corr.correlation_matrix_volume_rejection_samples(det_bounds,
                                                           2,
                                                           [num_samples, 9],
                                                           dtype=np.float32,
                                                           seed=43)
     # shape of rej_weights: [num_samples, 9, 2, 2]
     chk1 = st.assert_true_mean_equal_by_dkwm(rej_weights,
                                              low=0.,
                                              high=rej_proposal_volume,
                                              expected=exact_volumes,
                                              false_fail_rate=1e-6)
     chk2 = tf1.assert_less(
         st.min_discrepancy_of_true_means_detectable_by_dkwm(
             num_samples,
             low=0.,
             high=rej_proposal_volume,
             # Correct the false fail rate due to different broadcasting
             false_fail_rate=1.1e-7,
             false_pass_rate=1e-6),
         0.036)
     with tf.control_dependencies([chk1, chk2]):
         rej_weights = tf.identity(rej_weights)
     self.evaluate(rej_weights)
def area_range_to_index(area_range, length, max_area_width):
    """Computes the indices of each area in the area expansion.

  Args:
    area_range: tensor in shape of [batch_size, 2]
    length: a scalar tensor gives the length of the original feature space.
    max_area_width: a constant scalar.
  Returns:
    indices: area indices tensor in shape of [batch_size]
  """
    with tf.control_dependencies([
            tf.assert_equal(tf.rank(area_range), 2),
            tf.assert_equal(tf.shape(area_range)[1], 2)
    ]):
        area_range = tf.cast(area_range, tf.int32)
    target_size = area_range[:, 1] - area_range[:, 0]
    with tf.control_dependencies(
        [tf.assert_less(target_size, max_area_width + 1, summarize=100000)]):
        sizes = target_size - 1
        start_length = length
        pre_end_length = length - sizes + 1
        base = (start_length + pre_end_length) *\
            (start_length - pre_end_length + 1) // 2
        base = tf.where(tf.less_equal(target_size, 1),
                        tf.zeros_like(target_size), base)
        offset = area_range[:, 0]
        return base + offset
Beispiel #8
0
  def push(self, value, mask, name=None):
    """Pushes `value` onto the stack, advances frame of batch members in `mask`.

    In this impl, we update each thread's top-of-stack (regardless of `mask`) to
    the corresponding `value`, then advance the stack pointers of only those
    threads indicated by `mask`.

    Args:
      value: `Tensor` having the shape of a single batch of the variable.
      mask: Boolean `Tensor` of shape `[batch_size]`. Threads at `True` indices
          of `mask` have their stack frames advanced; the others remain.
      name: Optional name for this op.

    Returns:
      stack: Updated stack. Does not mutate `self`.
      asserted_value: A assertion-bound snapshot of the input `value`,
          assertions used to catch stack overflows.
    """
    with tf.name_scope(name or 'Stack.push'):
      value = tf.convert_to_tensor(value=value, name='value')
      mask = tf.convert_to_tensor(value=mask, name='mask')
      # self.stack:       [max_stack_depth * batch_size, ...]
      # self.stack_index:                   [batch_size]
      # value:                              [batch_size, ...]
      batch_size = (
          tf.compat.dimension_value(self.stack_index.shape[0]) or
          tf.shape(input=self.stack_index)[0])
      max_stack_depth = (tf.compat.dimension_value(self.stack.shape[0]) or
                         tf.shape(input=self.stack)[0]) // batch_size
      max_stack_depth_tensor = tf.convert_to_tensor(value=max_stack_depth)
      tiled_value = tf.tile(
          input=value[tf.newaxis, ...],
          multiples=tf.concat(
              [[max_stack_depth_tensor],
               tf.ones(tf.rank(value), dtype=max_stack_depth_tensor.dtype)],
              axis=0))
      update_stack_mask = tf.one_hot(
          self.stack_index,
          depth=max_stack_depth,
          axis=0,  # Stack depth x batch are both in outermost dim, stack major.
          on_value=True,
          off_value=False,
          dtype=tf.bool)
      new_stack = tf1.where(
          tf.reshape(update_stack_mask, [-1]),
          tf.reshape(tiled_value, tf.shape(input=self.stack)), self.stack)
      new_stack.set_shape(self.stack.shape)
      new_stack_index = self.stack_index + tf.cast(mask, self.stack_index.dtype)
      new_stack_index.set_shape(self.stack_index.shape)
      if self._safety_checks():
        with tf.control_dependencies(
            [tf1.assert_less(
                new_stack_index, tf.cast(
                    max_stack_depth_tensor, new_stack_index.dtype))]):
          value = tf.identity(value)
          new_stack_index = tf.identity(new_stack_index)
      return type(self)(new_stack, new_stack_index), value
 def _scan_fn(*_):
     exchange = exchange_proposed_fn(num_replica, seed)
     flat_replicas = tf.reshape(exchange, [-1])
     with tf.control_dependencies([
             tf1.assert_equal(
                 tf.size(input=flat_replicas),
                 tf.size(input=tf.unique(flat_replicas)[0])),
             tf1.assert_greater_equal(flat_replicas, 0),
             tf1.assert_less(flat_replicas, num_replica),
     ]):
         return tf.shape(input=exchange)[0]
Beispiel #10
0
def _maybe_validate_target_accept_prob(target_accept_prob, validate_args):
    """Validates that target_accept_prob is in (0, 1)."""
    if not validate_args:
        return target_accept_prob
    with tf.control_dependencies([
            tf1.assert_positive(target_accept_prob,
                                message='`target_accept_prob` must be > 0.'),
            tf1.assert_less(target_accept_prob,
                            tf.ones_like(target_accept_prob),
                            message='`target_accept_prob` must be < 1.')
    ]):
        return tf.identity(target_accept_prob)
Beispiel #11
0
    def append(self, value):
        """Appends a new tensor to the end of the buffer.

    Args:
      value: The tensor to append. Must match the shape specified in the
        initializer.

    Returns:
      An op appending the new tensor to the end of the buffer.
    """
        def _double_capacity():
            """Doubles the capacity of the current tensor buffer."""
            padding = tf.zeros_like(self._buffer, self._buffer.dtype)
            new_buffer = tf.concat([self._buffer, padding], axis=0)
            if tf.executing_eagerly():
                with tf.compat.v1.variable_scope(self._name, reuse=True):
                    self._buffer = tf.get_variable(name='buffer',
                                                   dtype=self._dtype,
                                                   initializer=new_buffer,
                                                   trainable=False)
                    return self._buffer, tf.compat.v1.assign(
                        self._capacity, tf.multiply(self._capacity, 2))
            else:
                return tf.compat.v1.assign(
                    self._buffer, new_buffer,
                    validate_shape=False), tf.compat.v1.assign(
                        self._capacity, tf.multiply(self._capacity, 2))

        update_buffer, update_capacity = tf.cond(
            pred=tf.equal(self._current_size, self._capacity),
            true_fn=_double_capacity,
            false_fn=lambda: (self._buffer, self._capacity))

        with tf.control_dependencies([update_buffer, update_capacity]):
            with tf.control_dependencies([
                    tf.assert_less(
                        self._current_size,
                        self._capacity,
                        message='Appending past end of TensorBuffer.'),
                    tf.assert_equal(
                        tf.shape(input=value),
                        tf.shape(input=self._buffer)[1:],
                        message='Appending value of inconsistent shape.')
            ]):
                with tf.control_dependencies([
                        tf.compat.v1.assign(
                            self._buffer[self._current_size, :], value)
                ]):
                    return tf.compat.v1.assign_add(self._current_size, 1)
Beispiel #12
0
def sparse_softmax_cross_entropy(labels,
                                 logits,
                                 num_classes,
                                 weights=1.0,
                                 label_smoothing=0.1):
    """Softmax cross entropy with example weights, label smoothing."""
    assert_valid_label = [
        tf.assert_greater_equal(labels, tf.cast(0, dtype=tf.int64)),
        tf.assert_less(labels, tf.cast(num_classes, dtype=tf.int64))
    ]
    with tf.control_dependencies(assert_valid_label):
        labels = tf.reshape(labels, [-1])
        dense_labels = tf.one_hot(labels, num_classes)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=dense_labels,
                                               logits=logits,
                                               weights=weights,
                                               label_smoothing=label_smoothing)
    return loss
Beispiel #13
0
def pad_to_fixed_size(data, pad_value, output_shape):
    """Pad data to a fixed length at the first dimension.

  Args:
    data: Tensor to be padded to output_shape.
    pad_value: A constant value assigned to the paddings.
    output_shape: The output shape of a 2D tensor.

  Returns:
    The Padded tensor with output_shape [max_instances_per_image, dimension].
  """
    max_instances_per_image = output_shape[0]
    dimension = output_shape[1]
    data = tf.reshape(data, [-1, dimension])
    num_instances = tf.shape(data)[0]
    msg = 'ERROR: please increase config.max_instances_per_image'
    with tf.control_dependencies(
        [tf.assert_less(num_instances, max_instances_per_image, message=msg)]):
        pad_length = max_instances_per_image - num_instances
    paddings = pad_value * tf.ones([pad_length, dimension])
    padded_data = tf.concat([data, paddings], axis=0)
    padded_data = tf.reshape(padded_data, output_shape)
    return padded_data