def _verifySolve(self, x, y, batch_dims=None): for adjoint in False, True: for np_type in [np.float32, np.float64]: a = x.astype(np_type) b = y.astype(np_type) if adjoint: a_np = np.conj(np.transpose(a)) else: a_np = a if batch_dims is not None: a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) np_ans = np.linalg.solve(a_np, b) with self.test_session(): # Test the batch version, which works for ndim >= 2 tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint) out = tf_ans.eval() self.assertEqual(tf_ans.get_shape(), out.shape) self.assertEqual(np_ans.shape, out.shape) self.assertAllClose(np_ans, out) if a.ndim == 2: # Test the simple version tf_ans = tf.matrix_solve(a, b, adjoint=adjoint) out = tf_ans.eval() self.assertEqual(out.shape, tf_ans.get_shape()) self.assertEqual(np_ans.shape, out.shape) self.assertAllClose(np_ans, out)
def testBatchResultSize(self): # 3x3x3 matrices, 3x3x1 right-hand sides. matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3) rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1) answer = tf.batch_matrix_solve(matrix, rhs) ls_answer = tf.batch_matrix_solve_ls(matrix, rhs) self.assertEqual(ls_answer.get_shape(), [3, 3, 1]) self.assertEqual(answer.get_shape(), [3, 3, 1])
def test_solve(self): with self.test_session(): for batch_shape in [(), (2, 3,)]: for k in [1, 4]: operator, mat = self._build_operator_and_mat(batch_shape, k) # Work with 5 simultaneous systems. 5 is arbitrary. x = self._rng.randn(*(batch_shape + (k, 5))) self._compare_results( expected=tf.batch_matrix_solve(mat, x).eval(), actual=operator.solve(x))
def _verifySolve(self, x, y): for np_type in [np.float32, np.float64]: a = x.astype(np_type) b = y.astype(np_type) with self.test_session(): if a.ndim == 2: tf_ans = tf.matrix_solve(a, b) else: tf_ans = tf.batch_matrix_solve(a, b) out = tf_ans.eval() np_ans = np.linalg.solve(a, b) self.assertEqual(np_ans.shape, out.shape) self.assertAllClose(np_ans, out)
def test_sqrt_solve(self): # Square roots are not unique, but we should still have # S^{-T} S^{-1} x = A^{-1} x. # In our case, we should have S = S^T, so then S^{-1} S^{-1} x = A^{-1} x. with self.test_session(): for batch_shape in [(), (2, 3,)]: for k in [1, 4]: operator, mat = self._build_operator_and_mat(batch_shape, k) # Work with 5 simultaneous systems. 5 is arbitrary. x = self._rng.randn(*(batch_shape + (k, 5))) self._compare_results( expected=tf.batch_matrix_solve(mat, x).eval(), actual=operator.sqrt_solve(operator.sqrt_solve(x)))
def test_solve(self): with self.test_session(): for batch_shape in [(), ( 2, 3, )]: for k in [1, 4]: operator, mat = self._build_operator_and_mat( batch_shape, k) # Work with 5 simultaneous systems. 5 is arbitrary. x = self._rng.randn(*(batch_shape + (k, 5))) self._compare_results(expected=tf.batch_matrix_solve( mat, x).eval(), actual=operator.solve(x))
def test_sqrt_solve(self): # Square roots are not unique, but we should still have # S^{-T} S^{-1} x = A^{-1} x. # In our case, we should have S = S^T, so then S^{-1} S^{-1} x = A^{-1} x. with self.test_session(): for batch_shape in [(), ( 2, 3, )]: for k in [1, 4]: operator, mat = self._build_operator_and_mat( batch_shape, k) # Work with 5 simultaneous systems. 5 is arbitrary. x = self._rng.randn(*(batch_shape + (k, 5))) self._compare_results( expected=tf.batch_matrix_solve(mat, x).eval(), actual=operator.sqrt_solve(operator.sqrt_solve(x)))
def test_BatchMatrixSolve(self): t = tf.batch_matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1))) self.check(t)
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update the row_factors, else update the column factors. sp_input: Please refer to comments for update_row_factors and update_col_factors. transpose_input: If true, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. Returns: A tuple consisting of the following two elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. """ assert isinstance(sp_input, ops.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache sharding_func = WALSModel._get_sharding_func( self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache sharding_func = WALSModel._get_sharding_func( self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = tf.split(1, 2, sp_input.indices) update_row_indices, all_row_ids = tf.unique(row_ids[:, 0]) update_col_indices, all_col_ids = tf.unique(col_ids[:, 0]) col_ids = tf.expand_dims(tf.cast(all_col_ids, tf.int64), 1) row_ids = tf.expand_dims(tf.cast(all_row_ids, tf.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [tf.cast(tf.shape(update_row_indices)[0], tf.int64)] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [tf.cast(tf.shape(update_col_indices)[0], tf.int64)] gather_indices = update_col_indices num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup(right_factors, gather_indices, partition_strategy='div') new_sp_indices = tf.concat(1, [row_ids, col_ids]) new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input else tf.concat(0, [col_shape, row_shape])) new_sp_input = tf.SparseTensor(indices=new_sp_indices, values=sp_input.values, shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization is not None: total_lhs += self._regularization if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * tf.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = tf.transpose( tf.matrix_solve(total_lhs, tf.transpose(total_rhs))) else: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy='div') col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy='div') partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs") total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs total_rhs = tf.expand_dims(total_rhs, -1) new_left_values = tf.squeeze( tf.batch_matrix_solve(total_lhs, total_rhs), [2]) return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))
def _batch_solve(self, rhs): return tf.batch_matrix_solve(self._pos_def_matrix, rhs)
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update the row_factors, else update the column factors. sp_input: Please refer to comments for update_row_factors and update_col_factors. transpose_input: If true, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. Returns: A tuple consisting of the following two elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. """ assert isinstance(sp_input, ops.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = tf.split(1, 2, sp_input.indices) update_row_indices, all_row_ids = tf.unique(row_ids[:, 0]) update_col_indices, all_col_ids = tf.unique(col_ids[:, 0]) col_ids = tf.expand_dims(tf.cast(all_col_ids, tf.int64), 1) row_ids = tf.expand_dims(tf.cast(all_row_ids, tf.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [tf.cast(tf.shape(update_row_indices)[0], tf.int64)] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [tf.cast(tf.shape(update_col_indices)[0], tf.int64)] gather_indices = update_col_indices num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup(right_factors, gather_indices, partition_strategy='div') new_sp_indices = tf.concat(1, [row_ids, col_ids]) new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input else tf.concat(0, [col_shape, row_shape])) new_sp_input = tf.SparseTensor(indices=new_sp_indices, values=sp_input.values, shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization is not None: total_lhs += self._regularization if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * tf.sparse_tensor_dense_matmul(new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = tf.transpose(tf.matrix_solve(total_lhs, tf.transpose(total_rhs))) else: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy='div') col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy='div') partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs") total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs total_rhs = tf.expand_dims(total_rhs, -1) new_left_values = tf.squeeze(tf.batch_matrix_solve(total_lhs, total_rhs), [2]) return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))