def testAddManyTakeManyRoundTripBatched(self): with self.test_session(use_gpu=False) as sess: # N == 4 because shape_value == [4, 5] indices_value_1 = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64) values_value_1 = np.array([b"a", b"b", b"c"]) shape_value_1 = np.array([4, 5], dtype=np.int64) sparse_tensor_1 = sparse_tensor.SparseTensor( array_ops.placeholder(dtypes.int64), array_ops.placeholder(dtypes.string), array_ops.placeholder(dtypes.int64)) dict1 = {"key": sparse_tensor_1} indices_value_2 = np.array([[1, 4], [2, 3]], dtype=np.int64) values_value_2 = np.array([b"d", b"e"]) shape_value_2 = np.array([4, 5], dtype=np.int64) sparse_tensor_2 = sparse_tensor.SparseTensor( array_ops.placeholder(dtypes.int64), array_ops.placeholder(dtypes.string), array_ops.placeholder(dtypes.int64)) dict2 = {"key": sparse_tensor_2} input_seq1, keys1, tensor_list1 = sqss._deconstruct_sparse_tensor_seq( dict1, shared_name="a") handles_1 = input_seq1["key"] input_seq2, _, _ = sqss._deconstruct_sparse_tensor_seq( dict2, shared_name="a") handles_2 = input_seq2["key"] combined_handles = array_ops.stack([ handles_1[1], handles_1[2], handles_1[3], handles_2[1], handles_2[2], handles_2[3] ]) batched_dict = {"key": combined_handles} sqss._reconstruct_sparse_tensor_seq(batched_dict, keys1, tensor_list1, batch_size=2, num_unroll=3) roundtrip_value, = sess.run( [batched_dict["key"]], feed_dict={ sparse_tensor_1.indices: indices_value_1, sparse_tensor_1.values: values_value_1, sparse_tensor_1.dense_shape: shape_value_1, sparse_tensor_2.indices: indices_value_2, sparse_tensor_2.values: values_value_2, sparse_tensor_2.dense_shape: shape_value_2 }) self.assertAllEqual( roundtrip_value.indices, np.array([[0, 1, 0], [1, 0, 4], [1, 1, 3]], dtype=np.int64)) self.assertAllEqual(roundtrip_value.values, np.array([b"c", b"d", b"e"])) self.assertAllEqual(roundtrip_value.dense_shape, np.array([2, 3, 5], dtype=np.int64))
def testAddManyTakeManyRoundTripBatched(self): with self.test_session(use_gpu=False) as sess: # N == 4 because shape_value == [4, 5] indices_value_1 = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64) values_value_1 = np.array([b"a", b"b", b"c"]) shape_value_1 = np.array([4, 5], dtype=np.int64) sparse_tensor_1 = sparse_tensor.SparseTensor( array_ops.placeholder(dtypes.int64), array_ops.placeholder(dtypes.string), array_ops.placeholder(dtypes.int64)) dict1 = {"key": sparse_tensor_1} indices_value_2 = np.array([[1, 4], [2, 3]], dtype=np.int64) values_value_2 = np.array([b"d", b"e"]) shape_value_2 = np.array([4, 5], dtype=np.int64) sparse_tensor_2 = sparse_tensor.SparseTensor( array_ops.placeholder(dtypes.int64), array_ops.placeholder(dtypes.string), array_ops.placeholder(dtypes.int64)) dict2 = {"key": sparse_tensor_2} input_seq1, keys1, tensor_list1 = sqss._deconstruct_sparse_tensor_seq( dict1, shared_name="a") handles_1 = input_seq1["key"] input_seq2, _, _ = sqss._deconstruct_sparse_tensor_seq( dict2, shared_name="a") handles_2 = input_seq2["key"] combined_handles = array_ops.stack( [handles_1[1], handles_1[2], handles_1[3], handles_2[1], handles_2[2], handles_2[3]]) batched_dict = {"key": combined_handles} sqss._reconstruct_sparse_tensor_seq( batched_dict, keys1, tensor_list1, batch_size=2, num_unroll=3) roundtrip_value, = sess.run( [batched_dict["key"]], feed_dict={sparse_tensor_1.indices: indices_value_1, sparse_tensor_1.values: values_value_1, sparse_tensor_1.dense_shape: shape_value_1, sparse_tensor_2.indices: indices_value_2, sparse_tensor_2.values: values_value_2, sparse_tensor_2.dense_shape: shape_value_2}) self.assertAllEqual(roundtrip_value.indices, np.array([[0, 1, 0], [1, 0, 4], [1, 1, 3]], dtype=np.int64)) self.assertAllEqual(roundtrip_value.values, np.array([b"c", b"d", b"e"])) self.assertAllEqual(roundtrip_value.dense_shape, np.array([2, 3, 5], dtype=np.int64))