def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False):
    x_mat = np.matrix(x)
    if adjoint_a:
      x_mat = x_mat.H
    y_mat = np.matrix(y)
    if adjoint_b:
      y_mat = y_mat.H

    np_ans = x_mat * y_mat

    x_indices = np.vstack(np.where(x)).astype(np.int64).T
    x_values = x[np.where(x)]
    x_shape = x.shape

    with self.test_session(use_gpu=True):
      sp_x_value = tf.SparseTensorValue(
          indices=x_indices, values=x_values, shape=x_shape)
      tf_value_ans = sparse_ops.sparse_tensor_dense_matmul(
          sp_x_value, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
      tf_tensor_ans = sparse_ops.sparse_tensor_dense_matmul(
          tf.SparseTensor.from_value(sp_x_value), y, adjoint_a=adjoint_a,
          adjoint_b=adjoint_b)

      # Ensure that the RHS shape is known at least.
      self.assertEqual(tf_value_ans.get_shape()[1], np_ans.shape[1])
      self.assertEqual(tf_tensor_ans.get_shape()[1], np_ans.shape[1])

      for out in (tf_value_ans.eval(), tf_tensor_ans.eval()):
        if x.dtype == np.float32:
          self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
        elif x.dtype == np.float64:
          self.assertAllClose(np_ans, out, rtol=1e-6, atol=1e-6)
        else:
          self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
  def testInvalidIndicesForSparseTensorDenseMatmul(self):
    # Note: use_gpu=False because nice errors are only returned from CPU kernel.
    with self.session(use_gpu=False):
      indices = np.matrix([[1, 10]]).astype(np.int64)
      values = np.array([10]).astype(np.float32)
      shape = [3, 2]
      sparse_t = sparse_tensor.SparseTensor(indices, values, shape)

      # Test multiplying by both a small and large dense matrix, to hit
      # both cases in the kernel.
      dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
      with self.assertRaisesOpError(
          "k .10. from index.0,1. out of bounds .>=2."):
        sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t).eval()
      dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
      with self.assertRaisesOpError(
          "k .10. from index.0,1. out of bounds .>=2."):
        sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t).eval()

      # Repeat with adjoint_a, to get a different error.
      dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
      with self.assertRaisesOpError(
          "m .10. from index.0,1. out of bounds .>=2."):
        sparse_ops.sparse_tensor_dense_matmul(
            sparse_t, dense_t, adjoint_a=True).eval()
      dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
      with self.assertRaisesOpError(
          "m .10. from index.0,1. out of bounds .>=2."):
        sparse_ops.sparse_tensor_dense_matmul(
            sparse_t, dense_t, adjoint_a=True).eval()
Exemple #3
0
def _ExtractImagePatchesGrad(op, grad):
  batch_size, rows_in, cols_in, channels = [
      dim.value for dim in op.inputs[0].shape.dims
  ]
  input_bhwc = array_ops.shape(op.inputs[0])
  batch_size = input_bhwc[0]
  channels = input_bhwc[3]

  # Create indices matrix for input tensor.
  # Note that 0 is preserved for padding location,
  # so indices for input start from 1 to 1 + rows_in * cols_in.
  input_indices_num = 1 + rows_in * cols_in
  input_idx = array_ops.reshape(math_ops.range(1, input_indices_num,
                                               dtype=ops.dtypes.int64),
                                (1, rows_in, cols_in, 1))
  input_idx_patched = gen_array_ops.extract_image_patches(
      input_idx,
      op.get_attr("ksizes"),
      op.get_attr("strides"),
      op.get_attr("rates"),
      op.get_attr("padding"))

  # Create indices matrix for output tensor.
  _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
  _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
  # Indices for output start from 0.
  output_indices_num = rows_out * cols_out * ksize_r * ksize_c
  output_idx = array_ops.reshape(math_ops.range(output_indices_num,
                                                dtype=ops.dtypes.int64),
                                 (1, rows_out, cols_out, ksize_r * ksize_c))

  # Construct mapping table for indices: (input -> output).
  idx_matrix = array_ops.concat(
      [array_ops.expand_dims(input_idx_patched, axis=-1),
       array_ops.expand_dims(output_idx, axis=-1)],
      axis=-1)
  idx_map = array_ops.reshape(idx_matrix, (-1, 2))

  sp_shape = (input_indices_num, output_indices_num)
  sp_mat_full = sparse_tensor.SparseTensor(
      idx_map,
      array_ops.ones([output_indices_num], dtype=grad.dtype),
      sp_shape)
  # Remove all padding locations [0, :].
  sp_mat = sparse_ops.sparse_slice(sp_mat_full,
                                   (1, 0),
                                   (input_indices_num - 1, output_indices_num))

  grad_expanded = array_ops.transpose(
      array_ops.reshape(
          grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
      (1, 2, 3, 4, 0, 5))
  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

  grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))

  return [grad_out]
def _SparseTensorDenseMatMulGrad(op, grad):
  """Gradients for the dense tensor in the SparseTensorDenseMatMul op.

  Gradients are only provided for the dense tensor.

  If either input is complex, no gradient is provided.

  Args:
    op: the SparseTensorDenseMatMul op
    grad: the incoming gradient

  Returns:
    Gradient for each of the 4 input tensors:
      (sparse_indices, sparse_values, sparse_shape, dense_tensor)
    The sparse tensor gradients are always None.
  """
  sp_t = ops.SparseTensor(*op.inputs[:3])
  adj_a = op.get_attr("adjoint_a")
  adj_b = op.get_attr("adjoint_b")

  a_type = sp_t.values.dtype
  b_type = op.inputs[3].dtype
  assert a_type == b_type
  is_complex = a_type == ops.dtypes.complex64
  if is_complex:
    raise NotImplementedError("SparseTensorDenseMatMul op does not support "
                              "complex gradients.")

  b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
                                                 adjoint_a=not adj_a)
  if adj_b:
    b_grad = array_ops.transpose(b_grad)

  return (None, None, None, b_grad)
Exemple #5
0
def _SparseTensorDenseMatMulGrad(op, grad):
  """Gradients for the dense tensor in the SparseTensorDenseMatMul op.

  If either input is complex, no gradient is provided.

  Args:
    op: the SparseTensorDenseMatMul op
    grad: the incoming gradient

  Returns:
    Gradient for each of the 4 input tensors:
      (sparse_indices, sparse_values, sparse_shape, dense_tensor)
    The gradients for indices and shape are None.

  Raises:
    TypeError: When the two operands don't have the same type.
  """
  sp_t = ops.SparseTensor(*op.inputs[:3])
  adj_a = op.get_attr("adjoint_a")
  adj_b = op.get_attr("adjoint_b")

  a_type = sp_t.values.dtype.base_dtype
  b_type = op.inputs[3].dtype.base_dtype
  if a_type != b_type:
    raise TypeError("SparseTensorDenseMatMul op received operands with "
                    "different types: ", a_type, " and ", b_type)
  is_complex = a_type == ops.dtypes.complex64
  if is_complex:
    raise NotImplementedError("SparseTensorDenseMatMul op does not support "
                              "complex gradients.")

  # gradient w.r.t. dense
  b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
                                                 adjoint_a=not adj_a)
  if adj_b:
    b_grad = array_ops.transpose(b_grad)

  # gradient w.r.t. sparse values
  a_indices = op.inputs[0]
  b = op.inputs[3]

  rows = a_indices[:, 0]
  cols = a_indices[:, 1]

  # TODO(zongheng, ebrevdo): add conjugates in the right places when complex
  # values are allowed.
  # TODO(zongheng): these gather calls could potentially duplicate rows/cols in
  # memory.  If there is a need, we should look into implementing this more
  # intelligently to avoid duplicating data.
  parts_a = array_ops.gather(grad, rows if not adj_a else cols)
  parts_b = array_ops.gather(b if not adj_b else array_ops.transpose(b),
                             cols if not adj_a else rows)
  a_values_grad = math_ops.reduce_sum(parts_a * parts_b, reduction_indices=1)

  # gradients w.r.t. (a_indices, a_values, a_shape, b)
  return (None, a_values_grad, None, b_grad)
  def testConsumers(self):
    sp = sparse_tensor.SparseTensor([[0, 0], [1, 2]], [1.0, 3.0], [3, 4])
    w = ops.convert_to_tensor(np.ones([4, 1], np.float32))
    out = sparse_ops.sparse_tensor_dense_matmul(sp, w)
    self.assertEqual(len(sp.consumers()), 1)
    self.assertEqual(sp.consumers()[0], out.op)

    dense = sparse_ops.sparse_tensor_to_dense(sp)
    self.assertEqual(len(sp.consumers()), 2)
    self.assertTrue(dense.op in sp.consumers())
    self.assertTrue(out.op in sp.consumers())
  def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
    # Note: use_gpu=False because nice errors are only returned from CPU kerne
    if not test.is_gpu_available():
      return
    with self.session(use_gpu=True):
      indices = np.array([[1, 10]]).astype(np.int64)
      values = np.array([10]).astype(np.float32)
      shape = [3, 2]
      sparse_t = sparse_tensor.SparseTensor(indices, values, shape)

      # Test multiplying by both a small and large dense matrix, to hit
      # both cases in the kernel.
      dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32)
      expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32)
      self.assertAllClose(expected_t,
                          sparse_ops.sparse_tensor_dense_matmul(
                              sparse_t, dense_t).eval())
      dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32)
      expected_t = np.array(
          [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32)
      self.assertAllClose(expected_t,
                          sparse_ops.sparse_tensor_dense_matmul(
                              sparse_t, dense_t).eval())

      # Repeat with adjoint_a, now the error is that the sparse index
      # is OOO w.r.t. the output.  The GPU kernel can't do much here,
      # so it just doesn't accumulate.

      dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32)
      expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32)
      self.assertAllClose(expected_t,
                          sparse_ops.sparse_tensor_dense_matmul(
                              sparse_t, dense_t, adjoint_a=True).eval())

      dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32)
      expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32)
      self.assertAllClose(expected_t,
                          sparse_ops.sparse_tensor_dense_matmul(
                              sparse_t, dense_t, adjoint_a=True).eval())
  def testShapeInference(self):
    x = np.random.rand(10, 10)
    x[np.abs(x) < 0.5] = 0  # Make it sparse
    y = np.random.randn(10, 20)
    x_indices = np.vstack(np.where(x)).astype(np.int64).T
    x_values = x[np.where(x)]
    x_shape = x.shape
    x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape)
    result = sparse_ops.sparse_tensor_dense_matmul(x_st, y)
    self.assertEqual(result.get_shape(), (10, 20))

    x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None)
    x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values,
                                                    x_shape_unknown)
    result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul(
        x_st_shape_unknown, y)
    self.assertEqual(result_left_shape_unknown.get_shape().as_list(),
                     [None, 20])

    x_shape_inconsistent = [10, 15]
    x_st_shape_inconsistent = sparse_tensor.SparseTensor(x_indices, x_values,
                                                         x_shape_inconsistent)
    with self.assertRaisesRegexp(ValueError, "Dimensions must be equal"):
      sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y)
  def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype):
    n, k, m = np.random.randint(1, 10, size=3)
    sp_t, nnz = self._randomTensor(
        [n, k], np_dtype, adjoint=adjoint_a, sparse=True)
    dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)

    matmul = sparse_ops.sparse_tensor_dense_matmul(
        sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)

    with self.test_session(use_gpu=True):
      dense_t_shape = [m, k] if adjoint_b else [k, m]
      sp_t_val_shape = [nnz]
      err = gradient_checker.compute_gradient_error(
          [dense_t, sp_t.values], [dense_t_shape, sp_t_val_shape], matmul,
          [n, m])
      print("%s gradient err = %s" % (name, err))
      self.assertLess(err, 1e-3)
Exemple #10
0
    def _process_input_helper(self,
                              update_row_factors,
                              sp_input=None,
                              transpose_input=False,
                              row_weights=None):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to calculate the root weighted
        squared error.
    """
        assert isinstance(sp_input, sparse_tensor.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right_factors = self._col_factors_cache
            row_wt = self._row_wt_cache
            col_wt = self._col_wt_cache
            total_rows = self._input_rows
            total_cols = self._input_cols
            sharding_func = WALSModel._get_sharding_func(
                self._input_rows, self._num_row_shards)
            gramian = self._col_gramian_cache
        else:
            left = self._col_factors
            right_factors = self._row_factors_cache
            row_wt = self._col_wt_cache
            col_wt = self._row_wt_cache
            total_rows = self._input_cols
            total_cols = self._input_rows
            sharding_func = WALSModel._get_sharding_func(
                self._input_cols, self._num_col_shards)
            gramian = self._row_gramian_cache
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = array_ops.split(value=sp_input.indices,
                                           num_or_size_splits=2,
                                           axis=1)
        update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
        update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
        col_ids = array_ops.expand_dims(
            math_ops.cast(all_col_ids, dtypes.int64), 1)
        row_ids = array_ops.expand_dims(
            math_ops.cast(all_row_ids, dtypes.int64), 1)

        if transpose_input:
            update_indices = update_col_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_row_indices)[0], dtypes.int64)
            ]
            gather_indices = update_row_indices
        else:
            update_indices = update_row_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_col_indices)[0], dtypes.int64)
            ]
            gather_indices = update_col_indices

        num_rows = math_ops.cast(
            array_ops.shape(update_indices)[0], dtypes.int64)
        col_shape = [num_rows]
        right = embedding_ops.embedding_lookup(right_factors,
                                               gather_indices,
                                               partition_strategy="div")
        new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
        new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                        if transpose_input else array_ops.concat(
                            [col_shape, row_shape], 0))
        new_sp_input = sparse_tensor.SparseTensor(indices=new_sp_indices,
                                                  values=sp_input.values,
                                                  dense_shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = (self._unobserved_weight * gramian)
        if self._regularization_matrix is not None:
            total_lhs += self._regularization_matrix
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = (self._unobserved_weight *
                         sparse_ops.sparse_tensor_dense_matmul(
                             new_sp_input, right, adjoint_a=transpose_input))
            # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:894 gh:895
            # transposing explicitly.
            # TODO (rmlarsen): multi-thread tf.matrix_solve. id:594 gh:594
            new_left_values = array_ops.transpose(
                linalg_ops.matrix_solve(total_lhs,
                                        array_ops.transpose(total_rhs)))
        else:
            if row_weights is None:
                # TODO (yifanchen): Add special handling for single shard without using id:635 gh:636
                # embedding_lookup and perform benchmarks for those cases. Same for
                # col_weights lookup below.
                row_weights_slice = embedding_ops.embedding_lookup(
                    row_wt, update_indices, partition_strategy="div")
            else:
                num_indices = array_ops.shape(update_indices)[0]
                with ops.control_dependencies([
                        check_ops.assert_less_equal(
                            array_ops.rank(row_weights), 1)
                ]):
                    row_weights_slice = control_flow_ops.cond(
                        math_ops.equal(array_ops.rank(row_weights), 0), lambda:
                        (array_ops.ones([num_indices]) * row_weights),
                        lambda: math_ops.cast(row_weights, dtypes.float32))

            col_weights = embedding_ops.embedding_lookup(
                col_wt, gather_indices, partition_strategy="div")
            partial_lhs, total_rhs = (
                gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
                    right,
                    col_weights,
                    self._unobserved_weight,
                    row_weights_slice,
                    new_sp_input.indices,
                    new_sp_input.values,
                    num_rows,
                    transpose_input,
                    name="wals_compute_partial_lhs_rhs"))
            total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = array_ops.expand_dims(total_rhs, -1)
            new_left_values = array_ops.squeeze(
                linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

        update_op_name = "row_update" if update_row_factors else "col_update"
        update_op = self.scatter_update(left,
                                        update_indices,
                                        new_left_values,
                                        sharding_func,
                                        name=update_op_name)

        # Create the loss subgraph
        loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                         if transpose_input else new_sp_input)
        # sp_approx is the low rank estimate of the input matrix, formed by
        # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
        sp_approx_vals = gen_factorization_ops.masked_matmul(
            new_left_values,
            right,
            loss_sp_input.indices,
            transpose_a=False,
            transpose_b=True)
        sp_approx = sparse_tensor.SparseTensor(loss_sp_input.indices,
                                               sp_approx_vals,
                                               loss_sp_input.dense_shape)
        sp_approx_sq = math_ops.square(sp_approx)
        sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
        sp_residual_sq = math_ops.square(sp_residual)
        row_wt_mat = (constant_op.constant(0.) if self._row_weights is None
                      else array_ops.expand_dims(row_weights_slice, 1))
        col_wt_mat = (constant_op.constant(0.) if self._col_weights is None
                      else array_ops.expand_dims(col_weights, 0))

        # We return the normalized loss
        partial_row_gramian = math_ops.matmul(new_left_values,
                                              new_left_values,
                                              transpose_a=True)
        normalization_factor = total_rows / math_ops.cast(
            num_rows, dtypes.float32)

        unregularized_loss = (
            self._unobserved_weight * (  # pyformat line break
                sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
                sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
                math_ops.trace(math_ops.matmul(partial_row_gramian, gramian)))
            + sparse_ops.sparse_reduce_sum(
                row_wt_mat *
                (sp_residual_sq * col_wt_mat))) * normalization_factor

        if self._regularization is not None:
            regularization = self._regularization * (
                math_ops.trace(partial_row_gramian) * normalization_factor +
                math_ops.trace(gramian))
        else:
            regularization = constant_op.constant(0.)

        sum_weights = self._unobserved_weight * math_ops.cast(
            total_rows * total_cols, dtypes.float32)
        if self._row_weights is not None and self._col_weights is not None:
            ones = sparse_tensor.SparseTensor(
                indices=loss_sp_input.indices,
                values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
                dense_shape=loss_sp_input.dense_shape)
            sum_weights += sparse_ops.sparse_reduce_sum(
                row_wt_mat * (ones * col_wt_mat)) * normalization_factor

        return (new_left_values, update_op, unregularized_loss, regularization,
                sum_weights)
Exemple #11
0
def _ExtractImagePatchesGrad(op, grad):
    batch_size, rows_in, cols_in, channels = [
        dim.value for dim in op.inputs[0].shape.dims
    ]
    input_bhwc = array_ops.shape(op.inputs[0])
    batch_size = input_bhwc[0]
    channels = input_bhwc[3]

    # Create indices matrix for input tensor.
    # Note that 0 is preserved for padding location,
    # so indices for input start from 1 to 1 + rows_in * cols_in.
    input_indices_num = 1 + rows_in * cols_in
    input_idx = array_ops.reshape(
        math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
        (1, rows_in, cols_in, 1))
    input_idx_patched = gen_array_ops.extract_image_patches(
        input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
        op.get_attr("rates"), op.get_attr("padding"))

    # Create indices matrix for output tensor.
    _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
    _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
    # Indices for output start from 0.
    output_indices_num = rows_out * cols_out * ksize_r * ksize_c
    output_idx = array_ops.reshape(
        math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
        (1, rows_out, cols_out, ksize_r * ksize_c))

    # Construct mapping table for indices: (input -> output).
    idx_matrix = array_ops.concat([
        array_ops.expand_dims(input_idx_patched, axis=-1),
        array_ops.expand_dims(output_idx, axis=-1)
    ],
                                  axis=-1)
    idx_map = array_ops.reshape(idx_matrix, (-1, 2))

    sp_shape = (input_indices_num, output_indices_num)
    sp_mat_full = sparse_tensor.SparseTensor(
        idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype),
        sp_shape)
    # Remove all padding locations [0, :].
    sp_mat = sparse_ops.sparse_slice(
        sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num))

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="Converting sparse IndexedSlices to a dense Tensor.*")
        grad_expanded = array_ops.transpose(
            array_ops.reshape(
                grad,
                (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
            (1, 2, 3, 4, 0, 5))
    grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

    jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

    grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
    grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))

    return [grad_out]
    def patches_to_images(self, grad, batch_size, rows_in, cols_in, channels, rows_out, cols_out, ksize_r, ksize_c, stride_h, stride_r ):
        rate_r = 1
        rate_c = 1
        padding = self.pad
        
        
        ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
        ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)

        if padding == 'SAME':
            rows_out = int(ceil(rows_in / stride_r))
            cols_out = int(ceil(cols_in / stride_h))
            pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
            pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2

        elif padding == 'VALID':
            rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
            cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
            pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
            pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in

        pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)

        grad_expanded = array_ops.transpose(
            array_ops.reshape(grad, (batch_size, rows_out,
                                     cols_out, ksize_r, ksize_c, channels)),
            (1, 2, 3, 4, 0, 5)
        )
        grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

        row_steps = range(0, rows_out * stride_r, stride_r)
        col_steps = range(0, cols_out * stride_h, stride_h)

        idx = []
        for i in range(rows_out):
            for j in range(cols_out):
                r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
                r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff

                idx.extend([(r * (cols_in) + c,
                   i * (cols_out * ksize_r * ksize_c) +
                   j * (ksize_r * ksize_c) +
                   ri * (ksize_c) + ci)
                  for (ri, r) in enumerate(range(r_low, r_high, rate_r))
                  for (ci, c) in enumerate(range(c_low, c_high, rate_c))
                  if 0 <= r and r < rows_in and 0 <= c and c < cols_in
                ])

        sp_shape = (rows_in * cols_in,
              rows_out * cols_out * ksize_r * ksize_c)

        sp_mat = sparse_tensor.SparseTensor(
            array_ops.constant(idx, dtype=ops.dtypes.int64),
            array_ops.ones((len(idx),), dtype=ops.dtypes.float32),
            sp_shape
        )

        jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

        grad_out = array_ops.reshape(
            jac, (rows_in, cols_in, batch_size, channels)
        )
        grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
        
        return grad_out
Exemple #13
0
def _ExtractVolumePatchesGrad(op, grad):
  batch_size, planes_in, rows_in, cols_in, channels = [
      dim.value for dim in op.inputs[0].shape.dims
  ]
  input_bphwc = array_ops.shape(op.inputs[0])
  batch_size = input_bphwc[0]
  channels = input_bphwc[4]

  # Create indices matrix for input tensor.
  # Note that 0 is preserved for padding location,
  # so indices for input start from 1 to 1 + rows_in * cols_in.
  input_indices_num = 1 + planes_in * rows_in * cols_in
  input_idx = array_ops.reshape(
      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
      (1, planes_in, rows_in, cols_in, 1))
  input_idx_patched = gen_array_ops.extract_volume_patches(
      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
      op.get_attr("padding"))

  # Create indices matrix for output tensor.
  _, planes_out, rows_out, cols_out, _ = [
      dim.value for dim in op.outputs[0].shape.dims
  ]
  _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
  # Indices for output start from 0.
  prc_indices_num = planes_out * rows_out * cols_out
  output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
  output_idx = array_ops.reshape(
      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
      (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))

  # Construct mapping table for indices: (input -> output).
  idx_matrix = array_ops.concat([
      array_ops.expand_dims(input_idx_patched, axis=-1),
      array_ops.expand_dims(output_idx, axis=-1)
  ],
                                axis=-1)
  idx_map = array_ops.reshape(idx_matrix, (-1, 2))

  sp_shape = (input_indices_num, output_indices_num)
  sp_mat_full = sparse_tensor.SparseTensor(
      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
  # Remove all padding locations [0, :].
  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
                                   (input_indices_num - 1, output_indices_num))

  with warnings.catch_warnings():
    warnings.filterwarnings(
        "ignore",
        message="Converting sparse IndexedSlices to a dense Tensor.*")
    grad_expanded = array_ops.transpose(
        array_ops.reshape(grad, (batch_size, planes_out, rows_out, cols_out,
                                 ksize_p, ksize_r, ksize_c, channels)),
        (1, 2, 3, 4, 5, 6, 0, 7))
  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

  grad_out = array_ops.reshape(
      jac, (planes_in, rows_in, cols_in, batch_size, channels))
  grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))

  return [grad_out]
    def _process_input_helper(self,
                              update_row_factors,
                              sp_input=None,
                              transpose_input=False,
                              row_weights=None):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
        assert isinstance(sp_input, sparse_tensor.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right_factors = self._col_factors_cache
            row_wt = self._row_wt_cache
            col_wt = self._col_wt_cache
            sharding_func = WALSModel._get_sharding_func(
                self._input_rows, self._num_row_shards)
            gramian = self._col_gramian_cache
        else:
            left = self._col_factors
            right_factors = self._row_factors_cache
            row_wt = self._col_wt_cache
            col_wt = self._row_wt_cache
            sharding_func = WALSModel._get_sharding_func(
                self._input_cols, self._num_col_shards)
            gramian = self._row_gramian_cache
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = array_ops.split(value=sp_input.indices,
                                           num_or_size_splits=2,
                                           axis=1)
        update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
        update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
        col_ids = array_ops.expand_dims(
            math_ops.cast(all_col_ids, dtypes.int64), 1)
        row_ids = array_ops.expand_dims(
            math_ops.cast(all_row_ids, dtypes.int64), 1)

        if transpose_input:
            update_indices = update_col_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_row_indices)[0], dtypes.int64)
            ]
            gather_indices = update_row_indices
        else:
            update_indices = update_row_indices
            row_shape = [
                math_ops.cast(
                    array_ops.shape(update_col_indices)[0], dtypes.int64)
            ]
            gather_indices = update_col_indices

        num_rows = math_ops.cast(
            array_ops.shape(update_indices)[0], dtypes.int64)
        col_shape = [num_rows]
        right = embedding_ops.embedding_lookup(right_factors,
                                               gather_indices,
                                               partition_strategy="div")
        new_sp_indices = array_ops.concat_v2([row_ids, col_ids], 1)
        new_sp_shape = (array_ops.concat_v2([row_shape, col_shape], 0)
                        if transpose_input else array_ops.concat_v2(
                            [col_shape, row_shape], 0))
        new_sp_input = sparse_tensor.SparseTensor(indices=new_sp_indices,
                                                  values=sp_input.values,
                                                  dense_shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = (self._unobserved_weight * gramian)
        if self._regularization is not None:
            total_lhs += self._regularization
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = (self._unobserved_weight *
                         sparse_ops.sparse_tensor_dense_matmul(
                             new_sp_input, right, adjoint_a=transpose_input))
            # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
            # transposing explicitly.
            # TODO(rmlarsen): multi-thread tf.matrix_solve.
            new_left_values = array_ops.transpose(
                linalg_ops.matrix_solve(total_lhs,
                                        array_ops.transpose(total_rhs)))
        else:
            if row_weights is None:
                # TODO(yifanchen): Add special handling for single shard without using
                # embedding_lookup and perform benchmarks for those cases. Same for
                # col_weights lookup below.
                row_weights_slice = embedding_ops.embedding_lookup(
                    row_wt, update_indices, partition_strategy="div")
            else:
                with ops.control_dependencies([
                        check_ops.assert_less_equal(
                            array_ops.rank(row_weights), 1)
                ]):
                    row_weights_slice = control_flow_ops.cond(
                        math_ops.equal(array_ops.rank(row_weights), 0), lambda:
                        (array_ops.ones([array_ops.shape(update_indices)[0]]) *
                         row_weights),
                        lambda: math_ops.cast(row_weights, dtypes.float32))

            col_weights = embedding_ops.embedding_lookup(
                col_wt, gather_indices, partition_strategy="div")
            partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
                right,
                col_weights,
                self._unobserved_weight,
                row_weights_slice,
                new_sp_input.indices,
                new_sp_input.values,
                num_rows,
                transpose_input,
                name="wals_compute_partial_lhs_rhs")
            total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = array_ops.expand_dims(total_rhs, -1)
            new_left_values = array_ops.squeeze(
                linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

        return (new_left_values,
                self.scatter_update(left, update_indices, new_left_values,
                                    sharding_func))
Exemple #15
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to caluclate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
 def body(t, prev):
     with tf.control_dependencies([prev]):
         return (t + 1, sparse_ops.sparse_tensor_dense_matmul(sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b))
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat_v2([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat_v2([row_shape, col_shape], 0) if
                    transpose_input else
                    array_ops.concat_v2([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization is not None:
      total_lhs += self._regularization
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (self._unobserved_weight *
                   sparse_ops.sparse_tensor_dense_matmul(
                       new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
          right,
          col_weights,
          self._unobserved_weight,
          row_weights_slice,
          new_sp_input.indices,
          new_sp_input.values,
          num_rows,
          transpose_input,
          name="wals_compute_partial_lhs_rhs")
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    return (new_left_values, self.scatter_update(left, update_indices,
                                                 new_left_values,
                                                 sharding_func))
Exemple #18
0
def _ExtractImagePatchesGrad(op, grad):

    batch_size, rows_in, cols_in, channels = [
        dim.value for dim in op.inputs[0].get_shape()
    ]
    _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()]
    _, ksize_r, ksize_c, _ = op.get_attr('ksizes')
    _, stride_r, stride_h, _ = op.get_attr('strides')
    _, rate_r, rate_c, _ = op.get_attr('rates')
    padding = op.get_attr('padding')

    ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
    ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)

    if padding == 'SAME':
        rows_out = int(ceil(rows_in / stride_r))
        cols_out = int(ceil(cols_in / stride_h))
        pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
        pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2

    elif padding == 'VALID':
        rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
        cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
        pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
        pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in

    pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)

    grad_expanded = array_ops.transpose(
        array_ops.reshape(
            grad,
            (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
        (1, 2, 3, 4, 0, 5))
    grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

    row_steps = range(0, rows_out * stride_r, stride_r)
    col_steps = range(0, cols_out * stride_h, stride_h)

    idx = []
    for i in range(rows_out):
        for j in range(cols_out):
            r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
            r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff

            idx.extend([
                (r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j *
                 (ksize_r * ksize_c) + ri * (ksize_c) + ci)
                for (ri, r) in enumerate(range(r_low, r_high, rate_r))
                for (ci, c) in enumerate(range(c_low, c_high, rate_c))
                if 0 <= r and r < rows_in and 0 <= c and c < cols_in
            ])

    sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c)

    sp_mat = ops.SparseTensor(
        array_ops.constant(idx, dtype=ops.dtypes.int64),
        array_ops.ones((len(idx), ), dtype=ops.dtypes.float32), sp_shape)

    jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

    grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
    grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))

    return [grad_out]
 def body(t, prev):
   with tf.control_dependencies([prev]):
     return (t + 1,
             sparse_ops.sparse_tensor_dense_matmul(
                 sp_x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b))
Exemple #20
0
def _ExtractVolumePatchesGrad(op, grad):
    batch_size, planes_in, rows_in, cols_in, channels = [
        dim.value for dim in op.inputs[0].shape.dims
    ]
    input_bphwc = array_ops.shape(op.inputs[0])
    batch_size = input_bphwc[0]
    channels = input_bphwc[4]

    # Create indices matrix for input tensor.
    # Note that 0 is preserved for padding location,
    # so indices for input start from 1 to 1 + rows_in * cols_in.
    input_indices_num = 1 + planes_in * rows_in * cols_in
    input_idx = array_ops.reshape(
        math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
        (1, planes_in, rows_in, cols_in, 1))
    input_idx_patched = gen_array_ops.extract_volume_patches(
        input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
        op.get_attr("padding"))

    # Create indices matrix for output tensor.
    _, planes_out, rows_out, cols_out, _ = [
        dim.value for dim in op.outputs[0].shape.dims
    ]
    _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
    # Indices for output start from 0.
    prc_indices_num = planes_out * rows_out * cols_out
    output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
    output_idx = array_ops.reshape(
        math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
        (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))

    # Construct mapping table for indices: (input -> output).
    idx_matrix = array_ops.concat([
        array_ops.expand_dims(input_idx_patched, axis=-1),
        array_ops.expand_dims(output_idx, axis=-1)
    ],
                                  axis=-1)
    idx_map = array_ops.reshape(idx_matrix, (-1, 2))

    sp_shape = (input_indices_num, output_indices_num)
    sp_mat_full = sparse_tensor.SparseTensor(
        idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype),
        sp_shape)
    # Remove all padding locations [0, :].
    sp_mat = sparse_ops.sparse_slice(
        sp_mat_full, (1, 0), (input_indices_num - 1, output_indices_num))

    grad_expanded = array_ops.transpose(
        array_ops.reshape(_IndexedSlicesToTensorNoWarning(grad),
                          (batch_size, planes_out, rows_out, cols_out, ksize_p,
                           ksize_r, ksize_c, channels)),
        (1, 2, 3, 4, 5, 6, 0, 7))
    grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

    jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

    grad_out = array_ops.reshape(
        jac, (planes_in, rows_in, cols_in, batch_size, channels))
    grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))

    return [grad_out]
Exemple #21
0
def _ExtractImagePatchesGrad(op, grad):

  batch_size, rows_in, cols_in, channels = [
    dim.value for dim in op.inputs[0].get_shape()
  ]
  input_bhwc = array_ops.shape(op.inputs[0])
  batch_size = input_bhwc[0]
  channels = input_bhwc[3]

  _, rows_out, cols_out, _ = [
    dim.value for dim in op.outputs[0].get_shape()
  ]
  _, ksize_r, ksize_c, _ = op.get_attr('ksizes')
  _, stride_r, stride_h, _ = op.get_attr('strides')
  _, rate_r, rate_c, _ = op.get_attr('rates')
  padding = op.get_attr('padding')

  ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
  ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)

  if padding == b'SAME':
    rows_out = int(ceil(rows_in / stride_r))
    cols_out = int(ceil(cols_in / stride_h))
    pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
    pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2

  elif padding == b'VALID':
    rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
    cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
    pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
    pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in

  pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)

  grad_expanded = array_ops.transpose(
    array_ops.reshape(grad, (batch_size, rows_out,
                             cols_out, ksize_r, ksize_c, channels)),
    (1, 2, 3, 4, 0, 5)
  )
  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))

  row_steps = range(0, rows_out * stride_r, stride_r)
  col_steps = range(0, cols_out * stride_h, stride_h)

  idx = []
  for i in range(rows_out):
    for j in range(cols_out):
      r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
      r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff

      idx.extend([(r * (cols_in) + c,
                   i * (cols_out * ksize_r * ksize_c) +
                   j * (ksize_r * ksize_c) +
                   ri * (ksize_c) + ci)
                  for (ri, r) in enumerate(range(r_low, r_high, rate_r))
                  for (ci, c) in enumerate(range(c_low, c_high, rate_c))
                  if 0 <= r and r < rows_in and 0 <= c and c < cols_in
      ])

  sp_shape = (rows_in * cols_in,
              rows_out * cols_out * ksize_r * ksize_c)

  sp_mat = sparse_tensor.SparseTensor(
    array_ops.constant(idx, dtype=ops.dtypes.int64),
    array_ops.ones((len(idx),), dtype=ops.dtypes.float32),
    sp_shape
  )

  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)

  grad_out = array_ops.reshape(
    jac, (rows_in, cols_in, batch_size, channels)
  )
  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))

  return [grad_out]
 def dense_matmul(sp, w):
     return sparse_ops.sparse_tensor_dense_matmul(sp, w)