def testInsertingNonScalarFails(self): with self.cached_session() as sess: input_priority = array_ops.placeholder(dtypes.int64) input_other = array_ops.placeholder(dtypes.string) q = data_flow_ops.PriorityQueue(2000, (dtypes.string, ), (())) with self.assertRaisesRegex( errors_impl.InvalidArgumentError, r"Shape mismatch in tuple component 0. Expected \[\], got \[2\]" ): sess.run( [q.enqueue((input_priority, input_other))], feed_dict={ input_priority: np.array([0, 2], dtype=np.int64), input_other: np.random.rand(3, 5).astype(bytes) }) with self.assertRaisesRegex( errors_impl.InvalidArgumentError, r"Shape mismatch in tuple component 0. Expected \[2\], got \[2,2\]" ): sess.run( [q.enqueue_many((input_priority, input_other))], feed_dict={ input_priority: np.array([[0, 2], [3, 4]], dtype=np.int64), input_other: np.random.rand(2, 3).astype(bytes) })
def testRoundTripInsertReadOnceSorts(self): with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ((), ())) elem = np.random.randint(-5, 5, size=100).astype(np.int64) side_value_0 = np.random.rand(100).astype(bytes) side_value_1 = np.random.rand(100).astype(bytes) enq_list = [ q.enqueue( (e, constant_op.constant(v0), constant_op.constant(v1))) for e, v0, v1 in zip(elem, side_value_0, side_value_1) ] for enq in enq_list: enq.run() deq = q.dequeue_many(100) deq_elem, deq_value_0, deq_value_1 = self.evaluate(deq) allowed = {} missed = set() for e, v0, v1 in zip(elem, side_value_0, side_value_1): if e not in allowed: allowed[e] = set() allowed[e].add((v0, v1)) missed.add((v0, v1)) self.assertAllEqual(deq_elem, sorted(elem)) for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): self.assertTrue((dv0, dv1) in allowed[e]) missed.remove((dv0, dv1)) self.assertEqual(missed, set())
def testRoundTripInsertOnceReadManySorts(self): with self.cached_session(): q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) elem = np.random.randint(-100, 100, size=1000).astype(np.int64) q.enqueue_many((elem, elem)).run() deq_values = np.hstack((q.dequeue_many(100)[0].eval() for _ in range(10))) self.assertAllEqual(deq_values, sorted(elem))
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) num_threads = 40 enqueue_counts = np.random.randint(10, size=num_threads) enqueue_values = [ np.random.randint( 5, size=count) for count in enqueue_counts ] enqueue_ops = [ q.enqueue_many((values, values)) for values in enqueue_values ] shuffled_counts = copy.deepcopy(enqueue_counts) random.shuffle(shuffled_counts) dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts] all_enqueued_values = np.hstack(enqueue_values) dequeue_wait = threading.Condition() # Run one producer thread for each element in elems. def enqueue(enqueue_op): sess.run(enqueue_op) def dequeue(dequeue_op, dequeued): (dequeue_indices, dequeue_values) = sess.run(dequeue_op) self.assertAllEqual(dequeue_indices, dequeue_values) dequeue_wait.acquire() dequeued.extend(dequeue_indices) dequeue_wait.release() dequeued = [] enqueue_threads = [ self.checkedThread( target=enqueue, args=(op,)) for op in enqueue_ops ] dequeue_threads = [ self.checkedThread( target=dequeue, args=(op, dequeued)) for op in dequeue_ops ] for t in enqueue_threads: t.start() for t in enqueue_threads: t.join() # Dequeue and check for t in dequeue_threads: t.start() for t in dequeue_threads: t.join() # We can't guarantee full sorting because we can't guarantee # that the dequeued.extend() call runs immediately after the # sess.run() call. Here we're just happy everything came out. self.assertAllEqual(set(dequeued), set(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ((), ())) elem = np.random.randint(-5, 5, size=100).astype(np.int64) side_value_0 = np.random.rand(100).astype(bytes) side_value_1 = np.random.rand(100).astype(bytes) batch = 5 enqueue_ops = [ q.enqueue_many((elem[i * batch:(i + 1) * batch], side_value_0[i * batch:(i + 1) * batch], side_value_1[i * batch:(i + 1) * batch])) for i in range(20) ] # Run one producer thread for each element in elems. def enqueue(enqueue_op): self.evaluate(enqueue_op) dequeue_op = q.dequeue_many(100) enqueue_threads = [ self.checkedThread(target=enqueue, args=(op, )) for op in enqueue_ops ] for t in enqueue_threads: t.start() deq_elem, deq_value_0, deq_value_1 = self.evaluate(dequeue_op) for t in enqueue_threads: t.join() allowed = {} missed = set() for e, v0, v1 in zip(elem, side_value_0, side_value_1): if e not in allowed: allowed[e] = set() allowed[e].add((v0, v1)) missed.add((v0, v1)) self.assertAllEqual(deq_elem, sorted(elem)) for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): self.assertTrue((dv0, dv1) in allowed[e]) missed.remove((dv0, dv1)) self.assertEqual(missed, set())
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self): # We need each thread to keep its own device stack or the device scopes # won't be properly nested. ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (())) num_threads = 40 enqueue_counts = np.random.randint(10, size=num_threads) enqueue_values = [ np.random.randint( 5, size=count) for count in enqueue_counts ] enqueue_ops = [ q.enqueue_many((values, values)) for values in enqueue_values ] shuffled_counts = copy.deepcopy(enqueue_counts) random.shuffle(shuffled_counts) dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts] all_enqueued_values = np.hstack(enqueue_values) # Run one producer thread for each element in elems. def enqueue(enqueue_op): sess.run(enqueue_op) dequeued = [] def dequeue(dequeue_op): (dequeue_indices, dequeue_values) = sess.run(dequeue_op) self.assertAllEqual(dequeue_indices, dequeue_values) dequeued.extend(dequeue_indices) enqueue_threads = [ self.checkedThread( target=enqueue, args=(op,)) for op in enqueue_ops ] dequeue_threads = [ self.checkedThread( target=dequeue, args=(op,)) for op in dequeue_ops ] # Dequeue and check for t in dequeue_threads: t.start() for t in enqueue_threads: t.start() for t in enqueue_threads: t.join() for t in dequeue_threads: t.join() self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self): with self.test_session() as sess: q = data_flow_ops.PriorityQueue(2000, (tf.string, tf.string), ((), ())) elem = np.random.randint(-5, 5, size=100).astype(np.int64) side_value_0 = np.random.rand(100).astype(bytes) side_value_1 = np.random.rand(100).astype(bytes) batch = 5 enqueue_ops = [ q.enqueue_many((elem[i * batch:(i + 1) * batch], side_value_0[i * batch:(i + 1) * batch], side_value_1[i * batch:(i + 1) * batch])) for i in range(20) ] # Run one producer thread for each element in elems. def enqueue(enqueue_op): sess.run(enqueue_op) dequeue_op = q.dequeue_many(100) enqueue_threads = [ self.checkedThread(target=enqueue, args=(op, )) for op in enqueue_ops ] for t in enqueue_threads: t.start() deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op) for t in enqueue_threads: t.join() allowed = {} missed = set() for e, v0, v1 in zip(elem, side_value_0, side_value_1): if e not in allowed: allowed[e] = set() allowed[e].add((v0, v1)) missed.add((v0, v1)) self.assertAllEqual(deq_elem, sorted(elem)) for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): self.assertTrue((dv0, dv1) in allowed[e]) missed.remove((dv0, dv1)) self.assertEqual(missed, set())
def testRoundTripInsertOnceReadOnceSorts(self): with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( (), ())) elem = np.random.randint(-100, 100, size=1000).astype(np.int64) side_value_0 = np.random.rand(1000).astype(bytes) side_value_1 = np.random.rand(1000).astype(bytes) q.enqueue_many((elem, side_value_0, side_value_1)).run() deq = q.dequeue_many(1000) deq_elem, deq_value_0, deq_value_1 = sess.run(deq) allowed = {} for e, v0, v1 in zip(elem, side_value_0, side_value_1): if e not in allowed: allowed[e] = set() allowed[e].add((v0, v1)) self.assertAllEqual(deq_elem, sorted(elem)) for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): self.assertTrue((dv0, dv1) in allowed[e])
def testInsertingNonInt64Fails(self): with self.cached_session(): q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (())) with self.assertRaises(TypeError): q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()