def testGradient(self): if np.__version__ == "1.13.0": self.skipTest("numpy 1.13.0 bug") np.random.seed(8161) test_dims = [(11, 1, 5, 7, 1), (2, 2)] with self.test_session(use_gpu=False): for dims in test_dims: sp_t, nnz = _sparsify(np.random.randn(*dims)) # reduce random axes from 1D to N-D for d in range(1, len(dims) + 1): axes = np.random.choice(len(dims), size=d, replace=False).tolist() reduced = sparse_ops.sparse_reduce_sum(sp_t, axes) err = gradient_checker.compute_gradient_error(sp_t.values, (nnz,), reduced, reduced.eval().shape) self.assertLess(err, 1e-3) # Tests for negative axes. reduced = sparse_ops.sparse_reduce_sum(sp_t, -1) err = gradient_checker.compute_gradient_error(sp_t.values, (nnz,), reduced, reduced.eval().shape) self.assertLess(err, 1e-3)
def testInvalidAxes(self): sp_t = ops.SparseTensor(self.ind, self.vals, self.shape) with self.test_session(use_gpu=False): with self.assertRaisesOpError("Invalid reduction dimension -3"): sparse_ops.sparse_reduce_sum(sp_t, -3).eval() with self.assertRaisesOpError("Invalid reduction dimension 2"): sparse_ops.sparse_reduce_sum(sp_t, 2).eval()
def testInvalidAxes(self): sp_t = sparse_tensor.SparseTensor(self.ind, self.vals, self.dense_shape) with test_util.force_cpu(): with self.assertRaisesOpError("Invalid reduction dimension -3"): self.evaluate(sparse_ops.sparse_reduce_sum(sp_t, -3)) with self.assertRaisesOpError("Invalid reduction dimension 2"): self.evaluate(sparse_ops.sparse_reduce_sum(sp_t, 2)) with self.assertRaisesOpError("Invalid reduction dimension -3"): self.evaluate(sparse_ops.sparse_reduce_max(sp_t, -3)) with self.assertRaisesOpError("Invalid reduction dimension 2"): self.evaluate(sparse_ops.sparse_reduce_max(sp_t, 2))
def calculate_loss(input_mat, row_factors, col_factors, regularization=None, w0=1., row_weights=None, col_weights=None): """Calculates the loss of a given factorization. Using a non distributed method, different than the one implemented in the WALS model. The weight of an observed entry (i, j) (i.e. such that input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]). Args: input_mat: The input matrix, a SparseTensor of rank 2. row_factors: The row factors, a dense Tensor of rank 2. col_factors: The col factors, a dense Tensor of rank 2. regularization: the regularization coefficient, a scalar. w0: the weight of unobserved entries. A scalar. row_weights: A dense tensor of rank 1. col_weights: A dense tensor of rank 1. Returns: The total loss. """ wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None else constant_op.constant(1.)) wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None else constant_op.constant(1.)) reg = (regularization if regularization is not None else constant_op.constant(0.)) row_indices, col_indices = array_ops.split(input_mat.indices, axis=1, num_or_size_splits=2) gathered_row_factors = array_ops.gather(row_factors, row_indices) gathered_col_factors = array_ops.gather(col_factors, col_indices) sp_approx_vals = array_ops.squeeze(math_ops.matmul( gathered_row_factors, gathered_col_factors, adjoint_b=True)) sp_approx = sparse_tensor.SparseTensor( indices=input_mat.indices, values=sp_approx_vals, dense_shape=input_mat.dense_shape) sp_approx_sq = math_ops.square(sp_approx) row_norm = math_ops.reduce_sum(math_ops.square(row_factors)) col_norm = math_ops.reduce_sum(math_ops.square(col_factors)) row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul( row_factors, col_factors, transpose_b=True))) resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1)) resid_sq = math_ops.square(resid) loss = w0 * ( sparse_ops.sparse_reduce_sum(resid_sq) - sparse_ops.sparse_reduce_sum(sp_approx_sq) ) loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) + w0 * row_col_norm + reg * (row_norm + col_norm)) return loss.eval()
def _compare(self, sp_t, reduction_axes, ndims, keep_dims): densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval() np_ans = densified if reduction_axes is None: np_ans = np.sum(np_ans, keepdims=keep_dims) else: if not isinstance(reduction_axes, list): # Single scalar. reduction_axes = [reduction_axes] reduction_axes = np.array(reduction_axes).astype(np.int32) # Handles negative axes. reduction_axes = (reduction_axes + ndims) % ndims # Loop below depends on sorted. reduction_axes.sort() for ra in reduction_axes.ravel()[::-1]: np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) with self.test_session(): tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes, keep_dims) out_dense = tf_dense_ans.eval() tf_sparse_ans = sparse_ops.sparse_reduce_sum_sparse(sp_t, reduction_axes, keep_dims) # Convert to dense for comparison purposes. out_sparse = sparse_ops.sparse_tensor_to_dense(tf_sparse_ans).eval() self.assertAllClose(np_ans, out_dense) self.assertAllClose(np_ans, out_sparse)
def _SparseSoftmaxGrad(op, grad): """Gradients for SparseSoftmax. The calculation is the same as SoftmaxGrad: grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax where we now only operate on the non-zero values present in the SparseTensors. Args: op: the SparseSoftmax op. grad: the upstream gradient w.r.t. the non-zero SparseSoftmax output values. Returns: Gradients w.r.t. the input (sp_indices, sp_values, sp_shape). """ indices, shape = op.inputs[0], op.inputs[2] out_vals = op.outputs[0] sp_output = sparse_tensor.SparseTensor(indices, out_vals, shape) sp_grad = sparse_tensor.SparseTensor(indices, grad, shape) sp_product = sparse_tensor.SparseTensor( indices, sp_output.values * sp_grad.values, shape) # [..., B, 1], dense. sum_reduced = -sparse_ops.sparse_reduce_sum(sp_product, [-1], keep_dims=True) # sparse [..., B, C] + dense [..., B, 1] with broadcast; outputs sparse. sp_sum = sparse_ops.sparse_dense_cwise_add(sp_grad, sum_reduced) grad_x = sp_sum.values * sp_output.values return [None, grad_x, None]
def _build_multilabel_adjacency(sparse_labels): """Builds multilabel adjacency matrix. As of March 14th, 2017, there's no op for the dot product between two sparse tensors in TF. However, there is `sparse_minimum` op which is equivalent to an AND op between two sparse boolean tensors. This computes the dot product between two sparse boolean inputs. Args: sparse_labels: List of 1-D boolean sparse tensors. Returns: adjacency_matrix: 2-D dense `Tensor`. """ num_pairs = len(sparse_labels) adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) for i in range(num_pairs): for j in range(num_pairs): sparse_dot_product = math_ops.to_float( sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( sparse_labels[i], sparse_labels[j]))) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) one_hot_matrix = array_ops.pad(sparse_dot_product, [[i, num_pairs-i-1], [j, num_pairs-j-1]], 'CONSTANT') adjacency_matrix += one_hot_matrix return adjacency_matrix
def testGradient(self): np.random.seed(8161) test_dims = [(11, 1, 5, 7, 1), (2, 2)] with self.test_session(use_gpu=False): for dims in test_dims: sp_t, nnz = _sparsify(np.random.randn(*dims)) # reduce random axes from 1D to N-D for d in range(1, len(dims) + 1): axes = np.random.choice(len(dims), size=d, replace=False).tolist() reduced = sparse_ops.sparse_reduce_sum(sp_t, axes) err = tf.test.compute_gradient_error(sp_t.values, (nnz,), reduced, reduced.eval().shape) self.assertLess(err, 1e-3)
def _compare(self, sp_t, reduction_axes, keep_dims): densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval() np_ans = densified if reduction_axes is None: np_ans = np.sum(np_ans, keepdims=keep_dims) else: if isinstance(reduction_axes, list): reduction_axes = sorted(reduction_axes) # loop below depends on sorted reduction_axes = np.array(reduction_axes).astype(np.int32) for ra in reduction_axes.ravel()[::-1]: np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) with self.test_session(): tf_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes, keep_dims) out = tf_ans.eval() self.assertAllClose(np_ans, out)
def _compare(self, sp_t, reduction_axes, ndims, keep_dims, do_sum): densified = self.evaluate(sparse_ops.sparse_tensor_to_dense(sp_t)) np_ans = densified if reduction_axes is None: if do_sum: np_ans = np.sum(np_ans, keepdims=keep_dims) else: np_ans = np.max(np_ans, keepdims=keep_dims) else: if not isinstance(reduction_axes, list): # Single scalar. reduction_axes = [reduction_axes] reduction_axes = np.array(reduction_axes).astype(np.int32) # Handles negative axes. reduction_axes = (reduction_axes + ndims) % ndims # Loop below depends on sorted. reduction_axes.sort() for ra in reduction_axes.ravel()[::-1]: if do_sum: np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) else: np_ans = np.max(np_ans, axis=ra, keepdims=keep_dims) with self.cached_session(): if do_sum: tf_dense_ans = sparse_ops.sparse_reduce_sum( sp_t, reduction_axes, keep_dims) else: tf_dense_ans = sparse_ops.sparse_reduce_max( sp_t, reduction_axes, keep_dims) out_dense = self.evaluate(tf_dense_ans) if do_sum: tf_sparse_ans = sparse_ops.sparse_reduce_sum_sparse( sp_t, reduction_axes, keep_dims) else: tf_sparse_ans = sparse_ops.sparse_reduce_max_sparse( sp_t, reduction_axes, keep_dims) # Convert to dense for comparison purposes. out_sparse = sparse_ops.sparse_tensor_to_dense(tf_sparse_ans) self.assertAllClose(np_ans, out_dense) self.assertAllClose(np_ans, out_sparse)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') shutil.rmtree(FLAGS.saved_model_path) # Create the graph x = array_ops.sparse_placeholder(dtype=dtypes.int32, shape=None, name='input') r = sparse_ops.sparse_reduce_sum(x) x1 = array_ops.placeholder(dtype=dtypes.int32, shape=(1, 3), name='input1') r1 = math_ops.add(x1, 1) sess = session.Session() sm_builder = builder.SavedModelBuilder(FLAGS.saved_model_path) tensor_info_x = utils.build_tensor_info(x) tensor_info_r = utils.build_tensor_info(r) tensor_info_x1 = utils.build_tensor_info(x1) tensor_info_r1 = utils.build_tensor_info(r1) sparse_signature = ( signature_def_utils.build_signature_def( inputs={'x': tensor_info_x}, outputs={'r': tensor_info_r}, method_name=signature_constants.PREDICT_METHOD_NAME)) dense_signature = ( signature_def_utils.build_signature_def( inputs={'x1': tensor_info_x1}, outputs={'r1': tensor_info_r1}, method_name=signature_constants.PREDICT_METHOD_NAME)) sm_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map={ 'sparse': sparse_signature, 'dense': dense_signature, }, strip_default_attrs=True) sm_builder.save()
def _compare(self, sp_t, reduction_axes, keep_dims): densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval() np_ans = densified if reduction_axes is None: np_ans = np.sum(np_ans, keepdims=keep_dims) else: if isinstance(reduction_axes, list): reduction_axes = sorted( reduction_axes) # loop below depends on sorted reduction_axes = np.array(reduction_axes).astype(np.int32) for ra in reduction_axes.ravel()[::-1]: np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) with self.test_session(): tf_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes, keep_dims) out = tf_ans.eval() self.assertAllClose(np_ans, out)
def _SparseReduceMinOrMaxGrad(op, out_grad): sp_indices = op.inputs[0] sp_values = op.inputs[1] sp_shape = op.inputs[2] reduction_axes = op.inputs[3] output = op.outputs[0] # Handle keepdims output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3]) out_grad = array_ops.reshape(out_grad, output_shape_kept_dims) output = array_ops.reshape(output, output_shape_kept_dims) # Map input and output coefficients scale = sp_shape // math_ops.to_int64(output_shape_kept_dims) scaled_indices = sp_indices // scale # Map pooled values with corresponding max/min values sp_max_val = array_ops.gather_nd(output, scaled_indices) indicators = math_ops.cast(math_ops.equal(sp_values, sp_max_val), out_grad.dtype) grad_values = array_ops.gather_nd(out_grad, scaled_indices) # Compute the number of selected (maximum or minimum) elements in each # reduction dimension. If there are multiple minimum or maximum elements # then the gradient will be divided between them. # (same as for MaxGrad) sp_indicators = sparse_tensor.SparseTensor(sp_indices, indicators, sp_shape) num_selected = array_ops.gather_nd( sparse_ops.sparse_reduce_sum(sp_indicators, axis=reduction_axes, keep_dims=True), scaled_indices) # (input_indices, input_values, input_shape, reduction_axes) return [ None, math_ops.div(indicators, math_ops.maximum(num_selected, 1)) * grad_values, None, None ]
def _compare(self, sp_t, reduction_axes, ndims, keep_dims): densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval() np_ans = densified if reduction_axes is None: np_ans = np.sum(np_ans, keepdims=keep_dims) else: if not isinstance(reduction_axes, list): # Single scalar. reduction_axes = [reduction_axes] reduction_axes = np.array(reduction_axes).astype(np.int32) # Handles negative axes. reduction_axes = (reduction_axes + ndims) % ndims # Loop below depends on sorted. reduction_axes.sort() for ra in reduction_axes.ravel()[::-1]: np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) with self.test_session(): tf_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes, keep_dims) out = tf_ans.eval() self.assertAllClose(np_ans, out)
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to calculate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:1138 # https://github.com/imdone/tensorflow/issues/1139 # transposing explicitly. # TODO (rmlarsen): multi-thread tf.matrix_solve. id:1249 # https://github.com/imdone/tensorflow/issues/1250 new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO (yifanchen): Add special handling for single shard without using id:845 # https://github.com/imdone/tensorflow/issues/846 # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update( left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor( loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims( row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims( col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul( new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) ) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to caluclate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update( left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor( loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims( row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims( col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul( new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) ) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)