def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False): x_mat = np.matrix(x) if adjoint_a: x_mat = x_mat.H y_mat = np.matrix(y) if adjoint_b: y_mat = y_mat.H np_ans = x_mat * y_mat x_indices = np.vstack(np.where(x)).astype(np.int64).T x_values = x[np.where(x)] x_shape = x.shape with self.test_session(use_gpu=True): sp_x_value = tf.SparseTensorValue( indices=x_indices, values=x_values, shape=x_shape) tf_value_ans = sparse_ops.sparse_tensor_dense_matmul( sp_x_value, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b) tf_tensor_ans = sparse_ops.sparse_tensor_dense_matmul( tf.SparseTensor.from_value(sp_x_value), y, adjoint_a=adjoint_a, adjoint_b=adjoint_b) # Ensure that the RHS shape is known at least. self.assertEqual(tf_value_ans.get_shape()[1], np_ans.shape[1]) self.assertEqual(tf_tensor_ans.get_shape()[1], np_ans.shape[1]) for out in (tf_value_ans.eval(), tf_tensor_ans.eval()): if x.dtype == np.float32: self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) elif x.dtype == np.float64: self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6) else: self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
def testInvalidIndicesForSparseTensorDenseMatmul(self): # Note: use_gpu=False because nice errors are only returned from CPU kernel. with self.session(use_gpu=False): indices = np.matrix([[1, 10]]).astype(np.int64) values = np.array([10]).astype(np.float32) shape = [3, 2] sparse_t = sparse_tensor.SparseTensor(indices, values, shape) # Test multiplying by both a small and large dense matrix, to hit # both cases in the kernel. dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) with self.assertRaisesOpError( "k .10. from index.0,1. out of bounds .>=2."): sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t).eval() dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) with self.assertRaisesOpError( "k .10. from index.0,1. out of bounds .>=2."): sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t).eval() # Repeat with adjoint_a, to get a different error. dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) with self.assertRaisesOpError( "m .10. from index.0,1. out of bounds .>=2."): sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval() dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) with self.assertRaisesOpError( "m .10. from index.0,1. out of bounds .>=2."): sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval()
def _ExtractImagePatchesGrad(op, grad): batch_size, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].shape.dims ] input_bhwc = array_ops.shape(op.inputs[0]) batch_size = input_bhwc[0] channels = input_bhwc[3] # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = 1 + rows_in * cols_in input_idx = array_ops.reshape(math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), (1, rows_in, cols_in, 1)) input_idx_patched = gen_array_ops.extract_image_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("rates"), op.get_attr("padding")) # Create indices matrix for output tensor. _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims] _, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Indices for output start from 0. output_indices_num = rows_out * cols_out * ksize_r * ksize_c output_idx = array_ops.reshape(math_ops.range(output_indices_num, dtype=ops.dtypes.int64), (1, rows_out, cols_out, ksize_r * ksize_c)) # Construct mapping table for indices: (input -> output). idx_matrix = array_ops.concat( [array_ops.expand_dims(input_idx_patched, axis=-1), array_ops.expand_dims(output_idx, axis=-1)], axis=-1) idx_map = array_ops.reshape(idx_matrix, (-1, 2)) sp_shape = (input_indices_num, output_indices_num) sp_mat_full = sparse_tensor.SparseTensor( idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) # Remove all padding locations [0, :]. sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num)) grad_expanded = array_ops.transpose( array_ops.reshape( grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out]
def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. Gradients are only provided for the dense tensor. If either input is complex, no gradient is provided. Args: op: the SparseTensorDenseMatMul op grad: the incoming gradient Returns: Gradient for each of the 4 input tensors: (sparse_indices, sparse_values, sparse_shape, dense_tensor) The sparse tensor gradients are always None. """ sp_t = ops.SparseTensor(*op.inputs[:3]) adj_a = op.get_attr("adjoint_a") adj_b = op.get_attr("adjoint_b") a_type = sp_t.values.dtype b_type = op.inputs[3].dtype assert a_type == b_type is_complex = a_type == ops.dtypes.complex64 if is_complex: raise NotImplementedError("SparseTensorDenseMatMul op does not support " "complex gradients.") b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad, adjoint_a=not adj_a) if adj_b: b_grad = array_ops.transpose(b_grad) return (None, None, None, b_grad)
def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. If either input is complex, no gradient is provided. Args: op: the SparseTensorDenseMatMul op grad: the incoming gradient Returns: Gradient for each of the 4 input tensors: (sparse_indices, sparse_values, sparse_shape, dense_tensor) The gradients for indices and shape are None. Raises: TypeError: When the two operands don't have the same type. """ sp_t = ops.SparseTensor(*op.inputs[:3]) adj_a = op.get_attr("adjoint_a") adj_b = op.get_attr("adjoint_b") a_type = sp_t.values.dtype.base_dtype b_type = op.inputs[3].dtype.base_dtype if a_type != b_type: raise TypeError("SparseTensorDenseMatMul op received operands with " "different types: ", a_type, " and ", b_type) is_complex = a_type == ops.dtypes.complex64 if is_complex: raise NotImplementedError("SparseTensorDenseMatMul op does not support " "complex gradients.") # gradient w.r.t. dense b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad, adjoint_a=not adj_a) if adj_b: b_grad = array_ops.transpose(b_grad) # gradient w.r.t. sparse values a_indices = op.inputs[0] b = op.inputs[3] rows = a_indices[:, 0] cols = a_indices[:, 1] # TODO(zongheng, ebrevdo): add conjugates in the right places when complex # values are allowed. # TODO(zongheng): these gather calls could potentially duplicate rows/cols in # memory. If there is a need, we should look into implementing this more # intelligently to avoid duplicating data. parts_a = array_ops.gather(grad, rows if not adj_a else cols) parts_b = array_ops.gather(b if not adj_b else array_ops.transpose(b), cols if not adj_a else rows) a_values_grad = math_ops.reduce_sum(parts_a * parts_b, reduction_indices=1) # gradients w.r.t. (a_indices, a_values, a_shape, b) return (None, a_values_grad, None, b_grad)
def testConsumers(self): sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4]) w = ops.convert_to_tensor(np.ones([4, 1], np.float32)) out = sparse_ops.sparse_tensor_dense_matmul(sp, w) self.assertEqual(len(sp.consumers()), 1) self.assertEqual(sp.consumers()[0], out.op) dense = sparse_ops.sparse_tensor_to_dense(sp) self.assertEqual(len(sp.consumers()), 2) self.assertTrue(dense.op in sp.consumers()) self.assertTrue(out.op in sp.consumers())
def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self): # Note: use_gpu=False because nice errors are only returned from CPU kerne if not test.is_gpu_available(): return with self.session(use_gpu=True): indices = np.array([[1, 10]]).astype(np.int64) values = np.array([10]).astype(np.float32) shape = [3, 2] sparse_t = sparse_tensor.SparseTensor(indices, values, shape) # Test multiplying by both a small and large dense matrix, to hit # both cases in the kernel. dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) self.assertAllClose(expected_t, sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t).eval()) dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) expected_t = np.array( [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32) self.assertAllClose(expected_t, sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t).eval()) # Repeat with adjoint_a, now the error is that the sparse index # is OOO w.r.t. the output. The GPU kernel can't do much here, # so it just doesn't accumulate. dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) self.assertAllClose(expected_t, sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval()) dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) self.assertAllClose(expected_t, sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval())
def testShapeInference(self): x = np.random.rand(10, 10) x[np.abs(x) < 0.5] = 0 # Make it sparse y = np.random.randn(10, 20) x_indices = np.vstack(np.where(x)).astype(np.int64).T x_values = x[np.where(x)] x_shape = x.shape x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape) result = sparse_ops.sparse_tensor_dense_matmul(x_st, y) self.assertEqual(result.get_shape(), (10, 20)) x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None) x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values, x_shape_unknown) result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul( x_st_shape_unknown, y) self.assertEqual(result_left_shape_unknown.get_shape().as_list(), [None, 20]) x_shape_inconsistent = [10, 15] x_st_shape_inconsistent = sparse_tensor.SparseTensor(x_indices, x_values, x_shape_inconsistent) with self.assertRaisesRegexp(ValueError, "Dimensions must be equal"): sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y)
def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype): n, k, m = np.random.randint(1, 10, size=3) sp_t, nnz = self._randomTensor( [n, k], np_dtype, adjoint=adjoint_a, sparse=True) dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b) matmul = sparse_ops.sparse_tensor_dense_matmul( sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name) with self.test_session(use_gpu=True): dense_t_shape = [m, k] if adjoint_b else [k, m] sp_t_val_shape = [nnz] err = gradient_checker.compute_gradient_error( [dense_t, sp_t.values], [dense_t_shape, sp_t_val_shape], matmul, [n, m]) print("%s gradient err = %s" % (name, err)) self.assertLess(err, 1e-3)
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to calculate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func( self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func( self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split(value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims( math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims( math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast( array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast( array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast( array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup(right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat( [col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor(indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:894 gh:895 # transposing explicitly. # TODO (rmlarsen): multi-thread tf.matrix_solve. id:594 gh:594 new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO (yifanchen): Add special handling for single shard without using id:635 gh:636 # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies([ check_ops.assert_less_equal( array_ops.rank(row_weights), 1) ]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update(left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor(loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims(row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims(col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul(new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast( num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum( row_wt_mat * (sp_residual_sq * col_wt_mat))) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum( row_wt_mat * (ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)
def _ExtractImagePatchesGrad(op, grad): batch_size, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].shape.dims ] input_bhwc = array_ops.shape(op.inputs[0]) batch_size = input_bhwc[0] channels = input_bhwc[3] # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = 1 + rows_in * cols_in input_idx = array_ops.reshape( math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), (1, rows_in, cols_in, 1)) input_idx_patched = gen_array_ops.extract_image_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("rates"), op.get_attr("padding")) # Create indices matrix for output tensor. _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims] _, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Indices for output start from 0. output_indices_num = rows_out * cols_out * ksize_r * ksize_c output_idx = array_ops.reshape( math_ops.range(output_indices_num, dtype=ops.dtypes.int64), (1, rows_out, cols_out, ksize_r * ksize_c)) # Construct mapping table for indices: (input -> output). idx_matrix = array_ops.concat([ array_ops.expand_dims(input_idx_patched, axis=-1), array_ops.expand_dims(output_idx, axis=-1) ], axis=-1) idx_map = array_ops.reshape(idx_matrix, (-1, 2)) sp_shape = (input_indices_num, output_indices_num) sp_mat_full = sparse_tensor.SparseTensor( idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) # Remove all padding locations [0, :]. sp_mat = sparse_ops.sparse_slice( sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num)) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Converting sparse IndexedSlices to a dense Tensor.*") grad_expanded = array_ops.transpose( array_ops.reshape( grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out]
def patches_to_images(self, grad, batch_size, rows_in, cols_in, channels, rows_out, cols_out, ksize_r, ksize_c, stride_h, stride_r ): rate_r = 1 rate_c = 1 padding = self.pad ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1) ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1) if padding == 'SAME': rows_out = int(ceil(rows_in / stride_r)) cols_out = int(ceil(cols_in / stride_h)) pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2 pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2 elif padding == 'VALID': rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r)) cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h)) pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols) grad_expanded = array_ops.transpose( array_ops.reshape(grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5) ) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) row_steps = range(0, rows_out * stride_r, stride_r) col_steps = range(0, cols_out * stride_h, stride_h) idx = [] for i in range(rows_out): for j in range(cols_out): r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j * (ksize_r * ksize_c) + ri * (ksize_c) + ci) for (ri, r) in enumerate(range(r_low, r_high, rate_r)) for (ci, c) in enumerate(range(c_low, c_high, rate_c)) if 0 <= r and r < rows_in and 0 <= c and c < cols_in ]) sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c) sp_mat = sparse_tensor.SparseTensor( array_ops.constant(idx, dtype=ops.dtypes.int64), array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape ) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape( jac, (rows_in, cols_in, batch_size, channels) ) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return grad_out
def _ExtractVolumePatchesGrad(op, grad): batch_size, planes_in, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].shape.dims ] input_bphwc = array_ops.shape(op.inputs[0]) batch_size = input_bphwc[0] channels = input_bphwc[4] # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = 1 + planes_in * rows_in * cols_in input_idx = array_ops.reshape( math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), (1, planes_in, rows_in, cols_in, 1)) input_idx_patched = gen_array_ops.extract_volume_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("padding")) # Create indices matrix for output tensor. _, planes_out, rows_out, cols_out, _ = [ dim.value for dim in op.outputs[0].shape.dims ] _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Indices for output start from 0. prc_indices_num = planes_out * rows_out * cols_out output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c output_idx = array_ops.reshape( math_ops.range(output_indices_num, dtype=ops.dtypes.int64), (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) # Construct mapping table for indices: (input -> output). idx_matrix = array_ops.concat([ array_ops.expand_dims(input_idx_patched, axis=-1), array_ops.expand_dims(output_idx, axis=-1) ], axis=-1) idx_map = array_ops.reshape(idx_matrix, (-1, 2)) sp_shape = (input_indices_num, output_indices_num) sp_mat_full = sparse_tensor.SparseTensor( idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) # Remove all padding locations [0, :]. sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num)) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="Converting sparse IndexedSlices to a dense Tensor.*") grad_expanded = array_ops.transpose( array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape( jac, (planes_in, rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) return [grad_out]
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following two elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache sharding_func = WALSModel._get_sharding_func( self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache sharding_func = WALSModel._get_sharding_func( self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split(value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims( math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims( math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast( array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast( array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast( array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup(right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat_v2([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat_v2([row_shape, col_shape], 0) if transpose_input else array_ops.concat_v2( [col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor(indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization is not None: total_lhs += self._regularization if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: with ops.control_dependencies([ check_ops.assert_less_equal( array_ops.rank(row_weights), 1) ]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs") total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to caluclate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update( left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor( loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims( row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims( col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul( new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) ) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)
def body(t, prev): with tf.control_dependencies([prev]): return (t + 1, sparse_ops.sparse_tensor_dense_matmul(sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b))
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following two elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat_v2([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat_v2([row_shape, col_shape], 0) if transpose_input else array_ops.concat_v2([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization is not None: total_lhs += self._regularization if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs") total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))
def _ExtractImagePatchesGrad(op, grad): batch_size, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].get_shape() ] _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()] _, ksize_r, ksize_c, _ = op.get_attr('ksizes') _, stride_r, stride_h, _ = op.get_attr('strides') _, rate_r, rate_c, _ = op.get_attr('rates') padding = op.get_attr('padding') ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1) ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1) if padding == 'SAME': rows_out = int(ceil(rows_in / stride_r)) cols_out = int(ceil(cols_in / stride_h)) pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2 pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2 elif padding == 'VALID': rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r)) cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h)) pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols) grad_expanded = array_ops.transpose( array_ops.reshape( grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) row_steps = range(0, rows_out * stride_r, stride_r) col_steps = range(0, cols_out * stride_h, stride_h) idx = [] for i in range(rows_out): for j in range(cols_out): r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff idx.extend([ (r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j * (ksize_r * ksize_c) + ri * (ksize_c) + ci) for (ri, r) in enumerate(range(r_low, r_high, rate_r)) for (ci, c) in enumerate(range(c_low, c_high, rate_c)) if 0 <= r and r < rows_in and 0 <= c and c < cols_in ]) sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c) sp_mat = ops.SparseTensor( array_ops.constant(idx, dtype=ops.dtypes.int64), array_ops.ones((len(idx), ), dtype=ops.dtypes.float32), sp_shape) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out]
def body(t, prev): with tf.control_dependencies([prev]): return (t + 1, sparse_ops.sparse_tensor_dense_matmul( sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b))
def _ExtractVolumePatchesGrad(op, grad): batch_size, planes_in, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].shape.dims ] input_bphwc = array_ops.shape(op.inputs[0]) batch_size = input_bphwc[0] channels = input_bphwc[4] # Create indices matrix for input tensor. # Note that 0 is preserved for padding location, # so indices for input start from 1 to 1 + rows_in * cols_in. input_indices_num = 1 + planes_in * rows_in * cols_in input_idx = array_ops.reshape( math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), (1, planes_in, rows_in, cols_in, 1)) input_idx_patched = gen_array_ops.extract_volume_patches( input_idx, op.get_attr("ksizes"), op.get_attr("strides"), op.get_attr("padding")) # Create indices matrix for output tensor. _, planes_out, rows_out, cols_out, _ = [ dim.value for dim in op.outputs[0].shape.dims ] _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") # Indices for output start from 0. prc_indices_num = planes_out * rows_out * cols_out output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c output_idx = array_ops.reshape( math_ops.range(output_indices_num, dtype=ops.dtypes.int64), (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) # Construct mapping table for indices: (input -> output). idx_matrix = array_ops.concat([ array_ops.expand_dims(input_idx_patched, axis=-1), array_ops.expand_dims(output_idx, axis=-1) ], axis=-1) idx_map = array_ops.reshape(idx_matrix, (-1, 2)) sp_shape = (input_indices_num, output_indices_num) sp_mat_full = sparse_tensor.SparseTensor( idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) # Remove all padding locations [0, :]. sp_mat = sparse_ops.sparse_slice( sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num)) grad_expanded = array_ops.transpose( array_ops.reshape(_IndexedSlicesToTensorNoWarning(grad), (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape( jac, (planes_in, rows_in, cols_in, batch_size, channels)) grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) return [grad_out]
def _ExtractImagePatchesGrad(op, grad): batch_size, rows_in, cols_in, channels = [ dim.value for dim in op.inputs[0].get_shape() ] input_bhwc = array_ops.shape(op.inputs[0]) batch_size = input_bhwc[0] channels = input_bhwc[3] _, rows_out, cols_out, _ = [ dim.value for dim in op.outputs[0].get_shape() ] _, ksize_r, ksize_c, _ = op.get_attr('ksizes') _, stride_r, stride_h, _ = op.get_attr('strides') _, rate_r, rate_c, _ = op.get_attr('rates') padding = op.get_attr('padding') ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1) ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1) if padding == b'SAME': rows_out = int(ceil(rows_in / stride_r)) cols_out = int(ceil(cols_in / stride_h)) pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2 pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2 elif padding == b'VALID': rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r)) cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h)) pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols) grad_expanded = array_ops.transpose( array_ops.reshape(grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), (1, 2, 3, 4, 0, 5) ) grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) row_steps = range(0, rows_out * stride_r, stride_r) col_steps = range(0, cols_out * stride_h, stride_h) idx = [] for i in range(rows_out): for j in range(cols_out): r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j * (ksize_r * ksize_c) + ri * (ksize_c) + ci) for (ri, r) in enumerate(range(r_low, r_high, rate_r)) for (ci, c) in enumerate(range(c_low, c_high, rate_c)) if 0 <= r and r < rows_in and 0 <= c and c < cols_in ]) sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c) sp_mat = sparse_tensor.SparseTensor( array_ops.constant(idx, dtype=ops.dtypes.int64), array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape ) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) grad_out = array_ops.reshape( jac, (rows_in, cols_in, batch_size, channels) ) grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) return [grad_out]
def dense_matmul(sp, w): return sparse_ops.sparse_tensor_dense_matmul(sp, w)