def testSimple(self): p = self._DualEncoderParamsForTest() p.loss_weights = {('x', 'y'): 0.5, ('y', 'x'): 0.5} p.joint_embedding_dim = 7 batch_size = 2 p.label_fn = lambda _: tf.eye(batch_size) model = p.Instantiate() input_batch = py_utils.NestedMap(x_input=tf.ones([batch_size, 3], dtype=tf.float32), x_ids=tf.range(batch_size, dtype=tf.int64), y_input=tf.ones([batch_size, 4], dtype=tf.float32), y_ids=tf.range(batch_size, dtype=tf.int64)) preds = model.ComputePredictions(model.theta, input_batch) self.assertEqual([batch_size, p.joint_embedding_dim], preds.x.encodings.shape.as_list()) self.assertEqual([batch_size, p.joint_embedding_dim], preds.y.encodings.shape.as_list()) self.assertEqual(input_batch.x_ids.shape, preds.x.ids.shape) self.assertEqual(input_batch.y_ids.shape, preds.y.ids.shape)
def _generalized_inverse_pth_root(self, input_t, exponent, epsilon=1e-12): input_t_f64 = tf.cast(input_t, tf.float64) s, u, v = tf.linalg.svd( input_t_f64 + tf.eye(tf.shape(input_t_f64)[0], dtype=tf.float64) * epsilon, full_matrices=True) inv_s = tf.reshape( tf.pow(tf.maximum(s, epsilon), tf.cast(exponent, tf.float64)), [1, -1]) val = tf.matmul(u * inv_s, v, adjoint_b=True) return tf.cast(val, tf.float32), tf.reduce_max(tf.abs(u - v))
def inverse_pth_root(self, input_t, exponent, epsilon=1e-12): input_t_f64 = tf.cast(input_t, tf.float64) s, u, v = tf.linalg.svd( input_t_f64 + tf.eye(tf.shape(input_t_f64)[0], dtype=tf.float64) * epsilon, full_matrices=True) val = tf.matmul( tf.matmul( u, tf.linalg.tensor_diag( tf.pow(tf.maximum(s, epsilon), tf.cast(exponent, tf.float64)))), tf.transpose(v)) return tf.cast(val, tf.float32), tf.reduce_max(tf.abs(u - v))
def BetweenLocalAndGlobalBatches(cls, local_batch, **kwargs) -> 'ExamplePairs': """Creates an instance representing (local, global) example pairs.""" local_batch = py_utils.NestedMap(local_batch) global_batch = tpu_utils.ConcatenateAcrossReplicas(local_batch) correspondences = tf.eye( utils.InferBatchSize(local_batch), utils.InferBatchSize(global_batch), dtype=tf.bool) return ExamplePairs( query_examples=local_batch, result_examples=global_batch, correspondences=correspondences, **kwargs)
def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4): """Iterative method to get matrix square root. Stable iterations for the matrix square root, Nicholas J. Higham Page 231, Eq 2.6b http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf Args: mat_a: the symmetric PSD matrix whose matrix square root be computed mat_a_size: size of mat_a. iter_count: Maximum number of iterations. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. Returns: mat_a^0.5 """ def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z, unused_old_mat_z, err, old_err): """This method require that we check for divergence every step.""" return tf.math.logical_and(i < iter_count, err < old_err) def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err, unused_old_err): """Iterative method to compute the square root of matrix.""" current_iterate = 0.5 * (3.0 * identity - tf.matmul(mat_z, mat_y)) current_mat_y = tf.matmul(mat_y, current_iterate) current_mat_z = tf.matmul(current_iterate, mat_z) # Compute the error in approximation. mat_sqrt_a = current_mat_y * tf.sqrt(norm) mat_a_approx = tf.matmul(mat_sqrt_a, mat_sqrt_a) residual = mat_a - mat_a_approx current_err = tf.sqrt(tf.reduce_sum(residual * residual)) / norm return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err identity = tf.eye(tf.cast(mat_a_size, tf.int32)) mat_a = mat_a + ridge_epsilon * identity norm = tf.sqrt(tf.reduce_sum(mat_a * mat_a)) mat_init_y = mat_a / norm mat_init_z = identity init_err = norm _, _, prev_mat_y, _, _, _, _ = tf.while_loop(_iter_condition, _iter_body, [ 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err, init_err + 1.0 ]) return prev_mat_y * tf.sqrt(norm)
def __call__(self, inputs: ExamplePairs): """Labels item pairs in `inputs`.""" # Generate labels for the example pairs. If examples only have one item in # each modality, everything else below is a no-op. example_pair_labels = self._example_pair_labeler(inputs) example_pair_labels.shape.assert_is_compatible_with( [inputs.query_batch_size, inputs.result_batch_size]) query_batch_shape = utils.ResolveBatchDim( self._modality_batch_shapes[inputs.query_modality], inputs.query_batch_size) result_batch_shape = utils.ResolveBatchDim( self._modality_batch_shapes[inputs.result_modality], inputs.result_batch_size) # Broadcast example-level labels to all pairs of their items. item_pair_labels = _BroadcastExamplePairLabelsToAllItemPairs( example_pair_labels, query_batch_shape, result_batch_shape) if inputs.query_modality == inputs.result_modality: # Intra-modal retrieval. Give self pairs the "ignore" label so that items # aren't used as their own targets during training. # For this case we require the batch shape to be rank 2. # - If rank == 1, each example only contains a single item, so the only # within-example pairs are the self pairs. Once these are dropped, # there are typically no positively labeled pairs, meaning there is # no training signal. (The exception is if two *distinct* examples # are given a positive label.) Rather than hoping the trainer will do # something smart in this weird corner case, we simply die if # rank == 1. # - Any rank > 1 is sufficient to get around the above problem, but # currently there's no use case for any ranks above 2. query_batch_shape.assert_has_rank(2) n = tf.compat.dimension_value(query_batch_shape[1]) assert n > 1 # Item self-pairs are those at indices [q, i, r, j] where # - Examples q and r are the same example # - i == j refer to the same item in the example is_item_self_pair = ( # [q, 1, r, 1] inputs.correspondences[:, None, :, None] # [1, n, 1, n] & tf.eye(n, dtype=tf.bool)[None, :, None, :]) item_pair_labels = _IgnorePairsWhere(is_item_self_pair, item_pair_labels) return item_pair_labels
def WithinBatch(cls, batch, **kwargs) -> 'ExamplePairs': """Creates an instance representing all example pairs within `batch`. Args: batch: Dict of input examples; this same set of examples represents both the query_examples and result_examples. **kwargs: kwargs to forward to ExamplePairs constructor. Returns: `ExamplePairs` """ correspondences = tf.eye(utils.InferBatchSize(batch), dtype=tf.bool) return ExamplePairs( query_examples=batch, result_examples=batch, correspondences=correspondences, **kwargs)
def inlined_matrix_inverse_pth_root(mat_g, mat_g_size, alpha, iter_count=100, error_tolerance=1e-6, ridge_epsilon=1e-6): """Computes mat_g^alpha, where alpha = -1/p, p is one of 2, 4, or 8. We use an iterative Schur-Newton method from equation 3.2 on page 9 of: A Schur-Newton Method for the Matrix p-th Root and its Inverse by Chun-Hua Guo and Nicholas J. Higham SIAM Journal on Matrix Analysis and Applications, 2006, Vol. 28, No. 3 : pp. 788-804 https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf Args: mat_g: the symmetric PSD matrix whose power it to be computed mat_g_size: size of mat_g. alpha: exponent, must be -1/p for p a positive integer. iter_count: Maximum number of iterations. error_tolerance: Error indicator, useful for early termination. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. Returns: mat_g^alpha """ alpha = tf.cast(alpha, tf.float64) neg_alpha = -1.0 * alpha exponent = 1.0 / neg_alpha identity = tf.eye(tf.cast(mat_g_size, tf.int32), dtype=tf.float64) def _unrolled_mat_pow_2(mat_m): """Computes mat_m^2.""" return tf.matmul(mat_m, mat_m) def _unrolled_mat_pow_4(mat_m): """Computes mat_m^4.""" mat_pow_2 = _unrolled_mat_pow_2(mat_m) return tf.matmul(mat_pow_2, mat_pow_2) def _unrolled_mat_pow_8(mat_m): """Computes mat_m^4.""" mat_pow_4 = _unrolled_mat_pow_4(mat_m) return tf.matmul(mat_pow_4, mat_pow_4) def mat_power(mat_m, p): """Computes mat_m^p, for p == 2 or 4 or 8. Args: mat_m: a square matrix p: a positive integer Returns: mat_m^p """ branch_index = tf.cast(p / 2 - 1, tf.int32) return tf.switch_case( branch_index, { 0: functools.partial(_unrolled_mat_pow_2, mat_m), 1: functools.partial(_unrolled_mat_pow_4, mat_m), 2: functools.partial(_unrolled_mat_pow_8, mat_m), }) def _iter_condition(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step): return tf.math.logical_and( tf.math.logical_and(i < iter_count, error > error_tolerance), run_step) def _iter_body(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step): mat_m_i = (1 - alpha) * identity + alpha * mat_m new_mat_m = tf.matmul(mat_power(mat_m_i, exponent), mat_m) new_mat_h = tf.matmul(mat_h, mat_m_i) new_error = tf.reduce_max(tf.abs(new_mat_m - identity)) return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error) if mat_g_size == 1: mat_h = tf.pow(mat_g + ridge_epsilon, alpha) else: damped_mat_g = mat_g + ridge_epsilon * identity z = (1 - 1 / alpha) / (2 * tf.norm(damped_mat_g)) # The best value for z is # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / # (c_max^{1-alpha} - c_min^{1-alpha}) # where c_max and c_min are the largest and smallest singular values of # damped_mat_g. # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) # Can replace above line by the one below, but it is less accurate, # hence needs more iterations to converge. # z = (1 - 1/alpha) / tf.trace(damped_mat_g) # If we want the method to always converge, use z = 1 / norm(damped_mat_g) # or z = 1 / tf.trace(damped_mat_g), but these can result in many # extra iterations. new_mat_m_0 = damped_mat_g * z new_error = tf.reduce_max(tf.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * tf.pow(z, neg_alpha) _, mat_m, mat_h, old_mat_h, error, convergence = tf.while_loop( _iter_condition, _iter_body, [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) error = tf.reduce_max(tf.abs(mat_m - identity)) is_converged = tf.cast(convergence, old_mat_h.dtype) resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h return resultant_mat_h, error