def testSddmm_Replicated(self, r, m, k, n, sparsity, use_gpu): # Helpers to set up the matrices. connector = connectors.Uniform(sparsity) initializer = initializers.Uniform() # Numpy matrices for verification. lhs_np = initializer([r, m, k]) rhs_np = initializer([r, n, k]) output_np = connector(np.ones([m, n])) # TensorFlow graph. output_topology = sparse_matrix.SparseTopology("output_topology", mask=output_np) lhs = tf.Variable(lhs_np, dtype=tf.float32) rhs = tf.Variable(rhs_np, dtype=tf.float32) output = ops.replicated_sddmm(lhs, rhs, output_topology, transpose_rhs=True) # Execute the op and compare the results. with self.test_session(use_gpu=use_gpu) as sess: sess.run(tf.global_variables_initializer()) # Run the replicated sddmm. v, ro, ci = sess.run([ output, output_topology.row_offsets, output_topology.column_indices ]) for i in range(r): expected_output = self.dense_to_scipy( output_np * np.dot(lhs_np[i, :, :], np.transpose(rhs_np[i, :, :]))) actual_output = self.sparse_to_scipy( v[i, :], ro, ci, shape=expected_output.shape) self.assert_sparse_matrix_equal(actual_output, expected_output, atol=1e-03, rtol=1e-05)
def sparse_dot_product_attention(q, k, v, topology, **_): q_3d, k_3d, v_3d = [preprocess_attention_component(x) for x in [q, k, v]] logits = ops.replicated_sddmm(q_3d, k_3d, topology, transpose_rhs=True) weights = ops.replicated_sparse_softmax(logits, topology) out = ops.replicated_spmm(weights, topology, v_3d) return tf.reshape(out, tf.shape(q))