Exemplo n.º 1
0
    def testSddmm_Replicated(self, r, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = initializer([r, m, k])
        rhs_np = initializer([r, n, k])
        output_np = connector(np.ones([m, n]))

        # TensorFlow graph.
        output_topology = sparse_matrix.SparseTopology("output_topology",
                                                       mask=output_np)
        lhs = tf.Variable(lhs_np, dtype=tf.float32)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.replicated_sddmm(lhs,
                                      rhs,
                                      output_topology,
                                      transpose_rhs=True)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())

            # Run the replicated sddmm.
            v, ro, ci = sess.run([
                output, output_topology.row_offsets,
                output_topology.column_indices
            ])

            for i in range(r):
                expected_output = self.dense_to_scipy(
                    output_np *
                    np.dot(lhs_np[i, :, :], np.transpose(rhs_np[i, :, :])))
                actual_output = self.sparse_to_scipy(
                    v[i, :], ro, ci, shape=expected_output.shape)
                self.assert_sparse_matrix_equal(actual_output,
                                                expected_output,
                                                atol=1e-03,
                                                rtol=1e-05)
Exemplo n.º 2
0
def sparse_dot_product_attention(q, k, v, topology, **_):
    q_3d, k_3d, v_3d = [preprocess_attention_component(x) for x in [q, k, v]]
    logits = ops.replicated_sddmm(q_3d, k_3d, topology, transpose_rhs=True)
    weights = ops.replicated_sparse_softmax(logits, topology)
    out = ops.replicated_spmm(weights, topology, v_3d)
    return tf.reshape(out, tf.shape(q))