def concat_along_batch_dimension(outputs): """Concats prediction outputs along the batch dimension.""" if isinstance(outputs[0], sparse_tensor.SparseTensor): return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs) if isinstance(outputs[0], ragged_tensor.RaggedTensor): return array_ops.concat(outputs, axis=0) return np.concatenate(outputs)
def call(self, inputs): depth_tuple = self._depth_tuple if self.depth else (len(inputs), ) ragged_out = sparse_out = False if all([ragged_tensor.is_ragged(inp) for inp in inputs]): # (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will # cause kernel failure. Investigate and find a more efficient # implementation inputs = [inp.to_sparse() for inp in inputs] ragged_out = True else: if any([ragged_tensor.is_ragged(inp) for inp in inputs]): raise ValueError( 'Inputs must be either all `RaggedTensor`, or none of them should ' 'be `RaggedTensor`, got {}'.format(inputs)) if any([ isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs ]): sparse_out = True outputs = [] for depth in depth_tuple: if len(inputs) < depth: raise ValueError( 'Number of inputs cannot be less than depth, got {} input tensors, ' 'and depth {}'.format(len(inputs), depth)) for partial_inps in itertools.combinations(inputs, depth): partial_out = self.partial_crossing(partial_inps, ragged_out, sparse_out) outputs.append(partial_out) if sparse_out: return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs) return array_ops.concat(outputs, axis=1)
def testSparseConcatStaticShape(self): if context.executing_eagerly(): self.skipTest('sparse_spaceholder is only available in graph context.') input_a = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 1)) input_b = array_ops.sparse_placeholder(dtypes.float32, shape=(2, 2)) result = sparse_ops.sparse_concat_v2(axis=1, sp_inputs=[input_a, input_b]) self.assertEqual(result.shape, [2, 3])
def call(self, inputs): depth_tuple = self._depth_tuple if self.depth else (len(inputs),) ragged_out = sparse_out = False if any([ragged_tensor.is_ragged(inp) for inp in inputs]): ragged_out = True elif any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]): sparse_out = True outputs = [] for depth in depth_tuple: if len(inputs) < depth: raise ValueError( 'Number of inputs cannot be less than depth, got {} input tensors, ' 'and depth {}'.format(len(inputs), depth)) for partial_inps in itertools.combinations(inputs, depth): partial_out = self.partial_crossing( partial_inps, ragged_out, sparse_out) outputs.append(partial_out) if sparse_out: return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs) return array_ops.concat(outputs, axis=1)