def testSpmmGradient(self, m, k, n, sparsity, use_gpu): # Helpers to set up the matrices. connector = connectors.Uniform(sparsity) initializer = initializers.Uniform() # Numpy matrices for verification. lhs_np = connector(initializer([m, k])) rhs_np = initializer([k, n]) lhs = sparse_matrix.SparseMatrix("lhs", matrix=lhs_np) rhs = tf.Variable(rhs_np, dtype=tf.float32) output = ops.spmm(lhs, rhs) with self.test_session(use_gpu=use_gpu) as sess: sess.run(tf.global_variables_initializer()) error = tf.test.compute_gradient_error( [lhs.values, rhs], [lhs.values.shape.as_list(), [k, n]], output, [m, n]) self.assertLess(error, 1e-3)
def testSpmm(self, m, k, n, sparsity, use_gpu): # Helpers to set up the matrices. connector = connectors.Uniform(sparsity) initializer = initializers.Uniform() # Numpy matrices for verification. lhs_np = connector(initializer([m, k])) rhs_np = initializer([k, n]) # TensorFlow graph. lhs = sparse_matrix.SparseMatrix("lhs", matrix=lhs_np) rhs = tf.Variable(rhs_np, dtype=tf.float32) output = ops.spmm(lhs, rhs) # Execute the op and compare the results. with self.test_session(use_gpu=use_gpu) as sess: sess.run(tf.global_variables_initializer()) self.assertAllClose(sess.run(output), np.dot(lhs_np, rhs_np), atol=1e-03, rtol=1e-05)
def call(self, inputs, training=None): # TODO(tgale): The following code assumes that the input channels, # height, and width are all defined and that the batch dimesnion # is undefined. Fix this to handle arbitrary input shapes correctly. input_shape = inputs.shape.as_list() flat_inputs = tf.reshape(inputs, [-1, input_shape[2] * input_shape[3]]) output_shape = [-1, self.filters, input_shape[2], input_shape[3]] # Use the fused kernel if possible. if self.use_bias and self.activation == tf.nn.relu: flat_output = ops.fused_spmm(self.kernel, flat_inputs, self.bias) return tf.reshape(flat_output, output_shape) flat_output = ops.spmm(self.kernel, flat_inputs) out = tf.reshape(flat_output, output_shape) if self.use_bias: out = tf.nn.bias_add(out, self.bias, data_format="NCHW") if self.activation: out = self.activation(out) return out