Example #1
0
def sparse_split(split_dim, num_split, sp_input, name=None):
    """Split a `SparseTensor` into `num_split` tensors along `split_dim`.

  If the `sp_input.shape[split_dim]` is not an integer multiple of `num_split`
  each slice starting from 0:`shape[split_dim] % num_split` gets extra one
  dimension. For example, if `split_dim = 1` and `num_split = 2` and the
  input is:

      input_tensor = shape = [2, 7]
      [    a   d e  ]
      [b c          ]

  Graphically the output tensors are:

      output_tensor[0] =
      [    a ]
      [b c   ]

      output_tensor[1] =
      [ d e  ]
      [      ]

  Args:
    split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
    num_split: A Python integer. The number of ways to split.
    sp_input: The `SparseTensor` to split.
    name: A name for the operation (optional).

  Returns:
    `num_split` `SparseTensor` objects resulting from splitting `value`.

  Raises:
    TypeError: If `sp_input` is not a `SparseTensor`.
  """
    if not isinstance(sp_input, ops.SparseTensor):
        raise TypeError("Input must be a SparseTensor")

    output_inds, output_vals, output_shapes = (gen_sparse_ops._sparse_split(
        split_dim,
        sp_input.indices,
        sp_input.values,
        sp_input.shape,
        num_split,
        name=name))
    sparse_tensors = []
    for i in range(0, num_split):
        sparse_tensors.append(
            ops.SparseTensor(output_inds[i], output_vals[i], output_shapes[i]))
    return sparse_tensors
Example #2
0
def sparse_split(split_dim, num_split, sp_input, name=None):
  """Split a `SparseTensor` into `num_split` tensors along `split_dim`.

  If the `sp_input.shape[split_dim]` is not an integer multiple of `num_split`
  each slice starting from 0:`shape[split_dim] % num_split` gets extra one
  dimension. For example, if `split_dim = 1` and `num_split = 2` and the
  input is:

      input_tensor = shape = [2, 7]
      [    a   d e  ]
      [b c          ]

  Graphically the output tensors are:

      output_tensor[0] =
      [    a ]
      [b c   ]

      output_tensor[1] =
      [ d e  ]
      [      ]

  Args:
    split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
    num_split: A Python integer. The number of ways to split.
    sp_input: The `SparseTensor` to split.
    name: A name for the operation (optional).

  Returns:
    `num_split` `SparseTensor` objects resulting from splitting `value`.

  Raises:
    TypeError: If `sp_input` is not a `SparseTensor`.
  """
  if not isinstance(sp_input, ops.SparseTensor):
    raise TypeError("Input must be a SparseTensor")

  output_inds, output_vals, output_shapes = (
      gen_sparse_ops._sparse_split(split_dim,
                                   sp_input.indices,
                                   sp_input.values,
                                   sp_input.shape,
                                   num_split,
                                   name=name))
  sparse_tensors = []
  for i in range(0, num_split):
    sparse_tensors.append(ops.SparseTensor(output_inds[i], output_vals[i],
                                           output_shapes[i]))
  return sparse_tensors