def testShapeInference(self):
        x = np.random.rand(10, 10)
        x[np.abs(x) < 0.5] = 0  # Make it sparse
        y = np.random.randn(10, 20)
        x_indices = np.vstack(np.where(x)).astype(np.int64).T
        x_values = x[np.where(x)]
        x_shape = x.shape
        x_st = tf.SparseTensor(x_indices, x_values, x_shape)
        result = tf.sparse_tensor_dense_matmul(x_st, y)
        self.assertEqual(result.get_shape(), (10, 20))

        x_shape_unknown = tf.placeholder(dtype=tf.int64, shape=None)
        x_st_shape_unknown = tf.SparseTensor(x_indices, x_values, x_shape_unknown)
        result_left_shape_unknown = tf.sparse_tensor_dense_matmul(x_st_shape_unknown, y)
        self.assertEqual(result_left_shape_unknown.get_shape().as_list(), [None, 20])

        x_shape_inconsistent = [10, 15]
        x_st_shape_inconsistent = tf.SparseTensor(x_indices, x_values, x_shape_inconsistent)
        with self.assertRaisesRegexp(ValueError, "Dimensions must be equal"):
            tf.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y)
    def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype, use_gpu=False):
        n, k, m = np.random.randint(1, 10, size=3)
        sp_t = self._randomTensor([n, k], np_dtype, adjoint=adjoint_a, sparse=True)
        dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)

        matmul = tf.sparse_tensor_dense_matmul(sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)

        with self.test_session(use_gpu=use_gpu):
            dense_t_shape = [m, k] if adjoint_b else [k, m]
            err = tf.test.compute_gradient_error(dense_t, dense_t_shape, matmul, [n, m])
            print("%s gradient err = %s" % (name, err))
            self.assertLess(err, 1e-3)
    def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update the row_factors, else update the
        column factors.
      sp_input: Please refer to comments for update_row_factors and
        update_col_factors.
      transpose_input: If true, logically transpose the input.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
        assert isinstance(sp_input, ops.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right = self._col_factors_cache
            row_weights = self._row_wt_cache
            col_weights = self._col_wt_cache
            sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards)
            right_length = self._input_cols
        else:
            left = self._col_factors
            right = self._row_factors_cache
            row_weights = self._col_wt_cache
            col_weights = self._row_wt_cache
            sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards)
            right_length = self._input_rows
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = tf.split(1, 2, sp_input.indices)

        if transpose_input:
            update_indices, all_ids = tf.unique(col_ids[:, 0])
            col_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1)
        else:
            update_indices, all_ids = tf.unique(row_ids[:, 0])
            row_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1)

        num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64)
        row_shape = tf.constant([right_length], tf.int64)
        col_shape = [num_rows]

        new_sp_indices = tf.concat(1, [row_ids, col_ids])
        new_sp_shape = tf.concat(0, [row_shape, col_shape]) if transpose_input else tf.concat(0, [col_shape, row_shape])
        new_sp_input = tf.SparseTensor(indices=new_sp_indices, values=sp_input.values, shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = self._unobserved_weight * tf.matmul(right, right, transpose_a=True)
        if self._regularization is not None:
            total_lhs += self._regularization
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = self._unobserved_weight * tf.sparse_tensor_dense_matmul(
                new_sp_input, right, adjoint_a=transpose_input
            )
            # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
            # transposing explicitly.
            # TODO(rmlarsen): multi-thread tf.matrix_solve.
            new_left_values = tf.transpose(tf.matrix_solve(total_lhs, tf.transpose(total_rhs)))
        else:
            row_weights_slice = tf.gather(row_weights, update_indices)
            partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
                right,
                col_weights,
                self._unobserved_weight,
                row_weights_slice,
                new_sp_input.indices,
                new_sp_input.values,
                num_rows,
                transpose_input,
                name="wals_compute_partial_lhs_rhs",
            )
            total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = tf.expand_dims(total_rhs, -1)
            new_left_values = tf.squeeze(tf.batch_matrix_solve(total_lhs, total_rhs), [2])

        return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))
    def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None):
        """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following two elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
    """
        assert isinstance(sp_input, tf.SparseTensor)

        if update_row_factors:
            left = self._row_factors
            right_factors = self._col_factors_cache
            row_wt = self._row_wt_cache
            col_wt = self._col_wt_cache
            sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards)
            gramian = self._col_gramian_cache
        else:
            left = self._col_factors
            right_factors = self._row_factors_cache
            row_wt = self._col_wt_cache
            col_wt = self._row_wt_cache
            sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards)
            gramian = self._row_gramian_cache
            transpose_input = not transpose_input

        # Note that the row indices of sp_input are based on the original full input
        # Here we reindex the rows and give them contiguous ids starting at 0.
        # We use tf.unique to achieve this reindexing. Note that this is done so
        # that the downstream kernel can assume that the input is "dense" along the
        # row dimension.
        row_ids, col_ids = tf.split(value=sp_input.indices, num_or_size_splits=2, axis=1)
        update_row_indices, all_row_ids = tf.unique(row_ids[:, 0])
        update_col_indices, all_col_ids = tf.unique(col_ids[:, 0])
        col_ids = tf.expand_dims(tf.cast(all_col_ids, tf.int64), 1)
        row_ids = tf.expand_dims(tf.cast(all_row_ids, tf.int64), 1)

        if transpose_input:
            update_indices = update_col_indices
            row_shape = [tf.cast(tf.shape(update_row_indices)[0], tf.int64)]
            gather_indices = update_row_indices
        else:
            update_indices = update_row_indices
            row_shape = [tf.cast(tf.shape(update_col_indices)[0], tf.int64)]
            gather_indices = update_col_indices

        num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64)
        col_shape = [num_rows]
        right = embedding_ops.embedding_lookup(right_factors, gather_indices, partition_strategy="div")
        new_sp_indices = tf.concat_v2([row_ids, col_ids], 1)
        new_sp_shape = (
            tf.concat_v2([row_shape, col_shape], 0) if transpose_input else tf.concat_v2([col_shape, row_shape], 0)
        )
        new_sp_input = tf.SparseTensor(indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape)

        # Compute lhs and rhs of the normal equations
        total_lhs = self._unobserved_weight * gramian
        if self._regularization is not None:
            total_lhs += self._regularization
        if self._row_weights is None:
            # Special case of ALS. Use a much simpler update rule.
            total_rhs = self._unobserved_weight * tf.sparse_tensor_dense_matmul(
                new_sp_input, right, adjoint_a=transpose_input
            )
            # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
            # transposing explicitly.
            # TODO(rmlarsen): multi-thread tf.matrix_solve.
            new_left_values = tf.transpose(tf.matrix_solve(total_lhs, tf.transpose(total_rhs)))
        else:
            if row_weights is None:
                # TODO(yifanchen): Add special handling for single shard without using
                # embedding_lookup and perform benchmarks for those cases. Same for
                # col_weights lookup below.
                row_weights_slice = embedding_ops.embedding_lookup(row_wt, update_indices, partition_strategy="div")
            else:
                with ops.control_dependencies([tf.assert_less_equal(tf.rank(row_weights), 1)]):
                    row_weights_slice = tf.cond(
                        tf.equal(tf.rank(row_weights), 0),
                        lambda: (tf.ones([tf.shape(update_indices)[0]]) * row_weights),
                        lambda: tf.cast(row_weights, tf.float32),
                    )

            col_weights = embedding_ops.embedding_lookup(col_wt, gather_indices, partition_strategy="div")
            partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
                right,
                col_weights,
                self._unobserved_weight,
                row_weights_slice,
                new_sp_input.indices,
                new_sp_input.values,
                num_rows,
                transpose_input,
                name="wals_compute_partial_lhs_rhs",
            )
            total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs
            total_rhs = tf.expand_dims(total_rhs, -1)
            new_left_values = tf.squeeze(tf.matrix_solve(total_lhs, total_rhs), [2])

        return (new_left_values, self.scatter_update(left, update_indices, new_left_values, sharding_func))
Example #5
0
 def _train_fprop(self, state_below):
     idx, val = state_below
     X = tf.SparseTensor(tf.cast(idx, "int64"), val, shape=[self.batchsize, self.prev_dim])
     X_order = tf.sparse_reorder(X)
     XW = tf.sparse_tensor_dense_matmul(X_order, self.W, adjoint_a=False, adjoint_b=False)
     return tf.add(XW, self.b)