Beispiel #1
0
    def testInsertingNonScalarFails(self):
        with self.cached_session() as sess:
            input_priority = array_ops.placeholder(dtypes.int64)
            input_other = array_ops.placeholder(dtypes.string)
            q = data_flow_ops.PriorityQueue(2000, (dtypes.string, ), (()))

            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    r"Shape mismatch in tuple component 0. Expected \[\], got \[2\]"
            ):
                sess.run(
                    [q.enqueue((input_priority, input_other))],
                    feed_dict={
                        input_priority: np.array([0, 2], dtype=np.int64),
                        input_other: np.random.rand(3, 5).astype(bytes)
                    })

            with self.assertRaisesRegex(
                    errors_impl.InvalidArgumentError,
                    r"Shape mismatch in tuple component 0. Expected \[2\], got \[2,2\]"
            ):
                sess.run(
                    [q.enqueue_many((input_priority, input_other))],
                    feed_dict={
                        input_priority: np.array([[0, 2], [3, 4]],
                                                 dtype=np.int64),
                        input_other: np.random.rand(2, 3).astype(bytes)
                    })
Beispiel #2
0
    def testRoundTripInsertReadOnceSorts(self):
        with self.cached_session() as sess:
            q = data_flow_ops.PriorityQueue(2000,
                                            (dtypes.string, dtypes.string),
                                            ((), ()))
            elem = np.random.randint(-5, 5, size=100).astype(np.int64)
            side_value_0 = np.random.rand(100).astype(bytes)
            side_value_1 = np.random.rand(100).astype(bytes)
            enq_list = [
                q.enqueue(
                    (e, constant_op.constant(v0), constant_op.constant(v1)))
                for e, v0, v1 in zip(elem, side_value_0, side_value_1)
            ]
            for enq in enq_list:
                enq.run()

            deq = q.dequeue_many(100)
            deq_elem, deq_value_0, deq_value_1 = self.evaluate(deq)

            allowed = {}
            missed = set()
            for e, v0, v1 in zip(elem, side_value_0, side_value_1):
                if e not in allowed:
                    allowed[e] = set()
                allowed[e].add((v0, v1))
                missed.add((v0, v1))

            self.assertAllEqual(deq_elem, sorted(elem))
            for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
                self.assertTrue((dv0, dv1) in allowed[e])
                missed.remove((dv0, dv1))
            self.assertEqual(missed, set())
 def testRoundTripInsertOnceReadManySorts(self):
   with self.cached_session():
     q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
     elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
     q.enqueue_many((elem, elem)).run()
     deq_values = np.hstack((q.dequeue_many(100)[0].eval() for _ in range(10)))
     self.assertAllEqual(deq_values, sorted(elem))
  def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))

      num_threads = 40
      enqueue_counts = np.random.randint(10, size=num_threads)
      enqueue_values = [
          np.random.randint(
              5, size=count) for count in enqueue_counts
      ]
      enqueue_ops = [
          q.enqueue_many((values, values)) for values in enqueue_values
      ]
      shuffled_counts = copy.deepcopy(enqueue_counts)
      random.shuffle(shuffled_counts)
      dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts]
      all_enqueued_values = np.hstack(enqueue_values)

      dequeue_wait = threading.Condition()

      # Run one producer thread for each element in elems.
      def enqueue(enqueue_op):
        sess.run(enqueue_op)

      def dequeue(dequeue_op, dequeued):
        (dequeue_indices, dequeue_values) = sess.run(dequeue_op)
        self.assertAllEqual(dequeue_indices, dequeue_values)
        dequeue_wait.acquire()
        dequeued.extend(dequeue_indices)
        dequeue_wait.release()

      dequeued = []
      enqueue_threads = [
          self.checkedThread(
              target=enqueue, args=(op,)) for op in enqueue_ops
      ]
      dequeue_threads = [
          self.checkedThread(
              target=dequeue, args=(op, dequeued)) for op in dequeue_ops
      ]

      for t in enqueue_threads:
        t.start()
      for t in enqueue_threads:
        t.join()
      # Dequeue and check
      for t in dequeue_threads:
        t.start()
      for t in dequeue_threads:
        t.join()

      # We can't guarantee full sorting because we can't guarantee
      # that the dequeued.extend() call runs immediately after the
      # sess.run() call.  Here we're just happy everything came out.
      self.assertAllEqual(set(dequeued), set(all_enqueued_values))
Beispiel #5
0
    def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
        # We need each thread to keep its own device stack or the device scopes
        # won't be properly nested.
        ops.get_default_graph().switch_to_thread_local()
        with self.cached_session() as sess:
            q = data_flow_ops.PriorityQueue(2000,
                                            (dtypes.string, dtypes.string),
                                            ((), ()))
            elem = np.random.randint(-5, 5, size=100).astype(np.int64)
            side_value_0 = np.random.rand(100).astype(bytes)
            side_value_1 = np.random.rand(100).astype(bytes)

            batch = 5
            enqueue_ops = [
                q.enqueue_many((elem[i * batch:(i + 1) * batch],
                                side_value_0[i * batch:(i + 1) * batch],
                                side_value_1[i * batch:(i + 1) * batch]))
                for i in range(20)
            ]

            # Run one producer thread for each element in elems.
            def enqueue(enqueue_op):
                self.evaluate(enqueue_op)

            dequeue_op = q.dequeue_many(100)

            enqueue_threads = [
                self.checkedThread(target=enqueue, args=(op, ))
                for op in enqueue_ops
            ]

            for t in enqueue_threads:
                t.start()

            deq_elem, deq_value_0, deq_value_1 = self.evaluate(dequeue_op)

            for t in enqueue_threads:
                t.join()

            allowed = {}
            missed = set()
            for e, v0, v1 in zip(elem, side_value_0, side_value_1):
                if e not in allowed:
                    allowed[e] = set()
                allowed[e].add((v0, v1))
                missed.add((v0, v1))

            self.assertAllEqual(deq_elem, sorted(elem))
            for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
                self.assertTrue((dv0, dv1) in allowed[e])
                missed.remove((dv0, dv1))
            self.assertEqual(missed, set())
  def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
    # We need each thread to keep its own device stack or the device scopes
    # won't be properly nested.
    ops.get_default_graph().switch_to_thread_local()
    with self.cached_session() as sess:
      q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))

      num_threads = 40
      enqueue_counts = np.random.randint(10, size=num_threads)
      enqueue_values = [
          np.random.randint(
              5, size=count) for count in enqueue_counts
      ]
      enqueue_ops = [
          q.enqueue_many((values, values)) for values in enqueue_values
      ]
      shuffled_counts = copy.deepcopy(enqueue_counts)
      random.shuffle(shuffled_counts)
      dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts]
      all_enqueued_values = np.hstack(enqueue_values)

      # Run one producer thread for each element in elems.
      def enqueue(enqueue_op):
        sess.run(enqueue_op)

      dequeued = []

      def dequeue(dequeue_op):
        (dequeue_indices, dequeue_values) = sess.run(dequeue_op)
        self.assertAllEqual(dequeue_indices, dequeue_values)
        dequeued.extend(dequeue_indices)

      enqueue_threads = [
          self.checkedThread(
              target=enqueue, args=(op,)) for op in enqueue_ops
      ]
      dequeue_threads = [
          self.checkedThread(
              target=dequeue, args=(op,)) for op in dequeue_ops
      ]

      # Dequeue and check
      for t in dequeue_threads:
        t.start()
      for t in enqueue_threads:
        t.start()
      for t in enqueue_threads:
        t.join()
      for t in dequeue_threads:
        t.join()

      self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
Beispiel #7
0
    def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
        with self.test_session() as sess:
            q = data_flow_ops.PriorityQueue(2000, (tf.string, tf.string),
                                            ((), ()))
            elem = np.random.randint(-5, 5, size=100).astype(np.int64)
            side_value_0 = np.random.rand(100).astype(bytes)
            side_value_1 = np.random.rand(100).astype(bytes)

            batch = 5
            enqueue_ops = [
                q.enqueue_many((elem[i * batch:(i + 1) * batch],
                                side_value_0[i * batch:(i + 1) * batch],
                                side_value_1[i * batch:(i + 1) * batch]))
                for i in range(20)
            ]

            # Run one producer thread for each element in elems.
            def enqueue(enqueue_op):
                sess.run(enqueue_op)

            dequeue_op = q.dequeue_many(100)

            enqueue_threads = [
                self.checkedThread(target=enqueue, args=(op, ))
                for op in enqueue_ops
            ]

            for t in enqueue_threads:
                t.start()

            deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op)

            for t in enqueue_threads:
                t.join()

            allowed = {}
            missed = set()
            for e, v0, v1 in zip(elem, side_value_0, side_value_1):
                if e not in allowed:
                    allowed[e] = set()
                allowed[e].add((v0, v1))
                missed.add((v0, v1))

            self.assertAllEqual(deq_elem, sorted(elem))
            for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
                self.assertTrue((dv0, dv1) in allowed[e])
                missed.remove((dv0, dv1))
            self.assertEqual(missed, set())
  def testRoundTripInsertOnceReadOnceSorts(self):
    with self.cached_session() as sess:
      q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
          (), ()))
      elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
      side_value_0 = np.random.rand(1000).astype(bytes)
      side_value_1 = np.random.rand(1000).astype(bytes)
      q.enqueue_many((elem, side_value_0, side_value_1)).run()
      deq = q.dequeue_many(1000)
      deq_elem, deq_value_0, deq_value_1 = sess.run(deq)

      allowed = {}
      for e, v0, v1 in zip(elem, side_value_0, side_value_1):
        if e not in allowed:
          allowed[e] = set()
        allowed[e].add((v0, v1))

      self.assertAllEqual(deq_elem, sorted(elem))
      for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
        self.assertTrue((dv0, dv1) in allowed[e])
Beispiel #9
0
 def testInsertingNonInt64Fails(self):
     with self.cached_session():
         q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (()))
         with self.assertRaises(TypeError):
             q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()