Exemple #1
0
    def testIncompatibleSharedBarrierErrors(self):
        with self.cached_session():
            # Do component types and shapes.
            b_a_1 = data_flow_ops.Barrier((dtypes.float32, ),
                                          shapes=(()),
                                          shared_name="b_a")
            b_a_2 = data_flow_ops.Barrier((dtypes.int32, ),
                                          shapes=(()),
                                          shared_name="b_a")
            b_a_1.barrier_ref.eval()
            with self.assertRaisesOpError("component types"):
                b_a_2.barrier_ref.eval()

            b_b_1 = data_flow_ops.Barrier((dtypes.float32, ),
                                          shapes=(()),
                                          shared_name="b_b")
            b_b_2 = data_flow_ops.Barrier((dtypes.float32, dtypes.int32),
                                          shapes=((), ()),
                                          shared_name="b_b")
            b_b_1.barrier_ref.eval()
            with self.assertRaisesOpError("component types"):
                b_b_2.barrier_ref.eval()

            b_c_1 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shapes=((2, 2), (8, )),
                                          shared_name="b_c")
            b_c_2 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shared_name="b_c")
            b_c_1.barrier_ref.eval()
            with self.assertRaisesOpError("component shapes"):
                b_c_2.barrier_ref.eval()

            b_d_1 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shapes=((), ()),
                                          shared_name="b_d")
            b_d_2 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shapes=((2, 2), (8, )),
                                          shared_name="b_d")
            b_d_1.barrier_ref.eval()
            with self.assertRaisesOpError("component shapes"):
                b_d_2.barrier_ref.eval()

            b_e_1 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shapes=((2, 2), (8, )),
                                          shared_name="b_e")
            b_e_2 = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                          shapes=((2, 5), (8, )),
                                          shared_name="b_e")
            b_e_1.barrier_ref.eval()
            with self.assertRaisesOpError("component shapes"):
                b_e_2.barrier_ref.eval()
Exemple #2
0
    def testParallelTakeMany(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier(dtypes.float32, shapes=())
            size_t = b.ready_size()
            keys = [str(x).encode("ascii") for x in range(10)]
            values = [float(x) for x in range(10)]
            insert_op = b.insert_many(0, keys, values)
            take_t = [b.take_many(1) for _ in keys]

            insert_op.run()
            self.assertEquals(size_t.eval(), [10])

            index_fetches = []
            key_fetches = []
            value_fetches = []
            for ix_t, k_t, v_t in take_t:
                index_fetches.append(ix_t)
                key_fetches.append(k_t)
                value_fetches.append(v_t[0])
            vals = sess.run(index_fetches + key_fetches + value_fetches)

        index_vals = vals[:len(keys)]
        key_vals = vals[len(keys):2 * len(keys)]
        value_vals = vals[2 * len(keys):]

        taken_elems = []
        for k, v in zip(key_vals, value_vals):
            taken_elems.append((k[0], v[0]))

        self.assertAllEqual(np.hstack(index_vals), [-2**63] * 10)

        self.assertItemsEqual(zip(keys, values),
                              [(k[0], v[0])
                               for k, v in zip(key_vals, value_vals)])
Exemple #3
0
    def testUseBarrierWithShape(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                      shapes=((2, 2), (8, )),
                                      name="B")
            size_t = b.ready_size()
            keys = [b"a", b"b", b"c"]
            values_0 = np.array(
                [[[10.0] * 2] * 2, [[20.0] * 2] * 2, [[30.0] * 2] * 2],
                np.float32)
            values_1 = np.array([[100.0] * 8, [200.0] * 8, [300.0] * 8],
                                np.float32)
            insert_0_op = b.insert_many(0, keys, values_0)
            insert_1_op = b.insert_many(1, keys, values_1)
            take_t = b.take_many(3)

            insert_0_op.run()
            insert_1_op.run()
            self.assertEquals(size_t.eval(), [3])

            indices_val, keys_val, values_0_val, values_1_val = sess.run(
                [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
            self.assertAllEqual(indices_val, [-2**63] * 3)
            self.assertShapeEqual(keys_val, take_t[1])
            self.assertShapeEqual(values_0_val, take_t[2][0])
            self.assertShapeEqual(values_1_val, take_t[2][1])

        for k, v0, v1 in zip(keys, values_0, values_1):
            idx = keys_val.tolist().index(k)
            self.assertAllEqual(values_0_val[idx], v0)
            self.assertAllEqual(values_1_val[idx], v1)
Exemple #4
0
    def testTakeMany(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                      shapes=((), ()),
                                      name="B")
            size_t = b.ready_size()
            keys = [b"a", b"b", b"c"]
            values_0 = [10.0, 20.0, 30.0]
            values_1 = [100.0, 200.0, 300.0]
            insert_0_op = b.insert_many(0, keys, values_0)
            insert_1_op = b.insert_many(1, keys, values_1)
            take_t = b.take_many(3)

            insert_0_op.run()
            insert_1_op.run()
            self.assertEquals(size_t.eval(), [3])

            indices_val, keys_val, values_0_val, values_1_val = sess.run(
                [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])

        self.assertAllEqual(indices_val, [-2**63] * 3)
        for k, v0, v1 in zip(keys, values_0, values_1):
            idx = keys_val.tolist().index(k)
            self.assertEqual(values_0_val[idx], v0)
            self.assertEqual(values_1_val[idx], v1)
Exemple #5
0
 def testConstructorWithShapes(self):
     with ops.Graph().as_default():
         b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                   shapes=((1, 2, 3), (8, )),
                                   shared_name="B",
                                   name="B")
     self.assertTrue(isinstance(b.barrier_ref, ops.Tensor))
     self.assertProtoEquals(
         """
   name:'B' op:'Barrier'
   attr {
     key: "capacity"
     value {
       i: -1
     }
   }
   attr { key: 'component_types'
          value { list { type: DT_FLOAT type: DT_FLOAT } } }
   attr {
     key: 'shapes'
     value {
       list {
         shape {
           dim { size: 1 } dim { size: 2 } dim { size: 3 }
         }
         shape {
           dim { size: 8 }
         }
       }
     }
   }
   attr { key: 'container' value { s: "" } }
   attr { key: 'shared_name' value: { s: 'B' } }
   """, b.barrier_ref.op.node_def)
Exemple #6
0
 def testInsertManyEmptyTensor(self):
   with self.cached_session():
     error_message = ("Empty tensors are not supported, but received shape "
                      r"\'\(0,\)\' at index 1")
     with self.assertRaisesRegex(ValueError, error_message):
       data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
Exemple #7
0
    def testBlockingTakeMany(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier(dtypes.float32, shapes=())
            keys = [str(x).encode("ascii") for x in range(10)]
            values = [float(x) for x in range(10)]
            insert_ops = [
                b.insert_many(0, [k], [v]) for k, v in zip(keys, values)
            ]
            take_t = b.take_many(10)

            def take():
                indices_val, keys_val, values_val = sess.run(
                    [take_t[0], take_t[1], take_t[2][0]])
                self.assertAllEqual(
                    indices_val,
                    [int(x.decode("ascii")) - 2**63 for x in keys_val])
                self.assertItemsEqual(zip(keys, values),
                                      zip(keys_val, values_val))

            t = self.checkedThread(target=take)
            t.start()
            time.sleep(0.1)
            for insert_op in insert_ops:
                insert_op.run()
            t.join()
Exemple #8
0
 def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
   with self.cached_session() as sess:
     b = data_flow_ops.Barrier(
         (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
     take_t = b.take_many(1, allow_small_batch=True)
     self.evaluate(b.close(cancel))
     with self.assertRaisesOpError("is closed and has insufficient elements"):
       self.evaluate(take_t)
Exemple #9
0
  def testClose(self):
    with self.cached_session() as sess:
      b = data_flow_ops.Barrier(
          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
      size_t = b.ready_size()
      incomplete_t = b.incomplete_size()
      keys = [b"a", b"b", b"c"]
      values_0 = [10.0, 20.0, 30.0]
      values_1 = [100.0, 200.0, 300.0]
      insert_0_op = b.insert_many(0, keys, values_0)
      insert_1_op = b.insert_many(1, keys, values_1)
      close_op = b.close()
      fail_insert_op = b.insert_many(0, ["f"], [60.0])
      take_t = b.take_many(3)
      take_too_many_t = b.take_many(4)

      self.assertEqual(self.evaluate(size_t), [0])
      self.assertEqual(self.evaluate(incomplete_t), [0])
      insert_0_op.run()
      self.assertEqual(self.evaluate(size_t), [0])
      self.assertEqual(self.evaluate(incomplete_t), [3])
      close_op.run()

      # This op should fail because the barrier is closed.
      with self.assertRaisesOpError("is closed"):
        fail_insert_op.run()

      # This op should succeed because the barrier has not canceled
      # pending enqueues
      insert_1_op.run()
      self.assertEqual(self.evaluate(size_t), [3])
      self.assertEqual(self.evaluate(incomplete_t), [0])

      # This op should fail because the barrier is closed.
      with self.assertRaisesOpError("is closed"):
        fail_insert_op.run()

      # This op should fail because we requested more elements than are
      # available in incomplete + ready queue.
      with self.assertRaisesOpError(r"is closed and has insufficient elements "
                                    r"\(requested 4, total size 3\)"):
        sess.run(take_too_many_t[0])  # Sufficient to request just the indices

      # This op should succeed because there are still completed elements
      # to process.
      indices_val, keys_val, values_0_val, values_1_val = sess.run(
          [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
      self.assertAllEqual(indices_val, [-2**63] * 3)
      for k, v0, v1 in zip(keys, values_0, values_1):
        idx = keys_val.tolist().index(k)
        self.assertEqual(values_0_val[idx], v0)
        self.assertEqual(values_1_val[idx], v1)

      # This op should fail because there are no more completed elements and
      # the queue is closed.
      with self.assertRaisesOpError("is closed and has insufficient elements"):
        sess.run(take_t[0])
Exemple #10
0
    def testCancel(self):
        with self.test_session() as sess:
            b = data_flow_ops.Barrier((tf.float32, tf.float32),
                                      shapes=((), ()),
                                      name="B")
            size_t = b.ready_size()
            incomplete_t = b.incomplete_size()
            keys = [b"a", b"b", b"c"]
            values_0 = [10.0, 20.0, 30.0]
            values_1 = [100.0, 200.0, 300.0]
            insert_0_op = b.insert_many(0, keys, values_0)
            insert_1_op = b.insert_many(1, keys[0:2], values_1[0:2])
            insert_2_op = b.insert_many(1, keys[2:], values_1[2:])
            cancel_op = b.close(cancel_pending_enqueues=True)
            fail_insert_op = b.insert_many(0, ["f"], [60.0])
            take_t = b.take_many(2)
            take_too_many_t = b.take_many(3)

            self.assertEquals(size_t.eval(), [0])
            insert_0_op.run()
            insert_1_op.run()
            self.assertEquals(size_t.eval(), [2])
            self.assertEquals(incomplete_t.eval(), [1])
            cancel_op.run()

            # This op should fail because the queue is closed.
            with self.assertRaisesOpError("is closed"):
                fail_insert_op.run()

            # This op should fail because the queue is cancelled.
            with self.assertRaisesOpError("is closed"):
                insert_2_op.run()

            # This op should fail because we requested more elements than are
            # available in incomplete + ready queue.
            with self.assertRaisesOpError(
                    r"is closed and has insufficient elements "
                    r"\(requested 3, total size 2\)"):
                sess.run(take_too_many_t[0]
                         )  # Sufficient to request just the indices

            # This op should succeed because there are still completed elements
            # to process.
            indices_val, keys_val, values_0_val, values_1_val = sess.run(
                [take_t[0], take_t[1], take_t[2][0], take_t[2][1]])
            self.assertAllEqual(indices_val, [-2**63] * 2)
            for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
                idx = keys_val.tolist().index(k)
                self.assertEqual(values_0_val[idx], v0)
                self.assertEqual(values_1_val[idx], v1)

            # This op should fail because there are no more completed elements and
            # the queue is closed.
            with self.assertRaisesOpError(
                    "is closed and has insufficient elements"):
                sess.run(take_t[0])
Exemple #11
0
 def testInsertManyEmptyTensorUnknown(self):
   with self.cached_session():
     b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
     size_t = b.ready_size()
     self.assertEqual([], size_t.get_shape())
     keys = [b"a", b"b", b"c"]
     insert_0_op = b.insert_many(0, keys, np.array([[], [], []], np.float32))
     self.assertEqual(self.evaluate(size_t), [0])
     with self.assertRaisesOpError(
         ".*Tensors with no elements are not supported.*"):
       insert_0_op.run()
Exemple #12
0
  def testInsertMany(self):
    with self.test_session():
      ba = data_flow_ops.Barrier(
          (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
      size_t = ba.ready_size()
      self.assertEqual([], size_t.get_shape())
      keys = [b"a", b"b", b"c"]
      insert_0_op = ba.insert_many(0, keys, [10.0, 20.0, 30.0])
      insert_1_op = ba.insert_many(1, keys, [100.0, 200.0, 300.0])

      self.assertEquals(size_t.eval(), [0])
      insert_0_op.run()
      self.assertEquals(size_t.eval(), [0])
      insert_1_op.run()
      self.assertEquals(size_t.eval(), [3])
Exemple #13
0
  def testParallelInsertMany(self):
    with self.cached_session() as sess:
      b = data_flow_ops.Barrier(dtypes.float32, shapes=())
      size_t = b.ready_size()
      keys = [str(x).encode("ascii") for x in range(10)]
      values = [float(x) for x in range(10)]
      insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
      take_t = b.take_many(10)

      self.evaluate(insert_ops)
      self.assertEqual(self.evaluate(size_t), [10])

      indices_val, keys_val, values_val = sess.run(
          [take_t[0], take_t[1], take_t[2][0]])

    self.assertAllEqual(indices_val, [-2**63 + x for x in range(10)])
    for k, v in zip(keys, values):
      idx = keys_val.tolist().index(k)
      self.assertEqual(values_val[idx], v)
Exemple #14
0
    def testTakeManySmallBatch(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32),
                                      shapes=((), ()),
                                      name="B")
            size_t = b.ready_size()
            size_i = b.incomplete_size()
            keys = [b"a", b"b", b"c", b"d"]
            values_0 = [10.0, 20.0, 30.0, 40.0]
            values_1 = [100.0, 200.0, 300.0, 400.0]
            insert_0_op = b.insert_many(0, keys, values_0)
            # Split adding of the second component into two independent operations.
            # After insert_1_1_op, we'll have two ready elements in the barrier,
            # 2 will still be incomplete.
            insert_1_1_op = b.insert_many(1, keys[0:2],
                                          values_1[0:2])  # add "a", "b"
            insert_1_2_op = b.insert_many(1, keys[2:3],
                                          values_1[2:3])  # add "c"
            insert_1_3_op = b.insert_many(1, keys[3:], values_1[3:])  # add "d"
            insert_empty_op = b.insert_many(0, [], [])
            close_op = b.close()
            close_op_final = b.close(cancel_pending_enqueues=True)
            index_t, key_t, value_list_t = b.take_many(3,
                                                       allow_small_batch=True)
            insert_0_op.run()
            insert_1_1_op.run()
            close_op.run()
            # Now we have a closed barrier with 2 ready elements. Running take_t
            # should return a reduced batch with 2 elements only.
            self.assertEquals(size_i.eval(),
                              [2])  # assert that incomplete size = 2
            self.assertEquals(size_t.eval(), [2])  # assert that ready size = 2
            _, keys_val, values_0_val, values_1_val = sess.run(
                [index_t, key_t, value_list_t[0], value_list_t[1]])
            # Check that correct values have been returned.
            for k, v0, v1 in zip(keys[0:2], values_0[0:2], values_1[0:2]):
                idx = keys_val.tolist().index(k)
                self.assertEqual(values_0_val[idx], v0)
                self.assertEqual(values_1_val[idx], v1)

            # The next insert completes the element with key "c". The next take_t
            # should return a batch with just 1 element.
            insert_1_2_op.run()
            self.assertEquals(size_i.eval(),
                              [1])  # assert that incomplete size = 1
            self.assertEquals(size_t.eval(), [1])  # assert that ready size = 1
            _, keys_val, values_0_val, values_1_val = sess.run(
                [index_t, key_t, value_list_t[0], value_list_t[1]])
            # Check that correct values have been returned.
            for k, v0, v1 in zip(keys[2:3], values_0[2:3], values_1[2:3]):
                idx = keys_val.tolist().index(k)
                self.assertEqual(values_0_val[idx], v0)
                self.assertEqual(values_1_val[idx], v1)

            # Adding nothing ought to work, even if the barrier is closed.
            insert_empty_op.run()

            # currently keys "a" and "b" are not in the barrier, adding them
            # again after it has been closed, ought to cause failure.
            with self.assertRaisesOpError("is closed"):
                insert_1_1_op.run()
            close_op_final.run()

            # These ops should fail because the barrier has now been closed with
            # cancel_pending_enqueues = True.
            with self.assertRaisesOpError("is closed"):
                insert_empty_op.run()
            with self.assertRaisesOpError("is closed"):
                insert_1_3_op.run()
Exemple #15
0
    def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(
            self, cancel):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.int64),
                                      shapes=((), (2, )))
            num_iterations = 100
            keys = [str(x) for x in range(10)]
            values_0 = np.asarray(range(10), dtype=np.float32)
            values_1 = np.asarray([[x + 1, x + 2] for x in range(10)],
                                  dtype=np.int64)
            keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii")
                                for k in keys]
            insert_0_ops = [
                b.insert_many(0,
                              keys_i(i),
                              values_0 + i,
                              name="insert_0_%d" % i)
                for i in range(num_iterations)
            ]

            close_op = b.close(cancel_pending_enqueues=cancel)

            take_ops = [
                b.take_many(10, name="take_%d" % i)
                for i in range(num_iterations)
            ]
            # insert_1_ops will only run after closure
            insert_1_ops = [
                b.insert_many(1,
                              keys_i(i),
                              values_1 + i,
                              name="insert_1_%d" % i)
                for i in range(num_iterations)
            ]

            def take(sess, i, taken):
                if cancel:
                    try:
                        indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run(
                            [
                                take_ops[i][0], take_ops[i][1],
                                take_ops[i][2][0], take_ops[i][2][1]
                            ])
                        taken.append(len(indices_val))
                    except errors_impl.OutOfRangeError:
                        taken.append(0)
                else:
                    indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run(
                        [
                            take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
                            take_ops[i][2][1]
                        ])
                    taken.append(len(indices_val))

            def insert_0(sess, i):
                insert_0_ops[i].run(session=sess)

            def insert_1(sess, i):
                if cancel:
                    try:
                        insert_1_ops[i].run(session=sess)
                    except errors_impl.CancelledError:
                        pass
                else:
                    insert_1_ops[i].run(session=sess)

            taken = []

            take_threads = [
                self.checkedThread(target=take, args=(sess, i, taken))
                for i in range(num_iterations)
            ]
            insert_0_threads = [
                self.checkedThread(target=insert_0, args=(sess, i))
                for i in range(num_iterations)
            ]
            insert_1_threads = [
                self.checkedThread(target=insert_1, args=(sess, i))
                for i in range(num_iterations)
            ]

            for t in insert_0_threads:
                t.start()
            for t in insert_0_threads:
                t.join()
            for t in take_threads:
                t.start()

            close_op.run()

            for t in insert_1_threads:
                t.start()
            for t in take_threads:
                t.join()
            for t in insert_1_threads:
                t.join()

            if cancel:
                self.assertEqual(taken, [0] * num_iterations)
            else:
                self.assertEqual(taken, [10] * num_iterations)
Exemple #16
0

# >>> np.square(wte - wte2).mean()
# 4.7295091816307575e-05

# >>> np.square(wpe - wpe2).mean()
# 0.00022108801059667483


from importlib import reload
import tf_tools as tft
reload(tft)


from tensorflow.python.ops import data_flow_ops
barrier = data_flow_ops.Barrier((tf.string, tf.int32), shapes=((), ()))
barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
barrier.insert_many(1, keys=["k1"], values=[1]).run()
barrier.insert_many(0, keys=["k3"], values=["c"]).run()
barrier.insert_many(1, keys=["k3"], values=[3]).run()
barrier.insert_many(1, keys=["k2"], values=[2]).run()
r(barrier.take_many(2))



#acc1 = tf.SparseConditionalAccumulator(dtype=tf.float32)
acc1 = tft.SparseSum()
r(acc1.apply_grad([42, 69], [42.0, 420.69]))
r(acc1.apply_grad([42, 69, 128], [42.0, 420.69, 4.0]))
r(acc1.take())
Exemple #17
0
    def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.int64),
                                      shapes=((), (2, )))
            num_iterations = 50
            keys = [str(x) for x in range(10)]
            values_0 = np.asarray(range(10), dtype=np.float32)
            values_1 = np.asarray([[x + 1, x + 2] for x in range(10)],
                                  dtype=np.int64)
            keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii")
                                for k in keys]
            insert_0_ops = [
                b.insert_many(0, keys_i(i), values_0 + i)
                for i in range(num_iterations)
            ]
            insert_1_ops = [
                b.insert_many(1, keys_i(i), values_1 + i)
                for i in range(num_iterations)
            ]
            take_ops = [b.take_many(10) for _ in range(num_iterations)]
            close_op = b.close(cancel_pending_enqueues=cancel)

            def take(sess, i, taken):
                try:
                    indices_val, unused_keys_val, unused_val_0, unused_val_1 = sess.run(
                        [
                            take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
                            take_ops[i][2][1]
                        ])
                    taken.append(len(indices_val))
                except errors_impl.OutOfRangeError:
                    taken.append(0)

            def insert(sess, i):
                try:
                    sess.run([insert_0_ops[i], insert_1_ops[i]])
                except errors_impl.CancelledError:
                    pass

            taken = []

            take_threads = [
                self.checkedThread(target=take, args=(sess, i, taken))
                for i in range(num_iterations)
            ]
            insert_threads = [
                self.checkedThread(target=insert, args=(sess, i))
                for i in range(num_iterations)
            ]

            first_half_insert_threads = insert_threads[:num_iterations // 2]
            second_half_insert_threads = insert_threads[num_iterations // 2:]

            for t in take_threads:
                t.start()
            for t in first_half_insert_threads:
                t.start()
            for t in first_half_insert_threads:
                t.join()

            close_op.run()

            for t in second_half_insert_threads:
                t.start()
            for t in take_threads:
                t.join()
            for t in second_half_insert_threads:
                t.join()

            self.assertEqual(sorted(taken), [0] * (num_iterations // 2) +
                             [10] * (num_iterations // 2))
Exemple #18
0
    def testParallelInsertManyTakeMany(self):
        with self.cached_session() as sess:
            b = data_flow_ops.Barrier((dtypes.float32, dtypes.int64),
                                      shapes=((), (2, )))
            num_iterations = 100
            keys = [str(x) for x in range(10)]
            values_0 = np.asarray(range(10), dtype=np.float32)
            values_1 = np.asarray([[x + 1, x + 2] for x in range(10)],
                                  dtype=np.int64)
            keys_i = lambda i: [("%d:%s" % (i, k)).encode("ascii")
                                for k in keys]
            insert_0_ops = [
                b.insert_many(0, keys_i(i), values_0 + i)
                for i in range(num_iterations)
            ]
            insert_1_ops = [
                b.insert_many(1, keys_i(i), values_1 + i)
                for i in range(num_iterations)
            ]
            take_ops = [b.take_many(10) for _ in range(num_iterations)]

            def take(sess, i, taken):
                indices_val, keys_val, values_0_val, values_1_val = sess.run([
                    take_ops[i][0], take_ops[i][1], take_ops[i][2][0],
                    take_ops[i][2][1]
                ])
                taken.append({
                    "indices": indices_val,
                    "keys": keys_val,
                    "values_0": values_0_val,
                    "values_1": values_1_val
                })

            def insert(sess, i):
                sess.run([insert_0_ops[i], insert_1_ops[i]])

            taken = []

            take_threads = [
                self.checkedThread(target=take, args=(sess, i, taken))
                for i in range(num_iterations)
            ]
            insert_threads = [
                self.checkedThread(target=insert, args=(sess, i))
                for i in range(num_iterations)
            ]

            for t in take_threads:
                t.start()
            time.sleep(0.1)
            for t in insert_threads:
                t.start()
            for t in take_threads:
                t.join()
            for t in insert_threads:
                t.join()

            self.assertEquals(len(taken), num_iterations)
            flatten = lambda l: [item for sublist in l for item in sublist]
            all_indices = sorted(flatten([t_i["indices"] for t_i in taken]))
            all_keys = sorted(flatten([t_i["keys"] for t_i in taken]))

            expected_keys = sorted(
                flatten([keys_i(i) for i in range(num_iterations)]))
            expected_indices = sorted(
                flatten([-2**63 + j] * 10 for j in range(num_iterations)))

            self.assertAllEqual(all_indices, expected_indices)
            self.assertAllEqual(all_keys, expected_keys)

            for taken_i in taken:
                outer_indices_from_keys = np.array([
                    int(k.decode("ascii").split(":")[0])
                    for k in taken_i["keys"]
                ])
                inner_indices_from_keys = np.array([
                    int(k.decode("ascii").split(":")[1])
                    for k in taken_i["keys"]
                ])
                self.assertAllEqual(
                    taken_i["values_0"],
                    outer_indices_from_keys + inner_indices_from_keys)
                expected_values_1 = np.vstack(
                    (1 + outer_indices_from_keys + inner_indices_from_keys,
                     2 + outer_indices_from_keys + inner_indices_from_keys)).T
                self.assertAllEqual(taken_i["values_1"], expected_values_1)