Ejemplo n.º 1
0
 def _runTestEmptyMaskedProduct(self):
     with ops.Graph().as_default(), self.test_session() as sess:
         empty_mask = constant_op.constant(0,
                                           shape=[0, 2],
                                           dtype=dtypes.int64)
         values = gen_factorization_ops.masked_matmul(
             self._a, self._b, empty_mask, False, False)
         self.assertEqual(len(values.eval(session=sess)), 0)
  def _runTestMaskedProduct(self, transpose_a, transpose_b):
    with ops.Graph().as_default(), self.test_session() as sess:
      a = self._a if not transpose_a else array_ops.transpose(self._a)
      b = self._b if not transpose_b else array_ops.transpose(self._b)

      def AssertClose(sp_x, sp_y):
        x_inds, x_vals, y_inds, y_vals = sess.run(
            [sp_x.indices, sp_x.values,
             sp_y.indices, sp_y.values])
        self.assertAllClose(x_inds, y_inds)
        self.assertAllClose(x_vals, y_vals)

      values = gen_factorization_ops.masked_matmul(
          a, b, self._mask_ind, transpose_a, transpose_b)
      result = sparse_tensor.SparseTensor(
          self._mask_ind, values, self._mask_shape)
      true_result = sparse_tensor.SparseTensor(
          self._mask_ind, self._dot_products, self._mask_shape)
      AssertClose(result, true_result)
Ejemplo n.º 3
0
    def _runTestMaskedProduct(self, transpose_a, transpose_b):
        with ops.Graph().as_default(), self.test_session() as sess:
            a = self._a if not transpose_a else array_ops.transpose(self._a)
            b = self._b if not transpose_b else array_ops.transpose(self._b)

            def AssertClose(sp_x, sp_y):
                x_inds, x_vals, y_inds, y_vals = sess.run(
                    [sp_x.indices, sp_x.values, sp_y.indices, sp_y.values])
                self.assertAllClose(x_inds, y_inds)
                self.assertAllClose(x_vals, y_vals)

            values = gen_factorization_ops.masked_matmul(
                a, b, self._mask_ind, transpose_a, transpose_b)
            result = sparse_tensor.SparseTensor(self._mask_ind, values,
                                                self._mask_shape)
            true_result = sparse_tensor.SparseTensor(self._mask_ind,
                                                     self._dot_products,
                                                     self._mask_shape)
            AssertClose(result, true_result)
Ejemplo n.º 4
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to calculate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:1138
      # https://github.com/imdone/tensorflow/issues/1139
      # transposing explicitly.
      # TODO (rmlarsen): multi-thread tf.matrix_solve. id:1249
      # https://github.com/imdone/tensorflow/issues/1250
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO (yifanchen): Add special handling for single shard without using id:845
        # https://github.com/imdone/tensorflow/issues/846
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
 def _runTestEmptyMaskedProduct(self):
   with ops.Graph().as_default(), self.test_session() as sess:
     empty_mask = constant_op.constant(0, shape=[0, 2], dtype=dtypes.int64)
     values = gen_factorization_ops.masked_matmul(
         self._a, self._b, empty_mask, False, False)
     self.assertEqual(len(values.eval(session=sess)), 0)
Ejemplo n.º 6
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to caluclate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
Ejemplo n.º 7
0
    def _run_graph(self,
                   a_shape,
                   b_shape,
                   nnz,
                   num_iters,
                   sort=False,
                   transpose_a=False,
                   transpose_b=False):
        """Run the graph and return its average execution time.

    Args:
      a_shape: int list, the shape of the a matrix.
      b_shape: int list, the shape of the b matrix.
      nnz: int, the number of non-zero elements in the mask.
      num_iters: int, the number of iterations to run (the output is the average
        execution time, over num_iters).
      sort: Boolean, whether to sort the indices in the mask.
      transpose_a: boolean, whether to transpose the a matrix.
      transpose_b: boolean, whether to transpose the b matrix.

    Returns:
      The average duration of the masked_matmul op in seconds.
    """
        graph = ops.Graph()

        with graph.as_default(), session_lib.Session(graph=graph) as session:
            mask_shape = [a_shape[0], b_shape[1]]
            a_shape = a_shape if not transpose_a else [a_shape[1], a_shape[0]]
            b_shape = b_shape if not transpose_b else [b_shape[1], b_shape[0]]
            a_var = variables.Variable(random_ops.random_normal(a_shape))
            b_var = variables.Variable(random_ops.random_normal(b_shape))
            mask_indices_ph = array_ops.placeholder(dtypes.int64,
                                                    shape=[nnz, 2])
            a_ph = array_ops.placeholder(dtypes.float32, shape=a_shape)
            b_ph = array_ops.placeholder(dtypes.float32, shape=b_shape)
            mask = self._make_sparse_mask(mask_shape, nnz, sort)
            masked_prod = gen_factorization_ops.masked_matmul(
                a_ph, b_ph, mask_indices_ph, transpose_a, transpose_b)
            with ops.control_dependencies([masked_prod]):
                result = control_flow_ops.no_op()

            variables.global_variables_initializer().run()
            avg_wall_time = 0
            for _ in range(num_iters):
                a, b, mask_indices = session.run([a_var, b_var, mask.indices])
                feed_dict = {mask_indices_ph: mask_indices, a_ph: a, b_ph: b}
                start_time = time.time()
                session.run(result, feed_dict=feed_dict)
                avg_wall_time += (time.time() - start_time) / num_iters

            bench_name = (
                "cpu nnz:{nnz} a_shape:{a_shape} b_shape:{b_shape} tr_a:{tr_a} "
                "tr_b:{tr_b} sort:{sort}").format(nnz=nnz,
                                                  a_shape=a_shape,
                                                  b_shape=b_shape,
                                                  tr_a=int(transpose_a),
                                                  tr_b=int(transpose_b),
                                                  sort=int(sort))
            print(bench_name + " - %f secs" % avg_wall_time)
            name = bench_name.replace(", ",
                                      "_").replace(":", "_").replace(" ", "_")
            self.report_benchmark(name=name,
                                  iters=num_iters,
                                  wall_time=avg_wall_time)

        return avg_wall_time
  def _run_graph(self, a_shape, b_shape, nnz, num_iters, sort=False,
                 transpose_a=False, transpose_b=False):
    """Run the graph and return its average execution time.

    Args:
      a_shape: int list, the shape of the a matrix.
      b_shape: int list, the shape of the b matrix.
      nnz: int, the number of non-zero elements in the mask.
      num_iters: int, the number of iterations to run (the output is the average
        execution time, over num_iters).
      sort: Boolean, whether to sort the indices in the mask.
      transpose_a: boolean, whether to transpose the a matrix.
      transpose_b: boolean, whether to transpose the b matrix.

    Returns:
      The average duration of the masked_matmul op in seconds.
    """
    graph = ops.Graph()

    with graph.as_default(), session_lib.Session(graph=graph) as session:
      mask_shape = [a_shape[0], b_shape[1]]
      a_shape = a_shape if not transpose_a else [a_shape[1], a_shape[0]]
      b_shape = b_shape if not transpose_b else [b_shape[1], b_shape[0]]
      a_var = variables.Variable(random_ops.random_normal(a_shape))
      b_var = variables.Variable(random_ops.random_normal(b_shape))
      mask_indices_ph = array_ops.placeholder(dtypes.int64, shape=[nnz, 2])
      a_ph = array_ops.placeholder(dtypes.float32, shape=a_shape)
      b_ph = array_ops.placeholder(dtypes.float32, shape=b_shape)
      mask = self._make_sparse_mask(mask_shape, nnz, sort)
      masked_prod = gen_factorization_ops.masked_matmul(
          a_ph, b_ph, mask_indices_ph, transpose_a, transpose_b)
      with ops.control_dependencies([masked_prod]):
        result = control_flow_ops.no_op()

      variables.global_variables_initializer().run()
      avg_wall_time = 0
      for _ in range(num_iters):
        a, b, mask_indices = session.run([a_var, b_var, mask.indices])
        feed_dict = {
            mask_indices_ph: mask_indices,
            a_ph: a,
            b_ph: b
        }
        start_time = time.time()
        session.run(result, feed_dict=feed_dict)
        avg_wall_time += (time.time() - start_time)/num_iters

      bench_name = (
          "cpu nnz:{nnz} a_shape:{a_shape} b_shape:{b_shape} tr_a:{tr_a} "
          "tr_b:{tr_b} sort:{sort}"
      ).format(
          nnz=nnz,
          a_shape=a_shape,
          b_shape=b_shape,
          tr_a=int(transpose_a),
          tr_b=int(transpose_b),
          sort=int(sort)
      )
      print(bench_name + " - %f secs" % avg_wall_time)
      name = bench_name.replace(", ", "_").replace(":", "_").replace(" ", "_")
      self.report_benchmark(
          name=name,
          iters=num_iters,
          wall_time=avg_wall_time)

    return avg_wall_time