def _verifySolve(self, x, y, batch_dims=None):
    for adjoint in False, True:
      for np_type in [np.float32, np.float64]:
        a = x.astype(np_type)
        b = y.astype(np_type)
        if adjoint:
          a_np = np.conj(np.transpose(a))
        else:
          a_np = a
        if batch_dims is not None:
          a = np.tile(a, batch_dims + [1, 1])
          a_np = np.tile(a_np, batch_dims + [1, 1])
          b = np.tile(b, batch_dims + [1, 1])

        np_ans = np.linalg.solve(a_np, b)
        with self.test_session():
          # Test the batch version, which works for ndim >= 2
          tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
          out = tf_ans.eval()
          self.assertEqual(tf_ans.get_shape(), out.shape)
          self.assertEqual(np_ans.shape, out.shape)
          self.assertAllClose(np_ans, out)

          if a.ndim == 2:
            # Test the simple version
            tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
            out = tf_ans.eval()
            self.assertEqual(out.shape, tf_ans.get_shape())
            self.assertEqual(np_ans.shape, out.shape)
            self.assertAllClose(np_ans, out)
Example #2
0
    def _verifySolve(self, x, y, batch_dims=None):
        for adjoint in False, True:
            for np_type in [np.float32, np.float64]:
                a = x.astype(np_type)
                b = y.astype(np_type)
                if adjoint:
                    a_np = np.conj(np.transpose(a))
                else:
                    a_np = a
                if batch_dims is not None:
                    a = np.tile(a, batch_dims + [1, 1])
                    a_np = np.tile(a_np, batch_dims + [1, 1])
                    b = np.tile(b, batch_dims + [1, 1])

                np_ans = np.linalg.solve(a_np, b)
                with self.test_session():
                    # Test the batch version, which works for ndim >= 2
                    tf_ans = tf.batch_matrix_solve(a, b, adjoint=adjoint)
                    out = tf_ans.eval()
                    self.assertEqual(tf_ans.get_shape(), out.shape)
                    self.assertEqual(np_ans.shape, out.shape)
                    self.assertAllClose(np_ans, out)

                    if a.ndim == 2:
                        # Test the simple version
                        tf_ans = tf.matrix_solve(a, b, adjoint=adjoint)
                        out = tf_ans.eval()
                        self.assertEqual(out.shape, tf_ans.get_shape())
                        self.assertEqual(np_ans.shape, out.shape)
                        self.assertAllClose(np_ans, out)
 def testBatchResultSize(self):
   # 3x3x3 matrices, 3x3x1 right-hand sides.
   matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3)
   rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)
   answer = tf.batch_matrix_solve(matrix, rhs)
   ls_answer = tf.batch_matrix_solve_ls(matrix, rhs)
   self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
   self.assertEqual(answer.get_shape(), [3, 3, 1])
Example #4
0
 def testBatchResultSize(self):
   # 3x3x3 matrices, 3x3x1 right-hand sides.
   matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3)
   rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)
   answer = tf.batch_matrix_solve(matrix, rhs)
   ls_answer = tf.batch_matrix_solve_ls(matrix, rhs)
   self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
   self.assertEqual(answer.get_shape(), [3, 3, 1])
Example #5
0
  def test_solve(self):
    with self.test_session():
      for batch_shape in [(), (2, 3,)]:
        for k in [1, 4]:
          operator, mat = self._build_operator_and_mat(batch_shape, k)

          # Work with 5 simultaneous systems.  5 is arbitrary.
          x = self._rng.randn(*(batch_shape + (k, 5)))

          self._compare_results(
              expected=tf.batch_matrix_solve(mat, x).eval(),
              actual=operator.solve(x))
Example #6
0
 def _verifySolve(self, x, y):
     for np_type in [np.float32, np.float64]:
         a = x.astype(np_type)
         b = y.astype(np_type)
         with self.test_session():
             if a.ndim == 2:
                 tf_ans = tf.matrix_solve(a, b)
             else:
                 tf_ans = tf.batch_matrix_solve(a, b)
             out = tf_ans.eval()
         np_ans = np.linalg.solve(a, b)
         self.assertEqual(np_ans.shape, out.shape)
         self.assertAllClose(np_ans, out)
 def _verifySolve(self, x, y):
   for np_type in [np.float32, np.float64]:
     a = x.astype(np_type)
     b = y.astype(np_type)
     with self.test_session():
       if a.ndim == 2:
         tf_ans = tf.matrix_solve(a, b)
       else:
         tf_ans = tf.batch_matrix_solve(a, b)
       out = tf_ans.eval()
     np_ans = np.linalg.solve(a, b)
     self.assertEqual(np_ans.shape, out.shape)
     self.assertAllClose(np_ans, out)
Example #8
0
  def test_sqrt_solve(self):
    # Square roots are not unique, but we should still have
    # S^{-T} S^{-1} x = A^{-1} x.
    # In our case, we should have S = S^T, so then S^{-1} S^{-1} x = A^{-1} x.
    with self.test_session():
      for batch_shape in [(), (2, 3,)]:
        for k in [1, 4]:
          operator, mat = self._build_operator_and_mat(batch_shape, k)

          # Work with 5 simultaneous systems.  5 is arbitrary.
          x = self._rng.randn(*(batch_shape + (k, 5)))

          self._compare_results(
              expected=tf.batch_matrix_solve(mat, x).eval(),
              actual=operator.sqrt_solve(operator.sqrt_solve(x)))
    def test_solve(self):
        with self.test_session():
            for batch_shape in [(), (
                    2,
                    3,
            )]:
                for k in [1, 4]:
                    operator, mat = self._build_operator_and_mat(
                        batch_shape, k)

                    # Work with 5 simultaneous systems.  5 is arbitrary.
                    x = self._rng.randn(*(batch_shape + (k, 5)))

                    self._compare_results(expected=tf.batch_matrix_solve(
                        mat, x).eval(),
                                          actual=operator.solve(x))
Example #10
0
    def test_sqrt_solve(self):
        # Square roots are not unique, but we should still have
        # S^{-T} S^{-1} x = A^{-1} x.
        # In our case, we should have S = S^T, so then S^{-1} S^{-1} x = A^{-1} x.
        with self.test_session():
            for batch_shape in [(), (
                    2,
                    3,
            )]:
                for k in [1, 4]:
                    operator, mat = self._build_operator_and_mat(
                        batch_shape, k)

                    # Work with 5 simultaneous systems.  5 is arbitrary.
                    x = self._rng.randn(*(batch_shape + (k, 5)))

                    self._compare_results(
                        expected=tf.batch_matrix_solve(mat, x).eval(),
                        actual=operator.sqrt_solve(operator.sqrt_solve(x)))
Example #11
0
 def test_BatchMatrixSolve(self):
     t = tf.batch_matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)))
     self.check(t)
Example #12
0
    def _process_input_helper(self,
                              update_row_factors,
                              sp_input=None,
                              transpose_input=False):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update the row_factors, else update the
        column factors.
      sp_input: Please refer to comments for update_row_factors and
        update_col_factors.
      transpose_input: If true, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
        assert isinstance(sp_input, ops.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right_factors = self._col_factors_cache
            row_wt = self._row_wt_cache
            col_wt = self._col_wt_cache
            sharding_func = WALSModel._get_sharding_func(
                self._input_rows, self._num_row_shards)
            gramian = self._col_gramian_cache
        else:
            left = self._col_factors
            right_factors = self._row_factors_cache
            row_wt = self._col_wt_cache
            col_wt = self._row_wt_cache
            sharding_func = WALSModel._get_sharding_func(
                self._input_cols, self._num_col_shards)
            gramian = self._row_gramian_cache
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = tf.split(1, 2, sp_input.indices)
        update_row_indices, all_row_ids = tf.unique(row_ids[:, 0])
        update_col_indices, all_col_ids = tf.unique(col_ids[:, 0])
        col_ids = tf.expand_dims(tf.cast(all_col_ids, tf.int64), 1)
        row_ids = tf.expand_dims(tf.cast(all_row_ids, tf.int64), 1)

        if transpose_input:
            update_indices = update_col_indices
            row_shape = [tf.cast(tf.shape(update_row_indices)[0], tf.int64)]
            gather_indices = update_row_indices
        else:
            update_indices = update_row_indices
            row_shape = [tf.cast(tf.shape(update_col_indices)[0], tf.int64)]
            gather_indices = update_col_indices

        num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64)
        col_shape = [num_rows]
        right = embedding_ops.embedding_lookup(right_factors,
                                               gather_indices,
                                               partition_strategy='div')
        new_sp_indices = tf.concat(1, [row_ids, col_ids])
        new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input
                        else tf.concat(0, [col_shape, row_shape]))
        new_sp_input = tf.SparseTensor(indices=new_sp_indices,
                                       values=sp_input.values,
                                       shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = (self._unobserved_weight * gramian)
        if self._regularization is not None:
            total_lhs += self._regularization
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = (self._unobserved_weight *
                         tf.sparse_tensor_dense_matmul(
                             new_sp_input, right, adjoint_a=transpose_input))
            # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
            # transposing explicitly.
            # TODO(rmlarsen): multi-thread tf.matrix_solve.
            new_left_values = tf.transpose(
                tf.matrix_solve(total_lhs, tf.transpose(total_rhs)))
        else:
            # TODO(yifanchen): Add special handling for single shard without using
            # embedding_lookup and perform benchmarks for those cases.
            row_weights_slice = embedding_ops.embedding_lookup(
                row_wt, update_indices, partition_strategy='div')
            col_weights = embedding_ops.embedding_lookup(
                col_wt, gather_indices, partition_strategy='div')
            partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
                right,
                col_weights,
                self._unobserved_weight,
                row_weights_slice,
                new_sp_input.indices,
                new_sp_input.values,
                num_rows,
                transpose_input,
                name="wals_compute_partial_lhs_rhs")
            total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = tf.expand_dims(total_rhs, -1)
            new_left_values = tf.squeeze(
                tf.batch_matrix_solve(total_lhs, total_rhs), [2])

        return (new_left_values,
                self.scatter_update(left, update_indices, new_left_values,
                                    sharding_func))
Example #13
0
 def _batch_solve(self, rhs):
     return tf.batch_matrix_solve(self._pos_def_matrix, rhs)
Example #14
0
 def _batch_solve(self, rhs):
   return tf.batch_matrix_solve(self._pos_def_matrix, rhs)
Example #15
0
  def _process_input_helper(self, update_row_factors,
                            sp_input=None, transpose_input=False):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update the row_factors, else update the
        column factors.
      sp_input: Please refer to comments for update_row_factors and
        update_col_factors.
      transpose_input: If true, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
    assert isinstance(sp_input, ops.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = tf.split(1, 2, sp_input.indices)
    update_row_indices, all_row_ids = tf.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = tf.unique(col_ids[:, 0])
    col_ids = tf.expand_dims(tf.cast(all_col_ids, tf.int64), 1)
    row_ids = tf.expand_dims(tf.cast(all_row_ids, tf.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [tf.cast(tf.shape(update_row_indices)[0], tf.int64)]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [tf.cast(tf.shape(update_col_indices)[0], tf.int64)]
      gather_indices = update_col_indices

    num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(right_factors, gather_indices,
                                           partition_strategy='div')
    new_sp_indices = tf.concat(1, [row_ids, col_ids])
    new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input
                    else tf.concat(0, [col_shape, row_shape]))
    new_sp_input = tf.SparseTensor(indices=new_sp_indices,
                                   values=sp_input.values, shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization is not None:
      total_lhs += self._regularization
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (self._unobserved_weight *
                   tf.sparse_tensor_dense_matmul(new_sp_input, right,
                                                 adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = tf.transpose(tf.matrix_solve(total_lhs,
                                                     tf.transpose(total_rhs)))
    else:
      # TODO(yifanchen): Add special handling for single shard without using
      # embedding_lookup and perform benchmarks for those cases.
      row_weights_slice = embedding_ops.embedding_lookup(
          row_wt, update_indices, partition_strategy='div')
      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy='div')
      partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
          right,
          col_weights,
          self._unobserved_weight,
          row_weights_slice,
          new_sp_input.indices,
          new_sp_input.values,
          num_rows,
          transpose_input,
          name="wals_compute_partial_lhs_rhs")
      total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = tf.expand_dims(total_rhs, -1)
      new_left_values = tf.squeeze(tf.batch_matrix_solve(total_lhs, total_rhs),
                                   [2])

    return (new_left_values,
            self.scatter_update(left,
                                update_indices,
                                new_left_values,
                                sharding_func))