Exemple #1
0
        def _fn():
            num_rows = np.shape(np_matrix)[0]
            num_cols = np.shape(np_matrix)[1]
            row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
            col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
            sp_mat = self.np_array_to_sparse(np_matrix)
            sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
            row_batch = input_lib.batch([row_ids, sp_mat],
                                        batch_size=min(batch_size, num_rows),
                                        capacity=10,
                                        enqueue_many=True)
            col_batch = input_lib.batch([col_ids, sp_mat_t],
                                        batch_size=min(batch_size, num_cols),
                                        capacity=10,
                                        enqueue_many=True)

            features = extract_features(row_batch, col_batch,
                                        sp_mat.dense_shape)
            if projection_weights is not None:
                weights_batch = input_lib.batch(projection_weights,
                                                batch_size=batch_size,
                                                capacity=10,
                                                enqueue_many=True)
                features[wals_lib.WALSMatrixFactorization.
                         PROJECTION_WEIGHTS] = (weights_batch)
            if project_row is not None:
                features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
                    constant_op.constant(project_row))

            labels = None
            return features, labels
Exemple #2
0
    def _fn():
      num_rows = np.shape(np_matrix)[0]
      num_cols = np.shape(np_matrix)[1]
      row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
      col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
      sp_mat = self.np_array_to_sparse(np_matrix)
      sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
      row_batch = input_lib.batch(
          [row_ids, sp_mat],
          batch_size=min(batch_size, num_rows),
          capacity=10,
          enqueue_many=True)
      col_batch = input_lib.batch(
          [col_ids, sp_mat_t],
          batch_size=min(batch_size, num_cols),
          capacity=10,
          enqueue_many=True)

      features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
      if projection_weights is not None:
        weights_batch = input_lib.batch(
            projection_weights,
            batch_size=batch_size,
            capacity=10,
            enqueue_many=True)
        features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
            weights_batch)
      if project_row is not None:
        features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
            constant_op.constant(project_row))

      labels = None
      return features, labels
 def testTransposePreservesShape(self):
     with ops.Graph().as_default():
         t = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                        values=[0.],
                                        dense_shape=[3, 4])
         self.assertTrue(t.shape.is_fully_defined)
         transposed = sparse_ops.sparse_transpose(t)
         self.assertAllEqual(transposed.shape, [4, 3])
Exemple #4
0
def _sparse_dense_dense_grad(op, grad):
  a    = op.inputs[0]
  b    = op.inputs[1]
  indices = op.inputs[2]
  print(a.get_shape(), b.get_shape(), grad)
  result_shape = tf.cast(tf.stack([tf.shape(a)[0], tf.shape(b)[0]]), dtype=dtypes.int64)
  grad = tf.SparseTensor(indices=indices, values=grad, dense_shape=result_shape)

  grad_T      = sparse_ops.sparse_transpose(grad)
  grad_a      = sparse_ops.sparse_tensor_dense_matmul(grad, b)
  grad_b      = sparse_ops.sparse_tensor_dense_matmul(grad_T, a)
  return [grad_a, grad_b, None] 
 def testTranspose(self):
   with self.test_session(use_gpu=False):
     np.random.seed(1618)
     shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)]
     for shape in shapes:
       for dtype in [np.int32, np.int64, np.float32, np.float64]:
         dn_input = np.random.randn(*shape).astype(dtype)
         rank = array_ops.rank(dn_input).eval()
         perm = np.random.choice(rank, rank, False)
         sp_input, unused_a_nnz = _sparsify(dn_input)
         sp_trans = sparse_ops.sparse_transpose(sp_input, perm=perm)
         dn_trans = sparse_ops.sparse_tensor_to_dense(sp_trans).eval()
         expected_trans = array_ops.transpose(dn_input, perm=perm).eval()
         self.assertAllEqual(dn_trans, expected_trans)
 def testTranspose(self):
   with self.test_session(use_gpu=False):
     np.random.seed(1618)
     shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)]
     for shape in shapes:
       for dtype in [np.int32, np.int64, np.float32, np.float64]:
         dn_input = np.random.randn(*shape).astype(dtype)
         rank = array_ops.rank(dn_input).eval()
         perm = np.random.choice(rank, rank, False)
         sp_input, unused_a_nnz = _sparsify(dn_input)
         sp_trans = sparse_ops.sparse_transpose(sp_input, perm=perm)
         dn_trans = sparse_ops.sparse_tensor_to_dense(sp_trans).eval()
         expected_trans = array_ops.transpose(dn_input, perm=perm).eval()
         self.assertAllEqual(dn_trans, expected_trans)
  def testTranspose(self):
    if np.__version__ == "1.13.0":
      self.skipTest("numpy 1.13.0 bug")

    with test_util.force_cpu():
      np.random.seed(1618)
      shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)]
      for shape in shapes:
        for dtype in [np.int32, np.int64, np.float32, np.float64]:
          dn_input = np.random.randn(*shape).astype(dtype)
          rank = self.evaluate(array_ops.rank(dn_input))
          perm = np.random.choice(rank, rank, False)
          sp_input, unused_a_nnz = _sparsify(dn_input)
          sp_trans = sparse_ops.sparse_transpose(sp_input, perm=perm)
          dn_trans = sparse_ops.sparse_tensor_to_dense(sp_trans)
          expected_trans = array_ops.transpose(dn_input, perm=perm)
          self.assertAllEqual(expected_trans.shape, sp_trans.get_shape())
          self.assertAllEqual(dn_trans, expected_trans)
  def testTranspose(self):
    if np.__version__ == "1.13.0":
      self.skipTest("numpy 1.13.0 bug")

    with test_util.force_cpu():
      np.random.seed(1618)
      shapes = [np.random.randint(1, 10, size=rank) for rank in range(1, 6)]
      for shape in shapes:
        for dtype in [np.int32, np.int64, np.float32, np.float64]:
          dn_input = np.random.randn(*shape).astype(dtype)
          rank = self.evaluate(array_ops.rank(dn_input))
          perm = np.random.choice(rank, rank, False)
          sp_input, unused_a_nnz = _sparsify(dn_input)
          sp_trans = sparse_ops.sparse_transpose(sp_input, perm=perm)
          dn_trans = sparse_ops.sparse_tensor_to_dense(sp_trans)
          expected_trans = array_ops.transpose(dn_input, perm=perm)
          self.assertAllEqual(expected_trans.shape, sp_trans.get_shape())
          self.assertAllEqual(dn_trans, expected_trans)
Exemple #9
0
    def _fn():
      num_rows = np.shape(np_matrix)[0]
      num_cols = np.shape(np_matrix)[1]
      row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
      col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
      sp_mat = self.np_array_to_sparse(np_matrix)
      sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
      row_batch = input_lib.batch(
          [row_ids, sp_mat],
          batch_size=min(batch_size, num_rows),
          capacity=10,
          enqueue_many=True)
      col_batch = input_lib.batch(
          [col_ids, sp_mat_t],
          batch_size=min(batch_size, num_cols),
          capacity=10,
          enqueue_many=True)

      features = extract_features(row_batch, col_batch, num_rows, num_cols)

      if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
        self.assertTrue(
            project_row is not None,
            msg='project_row must be specified in INFER or EVAL mode.')
        features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
            constant_op.constant(project_row))

      if mode == model_fn.ModeKeys.INFER and projection_weights is not None:
        weights_batch = input_lib.batch(
            projection_weights,
            batch_size=batch_size,
            capacity=10,
            enqueue_many=True)
        features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
            weights_batch)

      labels = None
      return features, labels
Exemple #10
0
    def _fn():
      num_rows = np.shape(np_matrix)[0]
      num_cols = np.shape(np_matrix)[1]
      row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
      col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
      sp_mat = self.np_array_to_sparse(np_matrix)
      sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
      row_batch = input_lib.batch(
          [row_ids, sp_mat],
          batch_size=min(batch_size, num_rows),
          capacity=10,
          enqueue_many=True)
      col_batch = input_lib.batch(
          [col_ids, sp_mat_t],
          batch_size=min(batch_size, num_cols),
          capacity=10,
          enqueue_many=True)

      features = extract_features(row_batch, col_batch, num_rows, num_cols)

      if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
        self.assertTrue(
            project_row is not None,
            msg='project_row must be specified in INFER or EVAL mode.')
        features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
            constant_op.constant(project_row))

      if mode == model_fn.ModeKeys.INFER and projection_weights is not None:
        weights_batch = input_lib.batch(
            projection_weights,
            batch_size=batch_size,
            capacity=10,
            enqueue_many=True)
        features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
            weights_batch)

      labels = None
      return features, labels
Exemple #11
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to calculate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:1138
      # https://github.com/imdone/tensorflow/issues/1139
      # transposing explicitly.
      # TODO (rmlarsen): multi-thread tf.matrix_solve. id:1249
      # https://github.com/imdone/tensorflow/issues/1250
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO (yifanchen): Add special handling for single shard without using id:845
        # https://github.com/imdone/tensorflow/issues/846
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
Exemple #12
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to caluclate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
Exemple #13
0
 def test_fn(tensor):
     tensor = sparse_ops.sparse_transpose(tensor)
     self.assertEqual(tensor.shape.rank, 2)
     return tensor
    def __init__(self,
                 loss,
                 penalty,
                 At,
                 D,
                 b,
                 tau=None,
                 sigma=None,
                 sess=None,
                 dtype=dtypes.float32,
                 devices='',
                 aggregate=False,
                 init_var=None):
        if not tau:
            raise ValueError("Must set tau")
        if not sigma:
            raise ValueError("Must set sigma")
        if sess is None:
            sess = session.Session()

        self.tau = tau
        self.sigma = sigma
        self.aggregate = aggregate

        #print(type(tau), type(sigma))

        #assert type(tau)==type(sigma)

        if isinstance(self.tau, float):
            self.parammode = 'static'
        else:
            self.parammode = 'variable'

        self.sess = sess
        self.loss = loss
        self.penalty = penalty
        self.dtype = dtype
        self.devices = devices
        self.init_var = True if init_var else False
        if isinstance(self.devices, list):
            self.master = self.devices[0]
        else:
            self.master = self.devices

        if not isinstance(devices, list):
            self.matmul = tf.matmul
            self.spmatmul = tf.sparse_tensor_dense_matmul
        else:
            self.matmul = distops.matmul
            self.spmatmul = distops.spmatmul

        # check shape.
        self.m, self.n = At.T.shape
        self.l, _ = D.shape
        print(At.T.shape)
        print(D.shape)
        print(self.n, D.shape[1])
        assert (self.n == D.shape[1])

        # setup variables.
        # for A, we need to consider the case where A is larger than 2GB. (to be distributed)
        if not isinstance(At, distmat.DistMat):
            if not isinstance(devices, list):
                with tf.device(self.devices):
                    Ap = tf.placeholder(dtype, shape=At.shape)
                    self.At = variables.Variable(Ap)
                sess.run(self.At.initializer, feed_dict={Ap: At})
                self.A = None
            else:
                self.At = distmat.DistMat.from_dataset(At,
                                                       devices=self.devices,
                                                       sess=sess,
                                                       dtype=self.dtype)
                self.A = None
        else:
            assert all([d1 == d2 for d1, d2 in zip(At.devices, self.devices)])
            self.At = At

        if not isinstance(D, distmat.DistSpMat):

            # D should be COO format.
            # dtype induced from original matrix.
            # avoid recomputation of transpose.
            if isinstance(D, coo_matrix):
                pass
            elif isinstance(D, spmatrix):
                D = D.tocoo()
            else:
                raise ValueError("must be a scipy sparse matrix")
            if not isinstance(devices, list):
                D_tensor = coo_to_sparsetensor(D)
                D_sorted_op = sparse_ops.sparse_reorder(D_tensor)
                Dt_sorted_op = sparse_ops.sparse_reorder(
                    sparse_ops.sparse_transpose(D_tensor))
                D_sorted, Dt_sorted = sess.run([D_sorted_op, Dt_sorted_op])
                with tf.device(self.devices):
                    self.D = sparse_tensor.SparseTensor.from_value(D_sorted)
                    self.Dt = sparse_tensor.SparseTensor.from_value(Dt_sorted)

                    # b is a constant. (to duplicate)
            else:
                self.D = distmat.DistSpMat.from_spmatrix(
                    D, devices_r=self.devices)
                self.Dt = None
        else:
            assert all([d1 == d2 for d1, d2 in zip(D.devices_r, self.devices)])
            self.D = D
            self.Dt = None

        with tf.device(self.master):
            self.b = constant_op.constant(b, dtype=dtype)

            self._setup_variables(init_var)
            self._setup_evals()