def testSpmm_Replicated(self, r, m, k, n, sparsity, use_gpu): # Helpers to set up the matrices. connector = connectors.Uniform(sparsity, round_to=4) initializer = initializers.Uniform() # Numpy matrices for verification. mask = connector(initializer([m, k])) mask[mask != 0] = 1.0 lhs_np = np.expand_dims(mask, axis=0) * initializer([r, m, k]) rhs_np = initializer([r, k, n]) # TensorFlow graph. topology = sparse_matrix.SparseTopology("topology", mask=mask) lhs = tf.Variable(np.reshape(lhs_np[lhs_np != 0], [r, -1]), dtype=tf.float32) rhs = tf.Variable(rhs_np, dtype=tf.float32) output = ops.replicated_spmm(lhs, topology, rhs) # Execute the op and compare the results. with self.test_session(use_gpu=use_gpu) as sess: sess.run(tf.global_variables_initializer()) out = sess.run(output) for i in range(r): expected_out = np.dot(lhs_np[i, :, :], rhs_np[i, :, :]) self.assertAllClose(out[i, :], expected_out, atol=1e-03, rtol=1e-05)
def sparse_dot_product_attention(q, k, v, topology, **_): q_3d, k_3d, v_3d = [preprocess_attention_component(x) for x in [q, k, v]] logits = ops.replicated_sddmm(q_3d, k_3d, topology, transpose_rhs=True) weights = ops.replicated_sparse_softmax(logits, topology) out = ops.replicated_spmm(weights, topology, v_3d) return tf.reshape(out, tf.shape(q))