def sparse_split(split_dim, num_split, sp_input, name=None): """Split a `SparseTensor` into `num_split` tensors along `split_dim`. If the `sp_input.shape[split_dim]` is not an integer multiple of `num_split` each slice starting from 0:`shape[split_dim] % num_split` gets extra one dimension. For example, if `split_dim = 1` and `num_split = 2` and the input is: input_tensor = shape = [2, 7] [ a d e ] [b c ] Graphically the output tensors are: output_tensor[0] = [ a ] [b c ] output_tensor[1] = [ d e ] [ ] Args: split_dim: A 0-D `int32` `Tensor`. The dimension along which to split. num_split: A Python integer. The number of ways to split. sp_input: The `SparseTensor` to split. name: A name for the operation (optional). Returns: `num_split` `SparseTensor` objects resulting from splitting `value`. Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ if not isinstance(sp_input, ops.SparseTensor): raise TypeError("Input must be a SparseTensor") output_inds, output_vals, output_shapes = (gen_sparse_ops._sparse_split( split_dim, sp_input.indices, sp_input.values, sp_input.shape, num_split, name=name)) sparse_tensors = [] for i in range(0, num_split): sparse_tensors.append( ops.SparseTensor(output_inds[i], output_vals[i], output_shapes[i])) return sparse_tensors
def sparse_split(split_dim, num_split, sp_input, name=None): """Split a `SparseTensor` into `num_split` tensors along `split_dim`. If the `sp_input.shape[split_dim]` is not an integer multiple of `num_split` each slice starting from 0:`shape[split_dim] % num_split` gets extra one dimension. For example, if `split_dim = 1` and `num_split = 2` and the input is: input_tensor = shape = [2, 7] [ a d e ] [b c ] Graphically the output tensors are: output_tensor[0] = [ a ] [b c ] output_tensor[1] = [ d e ] [ ] Args: split_dim: A 0-D `int32` `Tensor`. The dimension along which to split. num_split: A Python integer. The number of ways to split. sp_input: The `SparseTensor` to split. name: A name for the operation (optional). Returns: `num_split` `SparseTensor` objects resulting from splitting `value`. Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ if not isinstance(sp_input, ops.SparseTensor): raise TypeError("Input must be a SparseTensor") output_inds, output_vals, output_shapes = ( gen_sparse_ops._sparse_split(split_dim, sp_input.indices, sp_input.values, sp_input.shape, num_split, name=name)) sparse_tensors = [] for i in range(0, num_split): sparse_tensors.append(ops.SparseTensor(output_inds[i], output_vals[i], output_shapes[i])) return sparse_tensors