def _sddmm_grad(op, grad): """Gradient operation for sampled dense dense matrix multiplication.""" # Collect the inputs. m = op.inputs[0] n = op.inputs[1] row_indices = op.inputs[2] row_offsets = op.inputs[3] column_indices = op.inputs[4] lhs_matrix = op.inputs[5] rhs_matrix = op.inputs[6] # lhs matrix gradient: multiply the sparse gradient by the rhs matrix. lhs_matrix_grad = kernels.spmm(m, n, grad, row_indices, row_offsets, column_indices, rhs_matrix) # rhs matrix gradient: transpose the sparse gradient, calculate the new # row indices, and multiply the sparse gradient with the lhs matrix. grad_t, row_offsets_t, column_indices_t = kernels.csr_transpose( m, n, grad, row_offsets, column_indices) row_indices_t = diffsort(row_offsets_t) rhs_matrix_grad = kernels.spmm(n, m, grad_t, row_indices_t, row_offsets_t, column_indices_t, lhs_matrix) # NOTE: Because we exposed the sparse matrix meta-data as arguments to # the underlying op, we need to return 'None' as gradients for these # tensors. # # TODO(tgale): Make sure there are no performance implications for this. return [None] * 5 + [lhs_matrix_grad, rhs_matrix_grad]
def step(hprev, x): st_1, ct_1 = tf.unstack(hprev) rows, columns, values, row_indices, row_offsets, column_indices = self.dynamic_gate( x) fc_gate = kernels.spmm(rows, columns, values, row_indices, row_offsets, column_indices, tf.transpose(tf.concat([x, st_1], -1)), False, False) fc_gate = tf.transpose(fc_gate) + bias i, f, g, o = tf.split(fc_gate, 4, axis=1) i, f, g, o = tf.sigmoid(i), tf.sigmoid(f), tf.tanh(g), tf.sigmoid( o) ct = ct_1 * f + g * i st = tf.tanh(ct) * o return tf.stack([st, ct])
def _spmm_grad(op, grad): """Gradient operation for sparse matrix matrix multiplication.""" # Collect the inputs. m = op.inputs[0] k = op.inputs[1] values = op.inputs[2] row_indices = op.inputs[3] row_offsets = op.inputs[4] column_indices = op.inputs[5] dense_matrix = op.inputs[6] # Sparse matrix gradient: multiply the gradient by the transposed # dense matrix. sparse_matrix_grad = kernels.sddmm(m, k, row_indices, row_offsets, column_indices, grad, dense_matrix, transpose_rhs=True) # Dense matrix gradient: transpose the sparse weights, calculate the # new row indices, and multiply sparse matrix with dense gradient. values_t, row_offsets_t, column_indices_t = kernels.csr_transpose( m, k, values, row_offsets, column_indices) row_indices_t = diffsort(row_offsets_t) dense_matrix_grad = kernels.spmm(k, m, values_t, row_indices_t, row_offsets_t, column_indices_t, grad) # NOTE: Because we exposed the sparse matrix meta-data as arguments to # the underlying op, we need to return 'None' as gradients for these # tensors. # # TODO(tgale): Make sure there are no performance implications for this. return [ None, None, sparse_matrix_grad, None, None, None, dense_matrix_grad ]
def replicated_spmm(values, topology, dense_matrix, transpose_lhs=False, transpose_rhs=False): """Convenience API for replicated spmm. TODO(tgale): Add a better matrix type instead of having this. Args: values: Tensor, the replicated sparse matrix values. topology: SparseTopology, the sparse matrix topology. dense_matrix: Tensor, the right-hand, dense operand to the matrix product. transpose_lhs: bool, whether to transpose the lhs operand. transpose_rhs: bool, whether to transpose the rhs operand. Returns: Tensor, the dense matrix result of the product. """ return kernels.spmm(topology._rows, topology._columns, values, topology.row_indices, topology.row_offsets, topology.column_indices, dense_matrix, transpose_lhs, transpose_rhs)
def spmm(sparse_matrix, dense_matrix, transpose_lhs=False, transpose_rhs=False): """Sparse matrix matrix multiplication. Computes the product of a sparse matrix and a dense matrix. Args: sparse_matrix: SparseMatrix, the left-hand sparse operand to the matrix product. dense_matrix: Tensor, the right-hand, dense operand to the matrix product. transpose_lhs: bool, whether to transpose the lhs operand. transpose_rhs: bool, whether to transpose the rhs operand. Returns: Tensor, the dense matrix result of the product. """ return kernels.spmm(sparse_matrix._rows, sparse_matrix._columns, sparse_matrix.values, sparse_matrix.row_indices, sparse_matrix.row_offsets, sparse_matrix.column_indices, dense_matrix, transpose_lhs, transpose_rhs)