示例#1
0
 def test_sparse_value_concatenation(self):
     tensor_1 = sparse_tensor.SparseTensorValue([[0, 0]], [1], [1, 1])
     tensor_2 = sparse_tensor.SparseTensorValue([[0, 0]], [2], [1, 1])
     concatenated_tensor = training_utils_v1._append_composite_tensor(
         tensor_1, tensor_2)
     self.assertAllEqual(concatenated_tensor.indices, [[0, 0], [1, 0]])
     self.assertAllEqual(concatenated_tensor.values, [1, 2])
     self.assertAllEqual(concatenated_tensor.dense_shape, [2, 1])
示例#2
0
    def test_ragged_value_concatenation(self):
        tensor_1 = ragged_tensor_value.RaggedTensorValue(
            np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))
        tensor_2 = ragged_tensor_value.RaggedTensorValue(
            np.array([3, 4, 5]), np.array([0, 2, 3], dtype=np.int64))
        concatenated_tensor = training_utils_v1._append_composite_tensor(
            tensor_1, tensor_2)

        self.assertAllEqual(concatenated_tensor.values, [0, 1, 2, 3, 4, 5])
        self.assertAllEqual(concatenated_tensor.row_splits, [0, 1, 3, 5, 6])