def user_item_node_interaction_loss(self, probs, user_node_distance, item_node_distance, user_item_distance, neg_node_ind): """Computes pairwise hinge based loss, as in the reference below. Args: probs: Tensor of size batch_size x tot_node_batch containing the probability a node is the ancestor of the positive item. user_node_distance: Tensor of size batch_size x tot_node_batch containing square of the distances between the nodes and the user. item_node_distance: Tensor of size batch_size x tot_node_batch containing square of the distances between the nodes and the positive item. user_item_distance: Tensor of size batch_size x 2 containing square of the distances between the user and the positive and negative items. neg_node_ind: Tensor of size batch_size x tot_node_batch x 2 containing indices of negative nodes (within the sampled batch, from the relevant level), in tf.gather_nd format. Returns: loss within the input_batch. """ # TODO(advaw): change Idea 2 above to a real reference when possible. user_to_node = self.pos_and_neg_loss( user_node_distance, tf.gather_nd(user_node_distance, neg_node_ind)) item_to_node = self.pos_and_neg_loss( item_node_distance, tf.gather_nd(item_node_distance, neg_node_ind)) nodes_loss = tf.reduce_sum(probs * (user_to_node + item_to_node), axis=1) user_to_item = self.pos_and_neg_loss(user_item_distance[:, 0], user_item_distance[:, 1]) loss = tf.reduce_mean(user_to_item + nodes_loss) return loss
def _get_q_slice(q, k, ind, b=None, batch_shape=None): """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`.""" q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1) b_updates = tf.gather_nd(q, q_ind) if b is None: return tf.scatter_nd(ind, b_updates, batch_shape) return tf.tensor_scatter_nd_update(b, ind, b_updates)
def contrastive_loss(similarity_matrix, metric_values, temperature, coupling_temperature=1.0, use_coupling_weights=True): """Contrative Loss with soft coupling.""" logging.info('Using alternative contrastive loss.') metric_shape = tf.shape(metric_values) similarity_matrix /= temperature neg_logits1 = similarity_matrix col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32) pos_indices1 = tf.stack( (tf.range(metric_shape[0], dtype=tf.int32), col_indices), axis=1) pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1) if use_coupling_weights: metric_values /= coupling_temperature coupling = tf.exp(-metric_values) pos_weights1 = -tf.gather_nd(metric_values, pos_indices1) pos_logits1 += pos_weights1 negative_weights = tf.math.log((1.0 - coupling) + EPS) neg_logits1 += tf.tensor_scatter_nd_update(negative_weights, pos_indices1, pos_weights1) neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1) return tf.reduce_mean(neg_logits1 - pos_logits1)
def _piecewise_constant_integrate(x1, x2, jump_locations, values, batch_rank): """Integrates piecewise constant function between `x1` and `x2`.""" # Initializer already verified that `jump_locations` and `values` have the # same shape. # Expand batch size to one if there is no batch shape. if x1.shape.as_list()[:batch_rank]: no_batch_shape = False else: no_batch_shape = True x1 = tf.expand_dims(x1, 0) x2 = tf.expand_dims(x2, 0) if not jump_locations.shape.as_list()[:-1]: jump_locations = tf.expand_dims(jump_locations, 0) values = tf.expand_dims(values, 0) batch_rank += 1 # Compute the index matrix that is later used for `tf.gather_nd`. index_matrix = _prepare_index_matrix( x1.shape.as_list()[:-1], x1.shape.as_list()[-1], tf.int32) # Compute integral values at the jump locations starting from the first jump # location. event_shape = values.shape[(batch_rank+1):] num_data_points = values.shape.as_list()[batch_rank] diff = jump_locations[..., 1:] - jump_locations[..., :-1] # Broadcast `diff` to the shape of # `batch_shape + [num_data_points - 2] + [1] * sample_rank`. for _ in event_shape: diff = tf.expand_dims(diff, -1) slice_indices = batch_rank * [slice(None)] slice_indices += [slice(1, num_data_points - 1)] integrals = tf.cumsum(values[slice_indices] * diff, batch_rank) # Pad integrals with zero values on left and right. batch_shape = integrals.shape.as_list()[:batch_rank] zeros = tf.zeros(batch_shape + [1] + event_shape, dtype=integrals.dtype) integrals = tf.concat([zeros, integrals, zeros], axis=batch_rank) # Get jump locations and values and the integration end points value1, jump_location1, indices_nd1 = _get_indices_and_values( x1, index_matrix, jump_locations, values, 'left', batch_rank) value2, jump_location2, indices_nd2 = _get_indices_and_values( x2, index_matrix, jump_locations, values, 'right', batch_rank) integrals1 = tf.gather_nd(integrals, indices_nd1) integrals2 = tf.gather_nd(integrals, indices_nd2) # Broadcast `x1`, `x2`, `jump_location1`, `jump_location2` to the shape # `batch_shape + [num_points] + [1] * sample_rank`. for _ in event_shape: x1 = tf.expand_dims(x1, -1) x2 = tf.expand_dims(x2, -1) jump_location1 = tf.expand_dims(jump_location1, -1) jump_location2 = tf.expand_dims(jump_location2, -1) # Compute the value of the integral. res = ((jump_location1 - x1) * value1 + (x2 - jump_location2) * value2 + integrals2 - integrals1) if no_batch_shape: return tf.squeeze(res, 0) else: return res
def _get_coordinatewise_learning_rate(self, grad, var): # Compute the learning rate using a moving average for the diagonal of BB^T avg_first = self.get_slot(var, 'first_moment') avg_second = self.get_slot(var, 'second_moment') decay_tensor = tf.cast(self._decay_tensor, var.dtype) batch_size = tf.cast(self._batch_size_tensor, var.dtype) # Create an estimator for the moving average of gradient mean and variance # via Welford's algorithm if isinstance(grad, tf.Tensor): delta = grad - avg_first first_moment_update = avg_first.assign_add( delta * tf.where( self.iterations < 1, dtype_util.as_numpy_dtype(var.dtype)(1.), 1. - decay_tensor)) with tf.control_dependencies([first_moment_update]): second_moment_update = avg_second.assign_add( tf.cast(self.iterations < 1, var.dtype) * -(1. - decay_tensor) * (avg_second - decay_tensor * tf.square(delta))) diag_preconditioner = distribution_util.with_dependencies( [second_moment_update], tf.clip_by_value(avg_second, 1e-12, 1e12)) elif isinstance(grad, tf.IndexedSlices): delta = grad.values - tf.gather_nd(avg_first, grad.indices) first_moment_update = tf.compat.v1.scatter_add( avg_first, grad.indices, delta * tf.where( self.iterations < 1, dtype_util.as_numpy_dtype(var.dtype)(1.), 1. - decay_tensor)) with tf.control_dependencies([first_moment_update]): avg_second = tf.compat.v1.scatter_add( avg_second, grad.indices, tf.cast(self.iterations < 1, var.dtype) * -(1. - decay_tensor) * (tf.gather_nd(avg_second, grad.indices) - decay_tensor * tf.square(delta))) avg_second = tf.gather_nd(avg_second, grad.indices) # TODO(b/70783772): Needs dtype specific clipping. diag_preconditioner = tf.clip_by_value(avg_second, 1e-12, 1e12) else: raise tf.errors.InvalidArgumentError( None, None, 'grad must of type Tensor or IndexedSlice') diag_preconditioner *= batch_size if self._use_single_learning_rate: diag_preconditioner = tf.reduce_mean(diag_preconditioner) # From Theorem 2 Corollary 1 of Mandt et al. 2017 return 2. * batch_size / ( tf.cast(self._total_num_examples, var.dtype.base_dtype) * diag_preconditioner)
def _retrieve_from_cache( self, query_embeddings, cache): sorted_data_sources = sorted(cache.keys()) all_query_embeddings = util.cross_replica_concat(query_embeddings, axis=0) num_replicas = tf.distribute.get_replica_context().num_replicas_in_sync # Performs approximate top k across replicas. if self.top_k: top_k_per_replica = self.top_k // num_replicas else: top_k_per_replica = self.top_k retrieval_return = _retrieve_from_caches(all_query_embeddings, cache, self._retrieval_fn, self.embedding_key, self.data_keys, sorted_data_sources, self.score_transform, top_k_per_replica) # We transfer all queries to all replica and retrieve from every shard. all_queries_local_weight = tf.math.reduce_logsumexp( retrieval_return.scores, axis=1) local_queries_global_weights = _get_local_elements_global_data( all_queries_local_weight, num_replicas) local_queries_all_retrieved_data = {} for key in retrieval_return.retrieved_data: local_queries_all_retrieved_data[key] = _get_local_elements_global_data( retrieval_return.retrieved_data[key], num_replicas) local_queries_all_retrieved_embeddings = _get_local_elements_global_data( retrieval_return.retrieved_cache_embeddings, num_replicas) # We then sample a shard index proportional to its total weight. # This allows us to do Gumbel-Max sampling without modifying APIs. selected_replica = self._retrieval_fn(local_queries_global_weights) selected_replica = tf.stop_gradient(selected_replica) num_elements = tf.shape(selected_replica)[0] batch_indices = tf.range(num_elements) batch_indices = tf.cast(batch_indices, tf.int64) batch_indices = tf.expand_dims(batch_indices, axis=1) selected_replica_with_batch = tf.concat([batch_indices, selected_replica], axis=1) retrieved_data = { k: tf.gather_nd(v, selected_replica_with_batch) for k, v in local_queries_all_retrieved_data.items() } retrieved_cache_embeddings = tf.gather_nd( local_queries_all_retrieved_embeddings, selected_replica_with_batch) return _RetrievalReturn( retrieved_data=retrieved_data, scores=local_queries_global_weights, retrieved_indices=None, retrieved_cache_embeddings=retrieved_cache_embeddings)
def _get_indices_and_values(x, index_matrix, jump_locations, values, side, batch_rank): """Computes values and jump locations of the piecewise constant function. Given `jump_locations` and the `values` on the corresponding segments of the piecewise constant function, the function identifies the nearest jump to `x` from the right or left (which is determined by the `side` argument) and the corresponding value of the piecewise constant function at `x` Args: x: A real `Tensor` of shape `batch_shape + [num_points]`. Points at which the function has to be evaluated. index_matrix: An `int32` `Tensor` of shape `batch_shape + [num_points] + [len(batch_shape)]` such that if `batch_shape = [i1, .., in]`, then for all `j1, ..., jn, l`, `index_matrix[j1,..,jn, l] = [j1, ..., jn]`. jump_locations: A `Tensor` of the same `dtype` as `x` and shape `batch_shape + [num_jump_points]`. The locations where the function changes its values. Note that the values are expected to be ordered along the last dimension. values: A `Tensor` of the same `dtype` as `x` and shape `batch_shape + [num_jump_points + 1]`. Defines `values[..., i]` on `jump_locations[..., i - 1], jump_locations[..., i]`. side: A Python string. Whether the function is left- or right- continuous. The corresponding values for side should be `left` and `right`. batch_rank: A Python scalar stating the batch rank of `x`. Returns: A tuple of three `Tensor` of the same `dtype` as `x` and shapes `batch_shape + [num_points] + event_shape`, `batch_shape + [num_points]`, and `batch_shape + [num_points] + [2 * len(batch_shape)]`. The `Tensor`s correspond to the values, jump locations at `x`, and the corresponding indices used to obtain jump locations via `tf.gather_nd`. """ indices = tf.searchsorted(jump_locations, x, side=side) num_data_points = tf.shape(values)[batch_rank] - 2 if side == 'right': indices_jump = indices - 1 indices_jump = tf.maximum(indices_jump, 0) else: indices_jump = tf.minimum(indices, num_data_points) indices_nd = tf.concat( [index_matrix, tf.expand_dims(indices, -1)], -1) indices_jump_nd = tf.concat( [index_matrix, tf.expand_dims(indices_jump, -1)], -1) value = tf.gather_nd(values, indices_nd) jump_location = tf.gather_nd(jump_locations, indices_jump_nd) return value, jump_location, indices_jump_nd
def approximate_top_k_with_indices(negative_scores, k): """Approximately mines the top k highest scoreing negatives with indices. This function groups the negative scores into num_negatives / k groupings and returns the highest scoring element from each group. It also returns the index where the selected elements were found in the score matrix. Args: negative_scores: A matrix with the scores of the negative elements. k: The number of negatives to mine. Returns: The tuple (top_k_scores, top_k_indices), where top_k_indices describes the index of the mined elements in the given score matrix. """ bs = tf.shape(negative_scores)[0] num_elem = tf.shape(negative_scores)[1] batch_indices = tf.range(num_elem) indices = tf.tile(tf.expand_dims(batch_indices, axis=0), multiples=[bs, 1]) grouped_negative_scores = tf.reshape(negative_scores, [bs * k, -1]) grouped_batch_indices = tf.range(tf.shape(grouped_negative_scores)[0]) grouped_top_k_scores, grouped_top_k_indices = tf.math.top_k( grouped_negative_scores) grouped_top_k_indices = tf.squeeze(grouped_top_k_indices, axis=1) gather_indices = tf.stack([grouped_batch_indices, grouped_top_k_indices], axis=1) grouped_indices = tf.reshape(indices, [bs * k, -1]) grouped_top_k_indices = tf.gather_nd(grouped_indices, gather_indices) top_k_indices = tf.reshape(grouped_top_k_indices, [bs, k]) top_k_scores = tf.reshape(grouped_top_k_scores, [bs, k]) return top_k_scores, top_k_indices
def get_slice(x, encoding): if optimize_for_tpu: return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) * encoding, axis=-1) else: return tf.gather_nd(x, encoding)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat([ self._batch_shape_tensor(samples), self._event_shape_tensor(samples) ], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, fn_output_signature=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat( [self._batch_shape_tensor(samples), self._event_shape_tensor(samples)], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def dense_to_sparse(x, ignore_value=None, name=None): """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells. Args: x: A `Tensor`. ignore_value: Entries in `x` equal to this value will be absent from the return `SparseTensor`. If `None`, default value of `x` dtype will be used (e.g. '' for `str`, 0 for `int`). name: Python `str` prefix for ops created by this function. Returns: sparse_x: A `tf.SparseTensor` with the same shape as `x`. Raises: ValueError: when `x`'s rank is `None`. """ # Copied (with modifications) from: # tensorflow/contrib/layers/python/ops/sparse_ops.py. with tf.name_scope(name or 'dense_to_sparse'): x = tf.convert_to_tensor(x, name='x') if ignore_value is None: if dtype_util.base_dtype(x.dtype) == tf.string: # Exception due to TF strings are converted to numpy objects by default. ignore_value = '' else: ignore_value = dtype_util.as_numpy_dtype(x.dtype)(0) ignore_value = tf.cast(ignore_value, x.dtype, name='ignore_value') indices = tf.where(tf.not_equal(x, ignore_value), name='indices') return tf.SparseTensor(indices=indices, values=tf.gather_nd(x, indices, name='values'), dense_shape=tf.shape(x, out_type=tf.int64, name='dense_shape'))
def top_k_boxes(boxes, scores, k): """Sort and select top k boxes according to the scores. Args: boxes: a tensor of shape [batch_size, N, 4] representing the coordiante of the boxes. N is the number of boxes per image. scores: a tensor of shsape [batch_size, N] representing the socre of the boxes. k: an integer or a tensor indicating the top k number. Returns: selected_boxes: a tensor of shape [batch_size, k, 4] representing the selected top k box coordinates. selected_scores: a tensor of shape [batch_size, k] representing the selected top k box scores. """ with tf.name_scope('top_k_boxes'): selected_scores, top_k_indices = tf.nn.top_k(scores, k=k, sorted=True) batch_size, _ = scores.get_shape().as_list() if batch_size == 1: selected_boxes = tf.squeeze( tf.gather(boxes, top_k_indices, axis=1), axis=1) else: top_k_indices_shape = tf.shape(top_k_indices) batch_indices = ( tf.expand_dims(tf.range(top_k_indices_shape[0]), axis=-1) * tf.ones([1, top_k_indices_shape[-1]], dtype=tf.int32)) gather_nd_indices = tf.stack([batch_indices, top_k_indices], axis=-1) selected_boxes = tf.gather_nd(boxes, gather_nd_indices) return selected_boxes, selected_scores
def _build_target_quantile_values_op(self): """Build an op used as a target for return values at given quantiles. Returns: An op calculating the target quantile return. """ batch_size = tf.shape(self._replay.rewards)[0] # Calculate AL modified rewards. replay_action_one_hot = tf.one_hot( self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_target_q = tf.reduce_max( self._replay_target_q_values, axis=1, name='replay_chosen_target_q') replay_target_q_al = tf.reduce_sum( replay_action_one_hot * self._replay_target_q_values, axis=1, name='replay_chosen_target_q_al') if self._clip > 0.: al_bonus = self._alpha * tf.clip_by_value( (replay_target_q_al - replay_target_q), -self._clip, self._clip) else: al_bonus = self._alpha * ( replay_target_q_al - replay_target_q) # Shape of rewards: (num_tau_prime_samples x batch_size) x 1. rewards = (self._replay.rewards + al_bonus)[:, None] rewards = tf.tile(rewards, [self.num_tau_prime_samples, 1]) is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32) # Incorporate terminal state to discount factor. # size of gamma_with_terminal: (num_tau_prime_samples x batch_size) x 1. gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier gamma_with_terminal = tf.tile(gamma_with_terminal[:, None], [self.num_tau_prime_samples, 1]) # Get the indices of the maximum Q-value across the action dimension. # Shape of replay_next_qt_argmax: (num_tau_prime_samples x batch_size) x 1. replay_next_qt_argmax = tf.tile( self._replay_next_qt_argmax[:, None], [self.num_tau_prime_samples, 1]) # Shape of batch_indices: (num_tau_prime_samples x batch_size) x 1. batch_indices = tf.cast(tf.range( self.num_tau_prime_samples * batch_size)[:, None], tf.int64) # Shape of batch_indexed_target_values: # (num_tau_prime_samples x batch_size) x 2. batch_indexed_target_values = tf.concat( [batch_indices, replay_next_qt_argmax], axis=1) # Shape of next_target_values: (num_tau_prime_samples x batch_size) x 1. target_quantile_values = tf.gather_nd( self._replay_net_target_quantile_values, batch_indexed_target_values)[:, None] return rewards + gamma_with_terminal * target_quantile_values
def _piecewise_constant_function(x, jump_locations, values, batch_rank, side='left'): """Computes value of the piecewise constant function.""" # Initializer already verified that `jump_locations` and `values` have the # same shape batch_shape = jump_locations.shape.as_list()[:-1] # Check that the batch shape of `x` is the same as of `jump_locations` and # `values` batch_shape_x = x.shape.as_list()[:batch_rank] if batch_shape_x != batch_shape: raise ValueError('Batch shape of `x` is {1} but should be {0}'.format( batch_shape, batch_shape_x)) if x.shape.as_list()[:batch_rank]: no_batch_shape = False else: no_batch_shape = True x = tf.expand_dims(x, 0) # Expand batch size to one if there is no batch shape if not batch_shape: jump_locations = tf.expand_dims(jump_locations, 0) values = tf.expand_dims(values, 0) indices = tf.searchsorted(jump_locations, x, side=side) index_matrix = _prepare_index_matrix( indices.shape.as_list()[:-1], indices.shape.as_list()[-1], indices.dtype) indices_nd = tf.concat( [index_matrix, tf.expand_dims(indices, -1)], -1) res = tf.gather_nd(values, indices_nd) if no_batch_shape: return tf.squeeze(res, 0) else: return res
def _inverse(self, y): map_values = tf.convert_to_tensor(self.map_values) flat_y = tf.reshape(y, shape=[-1]) # Search for the indices of map_values that are closest to flat_y. # Since map_values is strictly increasing, the closest is either the # first one that is strictly greater than flat_y, or the one before it. upper_candidates = tf.minimum( tf.size(map_values) - 1, tf.searchsorted(map_values, values=flat_y, side='right')) lower_candidates = tf.maximum(0, upper_candidates - 1) candidates = tf.stack([lower_candidates, upper_candidates], axis=-1) lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates)) upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates)) if self.validate_args: with tf.control_dependencies([ assert_util.assert_near(tf.minimum(lower_cand_diff, upper_cand_diff), 0, message='inverse value not found') ]): candidates = tf.identity(candidates) candidate_selector = tf.stack([ tf.range(tf.size(flat_y), dtype=tf.int32), tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32) ], axis=-1) return tf.reshape(tf.gather_nd(candidates, candidate_selector), shape=y.shape)
def _analytic_valuation(expiries, floating_leg_start_times, floating_leg_end_times, fixed_leg_payment_times, fixed_leg_daycount_fractions, fixed_leg_coupon, reference_rate_fn, dim, mean_reversion, volatility, notional, is_payer_swaption, output_shape, dtype, name): """Helper function for analytic valuation.""" # The below inputs are needed for midcurve swaptions del floating_leg_start_times, floating_leg_end_times with tf.name_scope(name): is_call_options = tf.where(is_payer_swaption, tf.convert_to_tensor(False, dtype=tf.bool), tf.convert_to_tensor(True, dtype=tf.bool)) model = vector_hull_white.VectorHullWhiteModel( dim, mean_reversion, volatility, initial_discount_rate_fn=reference_rate_fn, dtype=dtype) coefficients = fixed_leg_daycount_fractions * fixed_leg_coupon jamshidian_coefficients = tf.concat([ -coefficients[..., :-1], tf.expand_dims(-1.0 - coefficients[..., -1], axis=-1)], axis=-1) breakeven_bond_option_strikes = _jamshidian_decomposition( model, expiries, fixed_leg_payment_times, jamshidian_coefficients, dtype, name=name + '_jamshidian_decomposition') bond_strike_rank = breakeven_bond_option_strikes.shape.rank perm = [bond_strike_rank-1] + [x for x in range(0, bond_strike_rank - 1)] breakeven_bond_option_strikes = tf.transpose( breakeven_bond_option_strikes, perm=perm) bond_option_prices = zcb.bond_option_price( strikes=breakeven_bond_option_strikes, expiries=expiries, maturities=fixed_leg_payment_times, discount_rate_fn=reference_rate_fn, dim=dim, mean_reversion=mean_reversion, volatility=volatility, is_call_options=is_call_options, use_analytic_pricing=True, dtype=dtype, name=name + '_bond_option') # Now compute P(T0, TN) + sum_i (c_i * tau_i * P(T0, Ti)) # bond_option_prices.shape = [dim] + batch_shape + [m] + [dim], where `m` # denotes the number of fixed payments for the underlying swaps. swaption_values = ( tf.reduce_sum( bond_option_prices * tf.expand_dims(coefficients, axis=-1), axis=-2) + bond_option_prices[..., -1, :]) swaption_shape = swaption_values.shape gather_index = _prepare_swaption_indices(swaption_shape.as_list()) swaption_values = tf.reshape( tf.gather_nd(swaption_values, gather_index), output_shape) return notional * swaption_values
def hard_quantile_normalization(inputs, quantiles): """Applies the quantile function `quantiles` to the inputs.""" n_rows = inputs.shape[0] rows = tf.range(n_rows)[:, tf.newaxis] * tf.ones_like(inputs, dtype=tf.int32) indices = tf.stack( [rows, tf.argsort(tf.argsort(inputs, axis=1), axis=1)], axis=-1) ordered_quantiles = tf.gather_nd(quantiles, tf.reshape(indices, (-1, 2))) return tf.reshape(ordered_quantiles, inputs.shape)
def _retrieve_from_caches(query_embeddings, cache, retrieval_fn, embedding_key, data_keys, sorted_data_sources, score_transform=None, top_k=None): """Retrieve elements from a cache with the given retrieval function.""" all_embeddings = _batch_concat_with_no_op([ cache[data_source].data[embedding_key] for data_source in sorted_data_sources ]) all_data = {} for key in data_keys: all_data[key] = _batch_concat_with_no_op([ cache[data_source].data[key] for data_source in sorted_data_sources ]) scores = _score_documents(query_embeddings, all_embeddings, score_transform=score_transform, all_pairs=True) if top_k: scores, top_k_indices = util.approximate_top_k_with_indices( scores, top_k) top_k_indices = tf.cast(top_k_indices, dtype=tf.int64) retrieved_indices = retrieval_fn(scores) batch_index = tf.expand_dims(tf.range(tf.shape(retrieved_indices)[0], dtype=tf.int64), axis=1) retrieved_indices_with_batch_index = tf.concat( [batch_index, retrieved_indices], axis=1) retrieved_indices = tf.gather_nd(top_k_indices, retrieved_indices_with_batch_index) retrieved_indices = tf.expand_dims(retrieved_indices, axis=1) else: retrieved_indices = retrieval_fn(scores) retrieved_indices = tf.stop_gradient(retrieved_indices) retrieved_data = { k: tf.gather_nd(v, retrieved_indices) for k, v in all_data.items() } retrieved_cache_embeddings = tf.gather_nd(all_embeddings, retrieved_indices) return _RetrievalReturn(retrieved_data, scores, retrieved_indices, retrieved_cache_embeddings)
def __call__(self, logits, scaled_labels, classes, category_loss=True, mse_loss=False): """Compute instance segmentation loss. Args: logits: A Tensor of shape [batch_size * num_points, height, width, num_classes]. The logits are not necessarily between 0 and 1. scaled_labels: A float16 Tensor of shape [batch_size, num_instances, mask_size, mask_size], where mask_size = mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size for coarse masks and shape priors. classes: A int tensor of shape [batch_size, num_instances]. category_loss: use class specific mask prediction or not. mse_loss: use mean square error for mask loss or not Returns: mask_loss: an float tensor representing total mask classification loss. iou: a float tensor representing the IoU between target and prediction. """ classes = tf.reshape(classes, [-1]) _, _, height, width = scaled_labels.get_shape().as_list() scaled_labels = tf.reshape(scaled_labels, [-1, height, width]) if not category_loss: logits = logits[:, :, :, 0] else: logits = tf.transpose(a=logits, perm=(0, 3, 1, 2)) gather_idx = tf.stack( [tf.range(tf.size(input=classes)), classes - 1], axis=1) logits = tf.gather_nd(logits, gather_idx) # Ignore loss on empty mask targets. valid_labels = tf.reduce_any(input_tensor=tf.greater(scaled_labels, 0), axis=[1, 2]) if mse_loss: # Logits are probabilities in the case of shape prior prediction. logits *= tf.reshape(tf.cast(valid_labels, logits.dtype), [-1, 1, 1]) weighted_loss = tf.nn.l2_loss(scaled_labels - logits) probs = logits else: weighted_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=scaled_labels, logits=logits) probs = tf.sigmoid(logits) weighted_loss *= tf.reshape( tf.cast(valid_labels, weighted_loss.dtype), [-1, 1, 1]) iou = tf.reduce_sum( input_tensor=tf.minimum(scaled_labels, probs)) / tf.reduce_sum( input_tensor=tf.maximum(scaled_labels, probs)) mask_loss = tf.reduce_sum(input_tensor=weighted_loss) / tf.reduce_sum( input_tensor=scaled_labels) return tf.cast(mask_loss, tf.float32), tf.cast(iou, tf.float32)
def model(): raining = yield Root(tfd.Bernoulli(probs=0.2, dtype=tf.int32)) sprinkler_prob = [0.4, 0.01] sprinkler_prob = tf.gather(sprinkler_prob, raining) sprinkler = yield tfd.Bernoulli(probs=sprinkler_prob, dtype=tf.int32) grass_wet_prob = [[0.0, 0.8], [0.9, 0.99]] grass_wet_prob = tf.gather_nd(grass_wet_prob, _stack(sprinkler, raining)) grass_wet = yield tfd.Bernoulli(probs=grass_wet_prob, dtype=tf.int32)
def _classify_and_fuse_detection_priors(self, uniform_priors, detection_prior_classes, crop_features): """Classify the uniform prior by predicting the shape modes. Classify the object crop features into K modes of the clusters for each category. Args: uniform_priors: A float Tensor of shape [batch_size, num_instances, mask_size, mask_size] representing the uniform detection priors. detection_prior_classes: A int Tensor of shape [batch_size, num_instances] of detection class ids. crop_features: A float Tensor of shape [batch_size * num_instances, mask_size, mask_size, num_channels]. Returns: shape_weights: A float Tensor of shape [batch_size * num_instances, num_clusters] representing the classifier output probability over all possible shapes. """ location_detection_priors = tf.reshape( uniform_priors, [-1, self._mask_crop_size, self._mask_crop_size, 1]) # Generate image embedding to shape. fused_shape_features = crop_features * location_detection_priors shape_embedding = tf.reduce_mean(input_tensor=fused_shape_features, axis=(1, 2)) if not self._use_category_for_mask: # TODO(weicheng) use custom op for performance shape_logits = tf.keras.layers.Dense( self._num_clusters, kernel_initializer=tf.keras.initializers.RandomNormal( stddev=0.01))(shape_embedding) shape_logits = tf.reshape( shape_logits, [-1, self._num_clusters]) / self._temperature shape_weights = tf.nn.softmax(shape_logits, name='shape_prior_weights') else: shape_logits = tf.keras.layers.Dense( self._mask_num_classes * self._num_clusters, kernel_initializer=tf.keras.initializers.RandomNormal( stddev=0.01))(shape_embedding) shape_logits = tf.reshape( shape_logits, [-1, self._mask_num_classes, self._num_clusters]) training_classes = tf.reshape(detection_prior_classes, [-1]) class_idx = tf.stack([ tf.range(tf.size(input=training_classes)), training_classes - 1 ], axis=1) shape_logits = tf.gather_nd(shape_logits, class_idx) / self._temperature shape_weights = tf.nn.softmax(shape_logits, name='shape_prior_weights') return shape_weights
def seperation_loss(self, model, node_tensor, neg_node_ind): """Calculates -d(n,n')^2.""" neg_nodes_actual_ind = tf.gather_nd(node_tensor, neg_node_ind) nodes = model.get_batch_nodes(node_tensor) neg_nodes = model.get_batch_nodes(neg_nodes_actual_ind) node_neg_node_dist = hyp_utils.hyp_distance(nodes, neg_nodes, tf.math.softplus(model.c)) seperation = tf.reduce_mean(-model.square_distance(node_neg_node_dist)) return seperation
def _get_new_item_indices(self, age, updates, mask=None): any_update = list(updates.values())[0] num_updates = tf.shape(any_update)[0] _, new_item_indices = tf.math.top_k(age, num_updates) if mask is not None: mask = tf.cast(mask, dtype=tf.int32) unmasked_indices = (tf.cumsum(mask) - 1) * mask unmasked_indices = tf.expand_dims(unmasked_indices, axis=1) new_item_indices = tf.gather_nd(new_item_indices, unmasked_indices) return new_item_indices
def _dense_to_sparse(self, student_ids, question_ids, dense_correct): test_y_idx = np.stack([student_ids, question_ids], axis=-1) # Need to tile the indices across the batch, for gather_nd. batch_shape = ps.shape(dense_correct)[:-2] broadcast_shape = ps.concat([ps.ones_like(batch_shape), test_y_idx.shape], axis=-1) test_y_idx = tf.reshape(test_y_idx, broadcast_shape) test_y_idx = tf.tile(test_y_idx, ps.concat([batch_shape, [1, 1]], axis=-1)) return tf.gather_nd( dense_correct, test_y_idx, batch_dims=ps.size(batch_shape))
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape): """Determine the beginning of the interval, `a`.""" # if i < 0: a = q0[j] i_lt_0 = tf.less(i, 0) ind = tf.where(i_lt_0) a_update = tf.gather_nd(q0_j, ind) a = tf.scatter_nd(ind, a_update, batch_shape) # elif j < 0: a = q1[i] j_lt_0 = tf.less(j, 0) ind = tf.where(j_lt_0) a_update = tf.gather_nd(q1_i, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) # else: a = max(q0[j], q1[i]) ind = tf.where(~(i_lt_0 | j_lt_0)) q_max = tf.maximum(q0_j, q1_i) a_update = tf.gather_nd(q_max, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) return a
def _compute_2d_sparsemax(logits): """Performs the sparsemax operation when axis=-1.""" shape_op = tf.shape(logits) obs = tf.math.reduce_prod(shape_op[:-1]) dims = shape_op[-1] # In the paper, they call the logits z. # The mean(logits) can be subtracted from logits to make the algorithm # more numerically stable. the instability in this algorithm comes mostly # from the z_cumsum. Subtacting the mean will cause z_cumsum to be close # to zero. However, in practise the numerical instability issues are very # minor and subtacting the mean causes extra issues with inf and nan # input. # Reshape to [obs, dims] as it is almost free and means the remanining # code doesn't need to worry about the rank. z = tf.reshape(logits, [obs, dims]) # sort z z_sorted, _ = tf.nn.top_k(z, k=dims) # calculate k(z) z_cumsum = tf.math.cumsum(z_sorted, axis=-1) k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype) z_check = 1 + k * z_sorted > z_cumsum # because the z_check vector is always [1,1,...1,0,0,...0] finding the # (index + 1) of the last `1` is the same as just summing the number of 1. k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1) # calculate tau(z) # If there are inf values or all values are -inf, the k_z will be zero, # this is mathematically invalid and will also cause the gather_nd to fail. # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then # fixed later (see p_safe) by returning p = nan. This results in the same # behavior as softmax. k_z_safe = tf.math.maximum(k_z, 1) indices = tf.stack( [tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1) tau_sum = tf.gather_nd(z_cumsum, indices) tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype) # calculate p p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1)) # If k_z = 0 or if z = nan, then the input is invalid p_safe = tf.where( tf.expand_dims(tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])), axis=-1), tf.fill([obs, dims], tf.cast(float('nan'), logits.dtype)), p) # Reshape back to original size p_safe = tf.reshape(p_safe, shape_op) return p_safe
def ensemble_crossentropy(labels, logits, ensemble_size): """Return ensemble cross-entropy.""" tile_logp = tf.nn.log_softmax(logits, axis=-1) # (1,ens_size*batch,n_classes) tile_logp = tf.expand_dims(tile_logp, 0) tile_logp = tf.concat( tf.split(tile_logp, ensemble_size, axis=1), 0) logp = tfp.math.reduce_logmeanexp(tile_logp, axis=0) mask = tf.stack([ tf.range(len(labels), dtype=tf.int32), tf.cast(labels, dtype=tf.int32)], axis=1) return -tf.reduce_mean(tf.gather_nd(logp, mask))
def compute_logits(self, context_features=None, example_features=None, training=None, mask=None): """Scores context and examples to return a score per document. Args: context_features: (dict) context feature names to 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to 3D tensors of shape [batch_size, list_size, feature_dims]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. If mask is None, all entries are valid. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. """ tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand query features to be of [batch_size, list_size, ...]. large_batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) large_batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size * list_size]) large_batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) large_batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size * list_size]) # Get scores for large batch. scores = self.score(context_features=large_batch_context_features, example_features=large_batch_example_features, training=training) logits = tf.reshape(scores, shape=[batch_size, list_size]) # Apply nd_mask to zero out invalid entries. logits = tf.where(nd_mask, logits, tf.zeros_like(logits)) return logits
def _get_reset_state_indices(): reset_indices_obs = tf.nest.map_structure( lambda t: tf.gather_nd(t, reset_indices), observation) # shape: [num_indices_to_reset, ...] reset_indices_state = self.get_initial_state( reset_indices_obs, batch_size=tf.shape(reset_indices)[0]) # Scatter tensors in `reset_indices_state` to shape: [num_timesteps, # batch_size, ...] return tf.nest.map_structure( lambda reset_tensor: tf.scatter_nd(indices=reset_indices, updates=reset_tensor, shape=done.shape.as_list() + reset_tensor.shape.as_list( )[1:]), reset_indices_state)