예제 #1
0
    def testSparseSoftmax_Replicated(self, r, m, n, sparsity):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrix for verification.
        mask = connector(np.ones([m, n]))
        matrix_np = np.expand_dims(mask, axis=0) * initializer([r, m, n])

        # TensorFlow graph.
        topology = sparse_matrix.SparseTopology("topology", mask=mask)
        values = tf.Variable(np.reshape(matrix_np[matrix_np != 0], [r, -1]),
                             dtype=tf.float32)
        output = ops.replicated_sparse_softmax(values, topology)

        with self.test_session(use_gpu=True) as sess:
            sess.run(tf.global_variables_initializer())
            v, ro, ci = sess.run(
                [output, topology.row_offsets, topology.column_indices])

            # Zero terms should not contribute to the softmax.
            matrix_np[matrix_np == 0] = -1e9

            def softmax(x):
                maxs = np.expand_dims(x.max(axis=1), axis=1)
                exps = np.exp(x - maxs)
                return exps / np.expand_dims(np.sum(exps, axis=1), axis=1)

            for i in range(r):
                expected_output = self.dense_to_scipy(
                    softmax(matrix_np[i, :, :]))

                actual_output = self.sparse_to_scipy(v[i, :], ro, ci,
                                                     expected_output.shape)
                self.assert_sparse_matrix_equal(actual_output,
                                                expected_output,
                                                atol=1e-03,
                                                rtol=1e-05)
예제 #2
0
def sparse_dot_product_attention(q, k, v, topology, **_):
    q_3d, k_3d, v_3d = [preprocess_attention_component(x) for x in [q, k, v]]
    logits = ops.replicated_sddmm(q_3d, k_3d, topology, transpose_rhs=True)
    weights = ops.replicated_sparse_softmax(logits, topology)
    out = ops.replicated_spmm(weights, topology, v_3d)
    return tf.reshape(out, tf.shape(q))