def testSparseSoftmax_Replicated(self, r, m, n, sparsity): # Helpers to set up the matrices. connector = connectors.Uniform(sparsity) initializer = initializers.Uniform() # Numpy matrix for verification. mask = connector(np.ones([m, n])) matrix_np = np.expand_dims(mask, axis=0) * initializer([r, m, n]) # TensorFlow graph. topology = sparse_matrix.SparseTopology("topology", mask=mask) values = tf.Variable(np.reshape(matrix_np[matrix_np != 0], [r, -1]), dtype=tf.float32) output = ops.replicated_sparse_softmax(values, topology) with self.test_session(use_gpu=True) as sess: sess.run(tf.global_variables_initializer()) v, ro, ci = sess.run( [output, topology.row_offsets, topology.column_indices]) # Zero terms should not contribute to the softmax. matrix_np[matrix_np == 0] = -1e9 def softmax(x): maxs = np.expand_dims(x.max(axis=1), axis=1) exps = np.exp(x - maxs) return exps / np.expand_dims(np.sum(exps, axis=1), axis=1) for i in range(r): expected_output = self.dense_to_scipy( softmax(matrix_np[i, :, :])) actual_output = self.sparse_to_scipy(v[i, :], ro, ci, expected_output.shape) self.assert_sparse_matrix_equal(actual_output, expected_output, atol=1e-03, rtol=1e-05)
def sparse_dot_product_attention(q, k, v, topology, **_): q_3d, k_3d, v_3d = [preprocess_attention_component(x) for x in [q, k, v]] logits = ops.replicated_sddmm(q_3d, k_3d, topology, transpose_rhs=True) weights = ops.replicated_sparse_softmax(logits, topology) out = ops.replicated_spmm(weights, topology, v_3d) return tf.reshape(out, tf.shape(q))