Пример #1
0
def concat_along_batch_dimension(outputs):
    """Concats prediction outputs along the batch dimension."""
    if isinstance(outputs[0], sparse_tensor.SparseTensor):
        return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
    if isinstance(outputs[0], ragged_tensor.RaggedTensor):
        return array_ops.concat(outputs, axis=0)
    return np.concatenate(outputs)
Пример #2
0
    def call(self, inputs):

        depth_tuple = self._depth_tuple if self.depth else (len(inputs), )
        ragged_out = sparse_out = False
        if all([ragged_tensor.is_ragged(inp) for inp in inputs]):
            # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will
            # cause kernel failure. Investigate and find a more efficient
            # implementation
            inputs = [inp.to_sparse() for inp in inputs]
            ragged_out = True
        else:
            if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
                raise ValueError(
                    'Inputs must be either all `RaggedTensor`, or none of them should '
                    'be `RaggedTensor`, got {}'.format(inputs))

            if any([
                    isinstance(inp, sparse_tensor.SparseTensor)
                    for inp in inputs
            ]):
                sparse_out = True

        outputs = []
        for depth in depth_tuple:
            if len(inputs) < depth:
                raise ValueError(
                    'Number of inputs cannot be less than depth, got {} input tensors, '
                    'and depth {}'.format(len(inputs), depth))
            for partial_inps in itertools.combinations(inputs, depth):
                partial_out = self.partial_crossing(partial_inps, ragged_out,
                                                    sparse_out)
                outputs.append(partial_out)
        if sparse_out:
            return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs)
        return array_ops.concat(outputs, axis=1)
Пример #3
0
  def testSparseConcatStaticShape(self):
    if context.executing_eagerly():
      self.skipTest('sparse_spaceholder is only available in graph context.')
    input_a = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 1))
    input_b = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 2))

    result = sparse_ops.sparse_concat_v2(axis=1, sp_inputs=[input_a, input_b])
    self.assertEqual(result.shape, [2, 3])
Пример #4
0
  def call(self, inputs):
    depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
    ragged_out = sparse_out = False
    if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
      ragged_out = True
    elif any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
      sparse_out = True

    outputs = []
    for depth in depth_tuple:
      if len(inputs) < depth:
        raise ValueError(
            'Number of inputs cannot be less than depth, got {} input tensors, '
            'and depth {}'.format(len(inputs), depth))
      for partial_inps in itertools.combinations(inputs, depth):
        partial_out = self.partial_crossing(
            partial_inps, ragged_out, sparse_out)
        outputs.append(partial_out)
    if sparse_out:
      return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs)
    return array_ops.concat(outputs, axis=1)