def user_item_node_interaction_loss(self, probs, user_node_distance,
                                        item_node_distance, user_item_distance,
                                        neg_node_ind):
        """Computes pairwise hinge based loss, as in the reference below.

    Args:
      probs: Tensor of size batch_size x tot_node_batch containing the
        probability a node is the ancestor of the positive item.
      user_node_distance: Tensor of size batch_size x tot_node_batch containing
        square of the distances between the nodes and the user.
      item_node_distance: Tensor of size batch_size x tot_node_batch containing
        square of the distances between the nodes and the positive item.
      user_item_distance: Tensor of size batch_size x 2 containing
        square of the distances between the user and the positive and negative
        items.
      neg_node_ind: Tensor of size batch_size x tot_node_batch x 2 containing
        indices of negative nodes (within the sampled batch, from the relevant
        level), in tf.gather_nd format.

    Returns:
      loss within the input_batch.
    """
        # TODO(advaw): change Idea 2 above to a real reference when possible.
        user_to_node = self.pos_and_neg_loss(
            user_node_distance, tf.gather_nd(user_node_distance, neg_node_ind))
        item_to_node = self.pos_and_neg_loss(
            item_node_distance, tf.gather_nd(item_node_distance, neg_node_ind))
        nodes_loss = tf.reduce_sum(probs * (user_to_node + item_to_node),
                                   axis=1)
        user_to_item = self.pos_and_neg_loss(user_item_distance[:, 0],
                                             user_item_distance[:, 1])
        loss = tf.reduce_mean(user_to_item + nodes_loss)
        return loss
Exemple #2
0
def _get_q_slice(q, k, ind, b=None, batch_shape=None):
    """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`."""
    q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1)
    b_updates = tf.gather_nd(q, q_ind)
    if b is None:
        return tf.scatter_nd(ind, b_updates, batch_shape)
    return tf.tensor_scatter_nd_update(b, ind, b_updates)
def contrastive_loss(similarity_matrix,
                     metric_values,
                     temperature,
                     coupling_temperature=1.0,
                     use_coupling_weights=True):
  """Contrative Loss with soft coupling."""
  logging.info('Using alternative contrastive loss.')
  metric_shape = tf.shape(metric_values)
  similarity_matrix /= temperature
  neg_logits1 = similarity_matrix

  col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32)
  pos_indices1 = tf.stack(
      (tf.range(metric_shape[0], dtype=tf.int32), col_indices), axis=1)
  pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1)

  if use_coupling_weights:
    metric_values /= coupling_temperature
    coupling = tf.exp(-metric_values)
    pos_weights1 = -tf.gather_nd(metric_values, pos_indices1)
    pos_logits1 += pos_weights1
    negative_weights = tf.math.log((1.0 - coupling) + EPS)
    neg_logits1 += tf.tensor_scatter_nd_update(negative_weights, pos_indices1,
                                               pos_weights1)
  neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1)
  return tf.reduce_mean(neg_logits1 - pos_logits1)
Exemple #4
0
def _piecewise_constant_integrate(x1, x2, jump_locations, values, batch_rank):
  """Integrates piecewise constant function between `x1` and `x2`."""
  # Initializer already verified that `jump_locations` and `values` have the
  # same shape.
  # Expand batch size to one if there is no batch shape.
  if x1.shape.as_list()[:batch_rank]:
    no_batch_shape = False
  else:
    no_batch_shape = True
    x1 = tf.expand_dims(x1, 0)
    x2 = tf.expand_dims(x2, 0)
  if not jump_locations.shape.as_list()[:-1]:
    jump_locations = tf.expand_dims(jump_locations, 0)
    values = tf.expand_dims(values, 0)
    batch_rank += 1

  # Compute the index matrix that is later used for `tf.gather_nd`.
  index_matrix = _prepare_index_matrix(
      x1.shape.as_list()[:-1], x1.shape.as_list()[-1], tf.int32)
  # Compute integral values at the jump locations starting from the first jump
  # location.
  event_shape = values.shape[(batch_rank+1):]
  num_data_points = values.shape.as_list()[batch_rank]
  diff = jump_locations[..., 1:] - jump_locations[..., :-1]
  # Broadcast `diff` to the shape of
  # `batch_shape + [num_data_points - 2] + [1] * sample_rank`.
  for _ in event_shape:
    diff = tf.expand_dims(diff, -1)
  slice_indices = batch_rank * [slice(None)]
  slice_indices += [slice(1, num_data_points - 1)]
  integrals = tf.cumsum(values[slice_indices] * diff, batch_rank)
  # Pad integrals with zero values on left and right.
  batch_shape = integrals.shape.as_list()[:batch_rank]
  zeros = tf.zeros(batch_shape + [1] + event_shape, dtype=integrals.dtype)
  integrals = tf.concat([zeros, integrals, zeros], axis=batch_rank)
  # Get jump locations and values and the integration end points
  value1, jump_location1, indices_nd1 = _get_indices_and_values(
      x1, index_matrix, jump_locations, values, 'left', batch_rank)
  value2, jump_location2, indices_nd2 = _get_indices_and_values(
      x2, index_matrix, jump_locations, values, 'right', batch_rank)
  integrals1 = tf.gather_nd(integrals, indices_nd1)
  integrals2 = tf.gather_nd(integrals, indices_nd2)
  # Broadcast `x1`, `x2`, `jump_location1`, `jump_location2` to the shape
  # `batch_shape + [num_points] + [1] * sample_rank`.
  for _ in event_shape:
    x1 = tf.expand_dims(x1, -1)
    x2 = tf.expand_dims(x2, -1)
    jump_location1 = tf.expand_dims(jump_location1, -1)
    jump_location2 = tf.expand_dims(jump_location2, -1)
  # Compute the value of the integral.
  res = ((jump_location1 - x1) * value1
         + (x2 - jump_location2) * value2
         + integrals2 - integrals1)
  if no_batch_shape:
    return tf.squeeze(res, 0)
  else:
    return res
  def _get_coordinatewise_learning_rate(self, grad, var):
    # Compute the learning rate using a moving average for the diagonal of BB^T
    avg_first = self.get_slot(var, 'first_moment')
    avg_second = self.get_slot(var, 'second_moment')
    decay_tensor = tf.cast(self._decay_tensor, var.dtype)
    batch_size = tf.cast(self._batch_size_tensor, var.dtype)

    # Create an estimator for the moving average of gradient mean and variance
    # via Welford's algorithm
    if isinstance(grad, tf.Tensor):
      delta = grad - avg_first
      first_moment_update = avg_first.assign_add(
          delta * tf.where(
              self.iterations < 1,
              dtype_util.as_numpy_dtype(var.dtype)(1.),
              1. - decay_tensor))

      with tf.control_dependencies([first_moment_update]):
        second_moment_update = avg_second.assign_add(
            tf.cast(self.iterations < 1, var.dtype) * -(1. - decay_tensor) *
            (avg_second - decay_tensor * tf.square(delta)))
      diag_preconditioner = distribution_util.with_dependencies(
          [second_moment_update],
          tf.clip_by_value(avg_second, 1e-12, 1e12))
    elif isinstance(grad, tf.IndexedSlices):
      delta = grad.values - tf.gather_nd(avg_first, grad.indices)
      first_moment_update = tf.compat.v1.scatter_add(
          avg_first, grad.indices,
          delta * tf.where(
              self.iterations < 1,
              dtype_util.as_numpy_dtype(var.dtype)(1.),
              1. - decay_tensor))

      with tf.control_dependencies([first_moment_update]):
        avg_second = tf.compat.v1.scatter_add(
            avg_second, grad.indices,
            tf.cast(self.iterations < 1, var.dtype) * -(1. - decay_tensor) *
            (tf.gather_nd(avg_second, grad.indices) -
             decay_tensor * tf.square(delta)))
        avg_second = tf.gather_nd(avg_second, grad.indices)
        # TODO(b/70783772): Needs dtype specific clipping.
        diag_preconditioner = tf.clip_by_value(avg_second, 1e-12, 1e12)
    else:
      raise tf.errors.InvalidArgumentError(
          None, None, 'grad must of type Tensor or IndexedSlice')

    diag_preconditioner *= batch_size

    if self._use_single_learning_rate:
      diag_preconditioner = tf.reduce_mean(diag_preconditioner)

    # From Theorem 2 Corollary 1 of Mandt et al. 2017
    return 2. * batch_size / (
        tf.cast(self._total_num_examples, var.dtype.base_dtype) *
        diag_preconditioner)
Exemple #6
0
 def _retrieve_from_cache(
     self, query_embeddings,
     cache):
   sorted_data_sources = sorted(cache.keys())
   all_query_embeddings = util.cross_replica_concat(query_embeddings, axis=0)
   num_replicas = tf.distribute.get_replica_context().num_replicas_in_sync
   # Performs approximate top k across replicas.
   if self.top_k:
     top_k_per_replica = self.top_k // num_replicas
   else:
     top_k_per_replica = self.top_k
   retrieval_return = _retrieve_from_caches(all_query_embeddings, cache,
                                            self._retrieval_fn,
                                            self.embedding_key, self.data_keys,
                                            sorted_data_sources,
                                            self.score_transform,
                                            top_k_per_replica)
   # We transfer all queries to all replica and retrieve from every shard.
   all_queries_local_weight = tf.math.reduce_logsumexp(
       retrieval_return.scores, axis=1)
   local_queries_global_weights = _get_local_elements_global_data(
       all_queries_local_weight, num_replicas)
   local_queries_all_retrieved_data = {}
   for key in retrieval_return.retrieved_data:
     local_queries_all_retrieved_data[key] = _get_local_elements_global_data(
         retrieval_return.retrieved_data[key], num_replicas)
   local_queries_all_retrieved_embeddings = _get_local_elements_global_data(
       retrieval_return.retrieved_cache_embeddings, num_replicas)
   # We then sample a shard index proportional to its total weight.
   # This allows us to do Gumbel-Max sampling without modifying APIs.
   selected_replica = self._retrieval_fn(local_queries_global_weights)
   selected_replica = tf.stop_gradient(selected_replica)
   num_elements = tf.shape(selected_replica)[0]
   batch_indices = tf.range(num_elements)
   batch_indices = tf.cast(batch_indices, tf.int64)
   batch_indices = tf.expand_dims(batch_indices, axis=1)
   selected_replica_with_batch = tf.concat([batch_indices, selected_replica],
                                           axis=1)
   retrieved_data = {
       k: tf.gather_nd(v, selected_replica_with_batch)
       for k, v in local_queries_all_retrieved_data.items()
   }
   retrieved_cache_embeddings = tf.gather_nd(
       local_queries_all_retrieved_embeddings, selected_replica_with_batch)
   return _RetrievalReturn(
       retrieved_data=retrieved_data,
       scores=local_queries_global_weights,
       retrieved_indices=None,
       retrieved_cache_embeddings=retrieved_cache_embeddings)
Exemple #7
0
def _get_indices_and_values(x, index_matrix, jump_locations, values, side,
                            batch_rank):
  """Computes values and jump locations of the piecewise constant function.

  Given `jump_locations` and the `values` on the corresponding segments of the
  piecewise constant function, the function identifies the nearest jump to `x`
  from the right or left (which is determined by the `side` argument) and the
  corresponding value of the piecewise constant function at `x`

  Args:
    x: A real `Tensor` of shape `batch_shape + [num_points]`. Points at which
      the function has to be evaluated.
    index_matrix: An `int32` `Tensor` of shape
      `batch_shape + [num_points] + [len(batch_shape)]` such that if
      `batch_shape = [i1, .., in]`, then for all `j1, ..., jn, l`,
      `index_matrix[j1,..,jn, l] = [j1, ..., jn]`.
    jump_locations: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points]`. The locations where the function
      changes its values. Note that the values are expected to be ordered
      along the last dimension.
    values: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points + 1]`. Defines `values[..., i]` on
      `jump_locations[..., i - 1], jump_locations[..., i]`.
    side: A Python string. Whether the function is left- or right- continuous.
      The corresponding values for side should be `left` and `right`.
    batch_rank: A Python scalar stating the batch rank of `x`.

  Returns:
    A tuple of three `Tensor` of the same `dtype` as `x` and shapes
    `batch_shape + [num_points] + event_shape`, `batch_shape + [num_points]`,
    and `batch_shape + [num_points] + [2 * len(batch_shape)]`. The `Tensor`s
    correspond to the values, jump locations at `x`, and the corresponding
    indices used to obtain jump locations via `tf.gather_nd`.
  """
  indices = tf.searchsorted(jump_locations, x, side=side)
  num_data_points = tf.shape(values)[batch_rank] - 2
  if side == 'right':
    indices_jump = indices - 1
    indices_jump = tf.maximum(indices_jump, 0)
  else:
    indices_jump = tf.minimum(indices, num_data_points)
  indices_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices, -1)], -1)
  indices_jump_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices_jump, -1)], -1)
  value = tf.gather_nd(values, indices_nd)
  jump_location = tf.gather_nd(jump_locations, indices_jump_nd)
  return value, jump_location, indices_jump_nd
Exemple #8
0
def approximate_top_k_with_indices(negative_scores, k):
    """Approximately mines the top k highest scoreing negatives with indices.

  This function groups the negative scores into num_negatives / k groupings and
  returns the highest scoring element from each group. It also returns the index
  where the selected elements were found in the score matrix.

  Args:
    negative_scores: A matrix with the scores of the negative elements.
    k: The number of negatives to mine.

  Returns:
    The tuple (top_k_scores, top_k_indices), where top_k_indices describes the
    index of the mined elements in the given score matrix.
  """
    bs = tf.shape(negative_scores)[0]
    num_elem = tf.shape(negative_scores)[1]
    batch_indices = tf.range(num_elem)
    indices = tf.tile(tf.expand_dims(batch_indices, axis=0), multiples=[bs, 1])
    grouped_negative_scores = tf.reshape(negative_scores, [bs * k, -1])
    grouped_batch_indices = tf.range(tf.shape(grouped_negative_scores)[0])
    grouped_top_k_scores, grouped_top_k_indices = tf.math.top_k(
        grouped_negative_scores)
    grouped_top_k_indices = tf.squeeze(grouped_top_k_indices, axis=1)
    gather_indices = tf.stack([grouped_batch_indices, grouped_top_k_indices],
                              axis=1)
    grouped_indices = tf.reshape(indices, [bs * k, -1])
    grouped_top_k_indices = tf.gather_nd(grouped_indices, gather_indices)
    top_k_indices = tf.reshape(grouped_top_k_indices, [bs, k])
    top_k_scores = tf.reshape(grouped_top_k_scores, [bs, k])
    return top_k_scores, top_k_indices
Exemple #9
0
 def get_slice(x, encoding):
     if optimize_for_tpu:
         return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                   encoding,
                                   axis=-1)
     else:
         return tf.gather_nd(x, encoding)
Exemple #10
0
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]).count
            return tf.argmax(count)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
        else:
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = tf.concat([
                self._batch_shape_tensor(samples),
                self._event_shape_tensor(samples)
            ],
                                   axis=0)
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode,
                            flattened_samples,
                            fn_output_signature=tf.int64)
        full_indices = tf.stack(
            [tf.range(tf.shape(indices)[0]),
             tf.cast(indices, tf.int32)],
            axis=1)

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
Exemple #11
0
  def _mode(self, samples=None):
    # Samples count can vary by batch member. Use map_fn to compute mode for
    # each batch separately.
    def _get_mode(samples):
      # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
      count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count
      return tf.argmax(count)

    if samples is None:
      samples = tf.convert_to_tensor(self._samples)
    num_samples = self._compute_num_samples(samples)

    # Flatten samples for each batch.
    if self._event_ndims == 0:
      flattened_samples = tf.reshape(samples, [-1, num_samples])
      mode_shape = self._batch_shape_tensor(samples)
    else:
      event_size = tf.reduce_prod(self._event_shape_tensor(samples))
      mode_shape = tf.concat(
          [self._batch_shape_tensor(samples),
           self._event_shape_tensor(samples)],
          axis=0)
      flattened_samples = tf.reshape(samples, [-1, num_samples, event_size])

    indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
    full_indices = tf.stack(
        [tf.range(tf.shape(indices)[0]),
         tf.cast(indices, tf.int32)], axis=1)

    mode = tf.gather_nd(flattened_samples, full_indices)
    return tf.reshape(mode, mode_shape)
Exemple #12
0
def dense_to_sparse(x, ignore_value=None, name=None):
    """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells.

  Args:
    x: A `Tensor`.
    ignore_value: Entries in `x` equal to this value will be
      absent from the return `SparseTensor`. If `None`, default value of
      `x` dtype will be used (e.g. '' for `str`, 0 for `int`).
    name: Python `str` prefix for ops created by this function.

  Returns:
    sparse_x: A `tf.SparseTensor` with the same shape as `x`.

  Raises:
    ValueError: when `x`'s rank is `None`.
  """
    # Copied (with modifications) from:
    # tensorflow/contrib/layers/python/ops/sparse_ops.py.
    with tf.name_scope(name or 'dense_to_sparse'):
        x = tf.convert_to_tensor(x, name='x')
        if ignore_value is None:
            if dtype_util.base_dtype(x.dtype) == tf.string:
                # Exception due to TF strings are converted to numpy objects by default.
                ignore_value = ''
            else:
                ignore_value = dtype_util.as_numpy_dtype(x.dtype)(0)
            ignore_value = tf.cast(ignore_value, x.dtype, name='ignore_value')
        indices = tf.where(tf.not_equal(x, ignore_value), name='indices')
        return tf.SparseTensor(indices=indices,
                               values=tf.gather_nd(x, indices, name='values'),
                               dense_shape=tf.shape(x,
                                                    out_type=tf.int64,
                                                    name='dense_shape'))
Exemple #13
0
def top_k_boxes(boxes, scores, k):
  """Sort and select top k boxes according to the scores.

  Args:
    boxes: a tensor of shape [batch_size, N, 4] representing the coordiante of
      the boxes. N is the number of boxes per image.
    scores: a tensor of shsape [batch_size, N] representing the socre of the
      boxes.
    k: an integer or a tensor indicating the top k number.

  Returns:
    selected_boxes: a tensor of shape [batch_size, k, 4] representing the
      selected top k box coordinates.
    selected_scores: a tensor of shape [batch_size, k] representing the selected
      top k box scores.
  """
  with tf.name_scope('top_k_boxes'):
    selected_scores, top_k_indices = tf.nn.top_k(scores, k=k, sorted=True)

    batch_size, _ = scores.get_shape().as_list()
    if batch_size == 1:
      selected_boxes = tf.squeeze(
          tf.gather(boxes, top_k_indices, axis=1), axis=1)
    else:
      top_k_indices_shape = tf.shape(top_k_indices)
      batch_indices = (
          tf.expand_dims(tf.range(top_k_indices_shape[0]), axis=-1) *
          tf.ones([1, top_k_indices_shape[-1]], dtype=tf.int32))
      gather_nd_indices = tf.stack([batch_indices, top_k_indices], axis=-1)
      selected_boxes = tf.gather_nd(boxes, gather_nd_indices)

    return selected_boxes, selected_scores
  def _build_target_quantile_values_op(self):
    """Build an op used as a target for return values at given quantiles.

    Returns:
      An op calculating the target quantile return.
    """
    batch_size = tf.shape(self._replay.rewards)[0]

    # Calculate AL modified rewards.
    replay_action_one_hot = tf.one_hot(
        self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
    replay_target_q = tf.reduce_max(
        self._replay_target_q_values,
        axis=1,
        name='replay_chosen_target_q')
    replay_target_q_al = tf.reduce_sum(
        replay_action_one_hot * self._replay_target_q_values,
        axis=1,
        name='replay_chosen_target_q_al')

    if self._clip > 0.:
      al_bonus = self._alpha * tf.clip_by_value(
          (replay_target_q_al - replay_target_q),
          -self._clip, self._clip)
    else:
      al_bonus = self._alpha * (
          replay_target_q_al - replay_target_q)

    # Shape of rewards: (num_tau_prime_samples x batch_size) x 1.
    rewards = (self._replay.rewards + al_bonus)[:, None]
    rewards = tf.tile(rewards, [self.num_tau_prime_samples, 1])

    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
    # Incorporate terminal state to discount factor.
    # size of gamma_with_terminal: (num_tau_prime_samples x batch_size) x 1.
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = tf.tile(gamma_with_terminal[:, None],
                                  [self.num_tau_prime_samples, 1])

    # Get the indices of the maximum Q-value across the action dimension.
    # Shape of replay_next_qt_argmax: (num_tau_prime_samples x batch_size) x 1.

    replay_next_qt_argmax = tf.tile(
        self._replay_next_qt_argmax[:, None], [self.num_tau_prime_samples, 1])

    # Shape of batch_indices: (num_tau_prime_samples x batch_size) x 1.
    batch_indices = tf.cast(tf.range(
        self.num_tau_prime_samples * batch_size)[:, None], tf.int64)

    # Shape of batch_indexed_target_values:
    # (num_tau_prime_samples x batch_size) x 2.
    batch_indexed_target_values = tf.concat(
        [batch_indices, replay_next_qt_argmax], axis=1)

    # Shape of next_target_values: (num_tau_prime_samples x batch_size) x 1.
    target_quantile_values = tf.gather_nd(
        self._replay_net_target_quantile_values,
        batch_indexed_target_values)[:, None]

    return rewards + gamma_with_terminal * target_quantile_values
Exemple #15
0
def _piecewise_constant_function(x, jump_locations, values,
                                 batch_rank, side='left'):
  """Computes value of the piecewise constant function."""
  # Initializer already verified that `jump_locations` and `values` have the
  # same shape
  batch_shape = jump_locations.shape.as_list()[:-1]
  # Check that the batch shape of `x` is the same as of `jump_locations` and
  # `values`
  batch_shape_x = x.shape.as_list()[:batch_rank]
  if batch_shape_x != batch_shape:
    raise ValueError('Batch shape of `x` is {1} but should be {0}'.format(
        batch_shape, batch_shape_x))
  if x.shape.as_list()[:batch_rank]:
    no_batch_shape = False
  else:
    no_batch_shape = True
    x = tf.expand_dims(x, 0)
  # Expand batch size to one if there is no batch shape
  if not batch_shape:
    jump_locations = tf.expand_dims(jump_locations, 0)
    values = tf.expand_dims(values, 0)
  indices = tf.searchsorted(jump_locations, x, side=side)
  index_matrix = _prepare_index_matrix(
      indices.shape.as_list()[:-1], indices.shape.as_list()[-1], indices.dtype)
  indices_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices, -1)], -1)
  res = tf.gather_nd(values, indices_nd)
  if no_batch_shape:
    return tf.squeeze(res, 0)
  else:
    return res
Exemple #16
0
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
def _analytic_valuation(expiries, floating_leg_start_times,
                        floating_leg_end_times, fixed_leg_payment_times,
                        fixed_leg_daycount_fractions, fixed_leg_coupon,
                        reference_rate_fn, dim, mean_reversion, volatility,
                        notional, is_payer_swaption, output_shape,
                        dtype, name):
  """Helper function for analytic valuation."""
  # The below inputs are needed for midcurve swaptions
  del floating_leg_start_times, floating_leg_end_times
  with tf.name_scope(name):
    is_call_options = tf.where(is_payer_swaption,
                               tf.convert_to_tensor(False, dtype=tf.bool),
                               tf.convert_to_tensor(True, dtype=tf.bool))

    model = vector_hull_white.VectorHullWhiteModel(
        dim,
        mean_reversion,
        volatility,
        initial_discount_rate_fn=reference_rate_fn,
        dtype=dtype)
    coefficients = fixed_leg_daycount_fractions * fixed_leg_coupon
    jamshidian_coefficients = tf.concat([
        -coefficients[..., :-1],
        tf.expand_dims(-1.0 - coefficients[..., -1], axis=-1)], axis=-1)

    breakeven_bond_option_strikes = _jamshidian_decomposition(
        model, expiries,
        fixed_leg_payment_times, jamshidian_coefficients, dtype,
        name=name + '_jamshidian_decomposition')

    bond_strike_rank = breakeven_bond_option_strikes.shape.rank
    perm = [bond_strike_rank-1] + [x for x in range(0, bond_strike_rank - 1)]
    breakeven_bond_option_strikes = tf.transpose(
        breakeven_bond_option_strikes, perm=perm)
    bond_option_prices = zcb.bond_option_price(
        strikes=breakeven_bond_option_strikes,
        expiries=expiries,
        maturities=fixed_leg_payment_times,
        discount_rate_fn=reference_rate_fn,
        dim=dim,
        mean_reversion=mean_reversion,
        volatility=volatility,
        is_call_options=is_call_options,
        use_analytic_pricing=True,
        dtype=dtype,
        name=name + '_bond_option')

    # Now compute P(T0, TN) + sum_i (c_i * tau_i * P(T0, Ti))
    # bond_option_prices.shape = [dim] + batch_shape + [m] + [dim], where `m`
    # denotes the number of fixed payments for the underlying swaps.
    swaption_values = (
        tf.reduce_sum(
            bond_option_prices * tf.expand_dims(coefficients, axis=-1),
            axis=-2) + bond_option_prices[..., -1, :])
    swaption_shape = swaption_values.shape
    gather_index = _prepare_swaption_indices(swaption_shape.as_list())
    swaption_values = tf.reshape(
        tf.gather_nd(swaption_values, gather_index), output_shape)
    return notional * swaption_values
Exemple #18
0
def hard_quantile_normalization(inputs, quantiles):
  """Applies the quantile function `quantiles` to the inputs."""
  n_rows = inputs.shape[0]
  rows = tf.range(n_rows)[:, tf.newaxis] * tf.ones_like(inputs, dtype=tf.int32)
  indices = tf.stack(
      [rows, tf.argsort(tf.argsort(inputs, axis=1), axis=1)], axis=-1)
  ordered_quantiles = tf.gather_nd(quantiles, tf.reshape(indices, (-1, 2)))
  return tf.reshape(ordered_quantiles, inputs.shape)
Exemple #19
0
def _retrieve_from_caches(query_embeddings,
                          cache,
                          retrieval_fn,
                          embedding_key,
                          data_keys,
                          sorted_data_sources,
                          score_transform=None,
                          top_k=None):
    """Retrieve elements from a cache with the given retrieval function."""
    all_embeddings = _batch_concat_with_no_op([
        cache[data_source].data[embedding_key]
        for data_source in sorted_data_sources
    ])
    all_data = {}
    for key in data_keys:
        all_data[key] = _batch_concat_with_no_op([
            cache[data_source].data[key] for data_source in sorted_data_sources
        ])
    scores = _score_documents(query_embeddings,
                              all_embeddings,
                              score_transform=score_transform,
                              all_pairs=True)
    if top_k:
        scores, top_k_indices = util.approximate_top_k_with_indices(
            scores, top_k)
        top_k_indices = tf.cast(top_k_indices, dtype=tf.int64)
        retrieved_indices = retrieval_fn(scores)
        batch_index = tf.expand_dims(tf.range(tf.shape(retrieved_indices)[0],
                                              dtype=tf.int64),
                                     axis=1)
        retrieved_indices_with_batch_index = tf.concat(
            [batch_index, retrieved_indices], axis=1)
        retrieved_indices = tf.gather_nd(top_k_indices,
                                         retrieved_indices_with_batch_index)
        retrieved_indices = tf.expand_dims(retrieved_indices, axis=1)
    else:
        retrieved_indices = retrieval_fn(scores)
    retrieved_indices = tf.stop_gradient(retrieved_indices)
    retrieved_data = {
        k: tf.gather_nd(v, retrieved_indices)
        for k, v in all_data.items()
    }
    retrieved_cache_embeddings = tf.gather_nd(all_embeddings,
                                              retrieved_indices)
    return _RetrievalReturn(retrieved_data, scores, retrieved_indices,
                            retrieved_cache_embeddings)
Exemple #20
0
    def __call__(self,
                 logits,
                 scaled_labels,
                 classes,
                 category_loss=True,
                 mse_loss=False):
        """Compute instance segmentation loss.

    Args:
      logits: A Tensor of shape [batch_size * num_points, height, width,
        num_classes]. The logits are not necessarily between 0 and 1.
      scaled_labels: A float16 Tensor of shape [batch_size, num_instances,
          mask_size, mask_size], where mask_size =
          mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
          for coarse masks and shape priors.
      classes: A int tensor of shape [batch_size, num_instances].
      category_loss: use class specific mask prediction or not.
      mse_loss: use mean square error for mask loss or not

    Returns:
      mask_loss: an float tensor representing total mask classification loss.
      iou: a float tensor representing the IoU between target and prediction.
    """
        classes = tf.reshape(classes, [-1])
        _, _, height, width = scaled_labels.get_shape().as_list()
        scaled_labels = tf.reshape(scaled_labels, [-1, height, width])

        if not category_loss:
            logits = logits[:, :, :, 0]
        else:
            logits = tf.transpose(a=logits, perm=(0, 3, 1, 2))
            gather_idx = tf.stack(
                [tf.range(tf.size(input=classes)), classes - 1], axis=1)
            logits = tf.gather_nd(logits, gather_idx)

        # Ignore loss on empty mask targets.
        valid_labels = tf.reduce_any(input_tensor=tf.greater(scaled_labels, 0),
                                     axis=[1, 2])
        if mse_loss:
            # Logits are probabilities in the case of shape prior prediction.
            logits *= tf.reshape(tf.cast(valid_labels, logits.dtype),
                                 [-1, 1, 1])
            weighted_loss = tf.nn.l2_loss(scaled_labels - logits)
            probs = logits
        else:
            weighted_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=scaled_labels, logits=logits)
            probs = tf.sigmoid(logits)
            weighted_loss *= tf.reshape(
                tf.cast(valid_labels, weighted_loss.dtype), [-1, 1, 1])

        iou = tf.reduce_sum(
            input_tensor=tf.minimum(scaled_labels, probs)) / tf.reduce_sum(
                input_tensor=tf.maximum(scaled_labels, probs))
        mask_loss = tf.reduce_sum(input_tensor=weighted_loss) / tf.reduce_sum(
            input_tensor=scaled_labels)
        return tf.cast(mask_loss, tf.float32), tf.cast(iou, tf.float32)
 def model():
   raining = yield Root(tfd.Bernoulli(probs=0.2, dtype=tf.int32))
   sprinkler_prob = [0.4, 0.01]
   sprinkler_prob = tf.gather(sprinkler_prob, raining)
   sprinkler = yield tfd.Bernoulli(probs=sprinkler_prob, dtype=tf.int32)
   grass_wet_prob = [[0.0, 0.8],
                     [0.9, 0.99]]
   grass_wet_prob = tf.gather_nd(grass_wet_prob, _stack(sprinkler, raining))
   grass_wet = yield tfd.Bernoulli(probs=grass_wet_prob, dtype=tf.int32)
    def _classify_and_fuse_detection_priors(self, uniform_priors,
                                            detection_prior_classes,
                                            crop_features):
        """Classify the uniform prior by predicting the shape modes.

    Classify the object crop features into K modes of the clusters for each
    category.

    Args:
      uniform_priors: A float Tensor of shape [batch_size, num_instances,
        mask_size, mask_size] representing the uniform detection priors.
      detection_prior_classes: A int Tensor of shape [batch_size, num_instances]
        of detection class ids.
      crop_features: A float Tensor of shape [batch_size * num_instances,
        mask_size, mask_size, num_channels].

    Returns:
      shape_weights: A float Tensor of shape
        [batch_size * num_instances, num_clusters] representing the classifier
        output probability over all possible shapes.
    """
        location_detection_priors = tf.reshape(
            uniform_priors,
            [-1, self._mask_crop_size, self._mask_crop_size, 1])
        # Generate image embedding to shape.
        fused_shape_features = crop_features * location_detection_priors

        shape_embedding = tf.reduce_mean(input_tensor=fused_shape_features,
                                         axis=(1, 2))
        if not self._use_category_for_mask:
            # TODO(weicheng) use custom op for performance
            shape_logits = tf.keras.layers.Dense(
                self._num_clusters,
                kernel_initializer=tf.keras.initializers.RandomNormal(
                    stddev=0.01))(shape_embedding)
            shape_logits = tf.reshape(
                shape_logits, [-1, self._num_clusters]) / self._temperature
            shape_weights = tf.nn.softmax(shape_logits,
                                          name='shape_prior_weights')
        else:
            shape_logits = tf.keras.layers.Dense(
                self._mask_num_classes * self._num_clusters,
                kernel_initializer=tf.keras.initializers.RandomNormal(
                    stddev=0.01))(shape_embedding)
            shape_logits = tf.reshape(
                shape_logits, [-1, self._mask_num_classes, self._num_clusters])
            training_classes = tf.reshape(detection_prior_classes, [-1])
            class_idx = tf.stack([
                tf.range(tf.size(input=training_classes)), training_classes - 1
            ],
                                 axis=1)
            shape_logits = tf.gather_nd(shape_logits,
                                        class_idx) / self._temperature
            shape_weights = tf.nn.softmax(shape_logits,
                                          name='shape_prior_weights')

        return shape_weights
 def seperation_loss(self, model, node_tensor, neg_node_ind):
     """Calculates -d(n,n')^2."""
     neg_nodes_actual_ind = tf.gather_nd(node_tensor, neg_node_ind)
     nodes = model.get_batch_nodes(node_tensor)
     neg_nodes = model.get_batch_nodes(neg_nodes_actual_ind)
     node_neg_node_dist = hyp_utils.hyp_distance(nodes, neg_nodes,
                                                 tf.math.softplus(model.c))
     seperation = tf.reduce_mean(-model.square_distance(node_neg_node_dist))
     return seperation
Exemple #24
0
 def _get_new_item_indices(self, age, updates, mask=None):
     any_update = list(updates.values())[0]
     num_updates = tf.shape(any_update)[0]
     _, new_item_indices = tf.math.top_k(age, num_updates)
     if mask is not None:
         mask = tf.cast(mask, dtype=tf.int32)
         unmasked_indices = (tf.cumsum(mask) - 1) * mask
         unmasked_indices = tf.expand_dims(unmasked_indices, axis=1)
         new_item_indices = tf.gather_nd(new_item_indices, unmasked_indices)
     return new_item_indices
Exemple #25
0
 def _dense_to_sparse(self, student_ids, question_ids, dense_correct):
   test_y_idx = np.stack([student_ids, question_ids], axis=-1)
   # Need to tile the indices across the batch, for gather_nd.
   batch_shape = ps.shape(dense_correct)[:-2]
   broadcast_shape = ps.concat([ps.ones_like(batch_shape), test_y_idx.shape],
                               axis=-1)
   test_y_idx = tf.reshape(test_y_idx, broadcast_shape)
   test_y_idx = tf.tile(test_y_idx, ps.concat([batch_shape, [1, 1]], axis=-1))
   return tf.gather_nd(
       dense_correct, test_y_idx, batch_dims=ps.size(batch_shape))
Exemple #26
0
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape):
    """Determine the beginning of the interval, `a`."""
    # if i < 0: a = q0[j]
    i_lt_0 = tf.less(i, 0)
    ind = tf.where(i_lt_0)
    a_update = tf.gather_nd(q0_j, ind)
    a = tf.scatter_nd(ind, a_update, batch_shape)

    # elif j < 0: a = q1[i]
    j_lt_0 = tf.less(j, 0)
    ind = tf.where(j_lt_0)
    a_update = tf.gather_nd(q1_i, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)

    # else: a = max(q0[j], q1[i])
    ind = tf.where(~(i_lt_0 | j_lt_0))
    q_max = tf.maximum(q0_j, q1_i)
    a_update = tf.gather_nd(q_max, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)
    return a
def _compute_2d_sparsemax(logits):
    """Performs the sparsemax operation when axis=-1."""
    shape_op = tf.shape(logits)
    obs = tf.math.reduce_prod(shape_op[:-1])
    dims = shape_op[-1]

    # In the paper, they call the logits z.
    # The mean(logits) can be subtracted from logits to make the algorithm
    # more numerically stable. the instability in this algorithm comes mostly
    # from the z_cumsum. Subtacting the mean will cause z_cumsum to be close
    # to zero. However, in practise the numerical instability issues are very
    # minor and subtacting the mean causes extra issues with inf and nan
    # input.
    # Reshape to [obs, dims] as it is almost free and means the remanining
    # code doesn't need to worry about the rank.
    z = tf.reshape(logits, [obs, dims])

    # sort z
    z_sorted, _ = tf.nn.top_k(z, k=dims)

    # calculate k(z)
    z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
    k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
    z_check = 1 + k * z_sorted > z_cumsum
    # because the z_check vector is always [1,1,...1,0,0,...0] finding the
    # (index + 1) of the last `1` is the same as just summing the number of 1.
    k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)

    # calculate tau(z)
    # If there are inf values or all values are -inf, the k_z will be zero,
    # this is mathematically invalid and will also cause the gather_nd to fail.
    # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
    # fixed later (see p_safe) by returning p = nan. This results in the same
    # behavior as softmax.
    k_z_safe = tf.math.maximum(k_z, 1)
    indices = tf.stack(
        [tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
    tau_sum = tf.gather_nd(z_cumsum, indices)
    tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)

    # calculate p
    p = tf.math.maximum(tf.cast(0, logits.dtype),
                        z - tf.expand_dims(tau_z, -1))
    # If k_z = 0 or if z = nan, then the input is invalid
    p_safe = tf.where(
        tf.expand_dims(tf.math.logical_or(tf.math.equal(k_z, 0),
                                          tf.math.is_nan(z_cumsum[:, -1])),
                       axis=-1),
        tf.fill([obs, dims], tf.cast(float('nan'), logits.dtype)), p)

    # Reshape back to original size
    p_safe = tf.reshape(p_safe, shape_op)
    return p_safe
def ensemble_crossentropy(labels, logits, ensemble_size):
  """Return ensemble cross-entropy."""
  tile_logp = tf.nn.log_softmax(logits, axis=-1)
  # (1,ens_size*batch,n_classes)
  tile_logp = tf.expand_dims(tile_logp, 0)
  tile_logp = tf.concat(
      tf.split(tile_logp, ensemble_size, axis=1), 0)
  logp = tfp.math.reduce_logmeanexp(tile_logp, axis=0)

  mask = tf.stack([
      tf.range(len(labels), dtype=tf.int32),
      tf.cast(labels, dtype=tf.int32)], axis=1)
  return -tf.reduce_mean(tf.gather_nd(logp, mask))
Exemple #29
0
    def compute_logits(self,
                       context_features=None,
                       example_features=None,
                       training=None,
                       mask=None):
        """Scores context and examples to return a score per document.

    Args:
      context_features: (dict) context feature names to 2D tensors of shape
        [batch_size, feature_dims].
      example_features: (dict) example feature names to 3D tensors of shape
        [batch_size, list_size, feature_dims].
      training: (bool) whether in train or inference mode.
      mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which
        is True for a valid example and False for invalid one. If mask is None,
        all entries are valid.

    Returns:
      (tf.Tensor) A score tensor of shape [batch_size, list_size].
    """
        tensor = next(six.itervalues(example_features))
        batch_size = tf.shape(tensor)[0]
        list_size = tf.shape(tensor)[1]
        if mask is None:
            mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool)
        nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask)

        # Expand query features to be of [batch_size, list_size, ...].
        large_batch_context_features = {}
        for name, tensor in six.iteritems(context_features):
            x = tf.expand_dims(input=tensor, axis=1)
            x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1)
            large_batch_context_features[name] = utils.reshape_first_ndims(
                x, 2, [batch_size * list_size])

        large_batch_example_features = {}
        for name, tensor in six.iteritems(example_features):
            # Replace invalid example features with valid ones.
            padded_tensor = tf.gather_nd(tensor, nd_indices)
            large_batch_example_features[name] = utils.reshape_first_ndims(
                padded_tensor, 2, [batch_size * list_size])

        # Get scores for large batch.
        scores = self.score(context_features=large_batch_context_features,
                            example_features=large_batch_example_features,
                            training=training)
        logits = tf.reshape(scores, shape=[batch_size, list_size])

        # Apply nd_mask to zero out invalid entries.
        logits = tf.where(nd_mask, logits, tf.zeros_like(logits))
        return logits
Exemple #30
0
 def _get_reset_state_indices():
     reset_indices_obs = tf.nest.map_structure(
         lambda t: tf.gather_nd(t, reset_indices), observation)
     # shape: [num_indices_to_reset, ...]
     reset_indices_state = self.get_initial_state(
         reset_indices_obs, batch_size=tf.shape(reset_indices)[0])
     # Scatter tensors in `reset_indices_state` to shape: [num_timesteps,
     # batch_size, ...]
     return tf.nest.map_structure(
         lambda reset_tensor: tf.scatter_nd(indices=reset_indices,
                                            updates=reset_tensor,
                                            shape=done.shape.as_list() +
                                            reset_tensor.shape.as_list(
                                            )[1:]), reset_indices_state)