示例#1
0
  def _testStateSaverFailsIfCapacityTooSmall(self, batch_size):
    with self.cached_session() as sess:
      num_unroll = 2
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5)),
          "seq2": array_ops.placeholder(
              dtypes.float32, shape=(None,))
      }
      context = {}
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states,
          capacity=10)

      sess.run([state_saver.prefetch_op],
               feed_dict={
                   length: 1,
                   key: "key",
                   sequences["seq1"]: np.random.rand(num_unroll, 5),
                   sequences["seq2"]: np.random.rand(num_unroll),
                   initial_states["state1"]: 1.0
               })
示例#2
0
 def testStateSaverScopeNames(self):
   batch_size = constant_op.constant(2)
   sqss_scope_name = "unique_scope_name_for_sqss"
   num_unroll = 2
   length = 3
   key = string_ops.string_join([
       "key_", string_ops.as_string(
           math_ops.cast(10000 * random_ops.random_uniform(()), dtypes.int32))
   ])
   padded_length = 4
   sequences = {
       "seq1": np.random.rand(padded_length, 5),
       "seq2": np.random.rand(padded_length, 4, 2)
   }
   context = {"context1": [3, 4]}
   initial_states = {
       "state1": np.random.rand(6, 7),
       "state2": np.random.rand(8)
   }
   state_saver = sqss.SequenceQueueingStateSaver(
       batch_size=batch_size,
       num_unroll=num_unroll,
       input_length=length,
       input_key=key,
       input_sequences=sequences,
       input_context=context,
       initial_states=initial_states,
       name=sqss_scope_name)
   prefetch_op = state_saver.prefetch_op
   next_batch = state_saver.next_batch
   self.assertTrue(
       state_saver.barrier.barrier_ref.name.startswith("%s/" %
                                                       sqss_scope_name))
   self.assertTrue(prefetch_op.name.startswith("%s/" % sqss_scope_name))
   self.assertTrue(next_batch.key.name.startswith("%s/" % sqss_scope_name))
示例#3
0
  def testStateSaverFailsIfPaddedLengthIsNotMultipleOfNumUnroll(self):
    with self.cached_session() as sess:
      batch_size = constant_op.constant(32)
      num_unroll = 17
      bad_padded_length = 3
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5))
      }
      context = {}
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states)

      with self.assertRaisesOpError(
          "should be a multiple of: 17, but saw value: %d" % bad_padded_length):
        sess.run([state_saver.prefetch_op],
                 feed_dict={
                     length: 1,
                     key: "key",
                     sequences["seq1"]: np.random.rand(bad_padded_length, 5),
                     initial_states["state1"]: 1.0
                 })
示例#4
0
 def testStateSaverFailsIfInconsistentWriteState(self):
   # TODO(b/26910386): Identify why this infrequently causes timeouts.
   with self.cached_session() as sess:
     batch_size = constant_op.constant(1)
     num_unroll = 17
     length = array_ops.placeholder(dtypes.int32)
     key = array_ops.placeholder(dtypes.string)
     sequences = {
         "seq1": array_ops.placeholder(
             dtypes.float32, shape=(None, 5))
     }
     context = {}
     initial_states = {
         "state1": array_ops.placeholder(
             dtypes.float32, shape=())
     }
     state_saver = sqss.SequenceQueueingStateSaver(
         batch_size=batch_size,
         num_unroll=num_unroll,
         input_length=length,
         input_key=key,
         input_sequences=sequences,
         input_context=context,
         initial_states=initial_states)
     next_batch = state_saver.next_batch
     with self.assertRaisesRegexp(KeyError, "state was not declared: state2"):
       save_op = next_batch.save_state("state2", None)
     with self.assertRaisesRegexp(ValueError, "Rank check failed for.*state1"):
       save_op = next_batch.save_state("state1", np.random.rand(1, 1))
     with self.assertRaisesOpError(
         r"convert_state1:0 should be: 1, shape received:\] \[1 1\]"):
       state_input = array_ops.placeholder(dtypes.float32)
       with ops.control_dependencies([state_saver.prefetch_op]):
         save_op = next_batch.save_state("state1", state_input)
       sess.run([save_op],
                feed_dict={
                    length: 1,
                    key: "key",
                    sequences["seq1"]: np.random.rand(num_unroll, 5),
                    initial_states["state1"]: 1.0,
                    state_input: np.random.rand(1, 1)
                })
示例#5
0
  def testStateSaverFailsIfInconsistentPaddedLength(self):
    with self.cached_session() as sess:
      batch_size = constant_op.constant(32)
      num_unroll = 17
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5)),
          "seq2": array_ops.placeholder(
              dtypes.float32, shape=(None,))
      }
      context = {}
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states)

      with self.assertRaisesOpError(
          "Dimension 0 of tensor labeled sorted_sequences_seq2 "
          "should be: %d, shape received: %d" % (num_unroll, 2 * num_unroll)):
        sess.run([state_saver.prefetch_op],
                 feed_dict={
                     length: 1,
                     key: "key",
                     sequences["seq1"]: np.random.rand(num_unroll, 5),
                     sequences["seq2"]: np.random.rand(2 * num_unroll),
                     initial_states["state1"]: 1.0
                 })
示例#6
0
  def testStateSaverCanHandleVariableBatchsize(self):
    with self.cached_session() as sess:
      batch_size = array_ops.placeholder(dtypes.int32)
      num_unroll = 17
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5))
      }
      context = {"context1": array_ops.placeholder(dtypes.string, shape=(3, 4))}
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states)
      next_batch = state_saver.next_batch

      update = next_batch.save_state("state1", 1 + next_batch.state("state1"))

      for insert_key in range(128):
        # Insert varying length inputs
        sess.run([state_saver.prefetch_op],
                 feed_dict={
                     length: np.random.randint(2 * num_unroll),
                     key: "%05d" % insert_key,
                     sequences["seq1"]: np.random.rand(2 * num_unroll, 5),
                     context["context1"]: np.random.rand(3, 4).astype(np.str),
                     initial_states["state1"]: 0.0
                 })

      all_received_indices = []
      # Pull out and validate batch sizes 0, 1, ..., 7
      for batch_size_value in range(8):
        got_keys, input_index, context1, seq1, state1, _ = sess.run(
            [
                next_batch.key, next_batch.insertion_index,
                next_batch.context["context1"], next_batch.sequences["seq1"],
                next_batch.state("state1"), update
            ],
            feed_dict={batch_size: batch_size_value})
        # Indices may have come in out of order within the batch
        all_received_indices.append(input_index.tolist())
        self.assertEqual(got_keys.size, batch_size_value)
        self.assertEqual(input_index.size, batch_size_value)
        self.assertEqual(context1.shape, (batch_size_value, 3, 4))
        self.assertEqual(seq1.shape, (batch_size_value, num_unroll, 5))
        self.assertEqual(state1.shape, (batch_size_value,))

      # Each input was split into 2 iterations (sequences size == 2*num_unroll)
      expected_indices = [[], [0], [0, 1], [1, 2, 3], [2, 3, 4, 5],
                          [4, 5, 6, 7, 8], [6, 7, 8, 9, 10, 11],
                          [9, 10, 11, 12, 13, 14, 15]]
      self.assertEqual(len(all_received_indices), len(expected_indices))
      for received, expected in zip(all_received_indices, expected_indices):
        self.assertAllEqual([x + 2**63 for x in received], expected)
示例#7
0
  def testStateSaverWithTwoSimpleSteps(self):
    with self.cached_session() as sess:
      batch_size_value = 2
      batch_size = constant_op.constant(batch_size_value)
      num_unroll = 2
      length = 3
      key = string_ops.string_join([
          "key_", string_ops.as_string(
              math_ops.cast(10000 * random_ops.random_uniform(()),
                            dtypes.int32))
      ])
      padded_length = 4
      sequences = {
          "seq1": np.random.rand(padded_length, 5),
          "seq2": np.random.rand(padded_length, 4, 2)
      }
      context = {"context1": [3, 4]}
      initial_states = {
          "state1": np.random.rand(6, 7),
          "state2": np.random.rand(8)
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states,
          capacity=100)

      initial_key_value_0, _ = sess.run((key, state_saver.prefetch_op))
      initial_key_value_1, _ = sess.run((key, state_saver.prefetch_op))

      initial_key_value_0 = initial_key_value_0.decode("ascii")
      initial_key_value_1 = initial_key_value_1.decode("ascii")

      # Step 1
      next_batch = state_saver.next_batch
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      expected_first_keys = set(
          ("00000_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_second_keys = set(
          ("00001_of_00002:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))
      expected_final_keys = set(
          ("STOP:%s" % x).encode("ascii")
          for x in (initial_key_value_0, initial_key_value_1))

      self.assertEqual(set(key_value), expected_first_keys)
      self.assertEqual(set(next_key_value), expected_second_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 0:2, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 0:2, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value,
                          np.tile(initial_states["state1"],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value,
                          np.tile(initial_states["state2"],
                                  (batch_size_value, 1)))
      self.assertAllEqual(length_value, [2, 2])

      # Step 2
      (key_value, next_key_value, seq1_value, seq2_value, context1_value,
       state1_value, state2_value, length_value, _, _) = sess.run(
           (next_batch.key, next_batch.next_key, next_batch.sequences["seq1"],
            next_batch.sequences["seq2"], next_batch.context["context1"],
            next_batch.state("state1"), next_batch.state("state2"),
            next_batch.length,
            next_batch.save_state("state1", next_batch.state("state1") + 1),
            next_batch.save_state("state2", next_batch.state("state2") - 1)))

      self.assertEqual(set(key_value), expected_second_keys)
      self.assertEqual(set(next_key_value), expected_final_keys)
      self.assertAllEqual(context1_value,
                          np.tile(context["context1"], (batch_size_value, 1)))
      self.assertAllEqual(seq1_value,
                          np.tile(sequences["seq1"][np.newaxis, 2:4, :],
                                  (batch_size_value, 1, 1)))
      self.assertAllEqual(seq2_value,
                          np.tile(sequences["seq2"][np.newaxis, 2:4, :, :],
                                  (batch_size_value, 1, 1, 1)))
      self.assertAllEqual(state1_value, 1 + np.tile(initial_states["state1"],
                                                    (batch_size_value, 1, 1)))
      self.assertAllEqual(state2_value, -1 + np.tile(initial_states["state2"],
                                                     (batch_size_value, 1)))
      self.assertAllEqual(length_value, [1, 1])

      # Finished.  Let's make sure there's nothing left in the barrier.
      self.assertEqual(0, state_saver.barrier.ready_size().eval())
示例#8
0
  def testStateSaverProcessesExamplesInOrder(self):
    with self.cached_session() as sess:
      batch_size_value = 32
      batch_size = constant_op.constant(batch_size_value)
      num_unroll = 17
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5))
      }
      context = {"context1": array_ops.placeholder(dtypes.string, shape=(3, 4))}
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states)
      next_batch = state_saver.next_batch

      update = next_batch.save_state("state1", 1 + next_batch.state("state1"))
      get_ready_size = state_saver.barrier.ready_size()
      get_incomplete_size = state_saver.barrier.incomplete_size()

      global_insert_key = [0]

      def insert(insert_key):
        # Insert varying length inputs
        sess.run([state_saver.prefetch_op],
                 feed_dict={
                     length: np.random.randint(2 * num_unroll),
                     key: "%05d" % insert_key[0],
                     sequences["seq1"]: np.random.rand(2 * num_unroll, 5),
                     context["context1"]: np.random.rand(3, 4).astype(np.str),
                     initial_states["state1"]: 0.0
                 })
        insert_key[0] += 1

      for _ in range(batch_size_value * 100):
        insert(global_insert_key)

      def process_and_validate(check_key):
        true_step = int(check_key[0] / 2)  # Each entry has two slices
        check_key[0] += 1
        got_keys, input_index, _ = sess.run(
            [next_batch.key, next_batch.insertion_index, update])
        decoded_keys = [int(x.decode("ascii").split(":")[-1]) for x in got_keys]
        min_key = min(decoded_keys)
        min_index = int(min(input_index))  # numpy scalar
        max_key = max(decoded_keys)
        max_index = int(max(input_index))  # numpy scalar
        # The current min key should be above the previous min
        self.assertEqual(min_key, true_step * batch_size_value)
        self.assertEqual(max_key, (true_step + 1) * batch_size_value - 1)
        self.assertEqual(2**63 + min_index, true_step * batch_size_value)
        self.assertEqual(2**63 + max_index,
                         (true_step + 1) * batch_size_value - 1)

      # There are now (batch_size * 100 * 2) / batch_size = 200 full steps
      global_step_key = [0]
      for _ in range(200):
        process_and_validate(global_step_key)

      # Processed everything in the queue
      self.assertEqual(get_incomplete_size.eval(), 0)
      self.assertEqual(get_ready_size.eval(), 0)
示例#9
0
  def testStateSaverWithManyInputsReadWriteThread(self):
    batch_size_value = 32
    num_proc_threads = 100
    with self.cached_session() as sess:
      batch_size = constant_op.constant(batch_size_value)
      num_unroll = 17
      length = array_ops.placeholder(dtypes.int32)
      key = array_ops.placeholder(dtypes.string)
      sequences = {
          "seq1": array_ops.placeholder(
              dtypes.float32, shape=(None, 5)),
          "seq2": array_ops.placeholder(
              dtypes.float32, shape=(None, 4, 2)),
          "seq3": array_ops.placeholder(
              dtypes.float64, shape=(None,))
      }
      context = {
          "context1": array_ops.placeholder(
              dtypes.string, shape=(3, 4)),
          "context2": array_ops.placeholder(
              dtypes.int64, shape=())
      }
      initial_states = {
          "state1": array_ops.placeholder(
              dtypes.float32, shape=(6, 7)),
          "state2": array_ops.placeholder(
              dtypes.int32, shape=())
      }
      state_saver = sqss.SequenceQueueingStateSaver(
          batch_size=batch_size,
          num_unroll=num_unroll,
          input_length=length,
          input_key=key,
          input_sequences=sequences,
          input_context=context,
          initial_states=initial_states)
      next_batch = state_saver.next_batch
      cancel_op = state_saver.close(cancel_pending_enqueues=True)

      update_1 = next_batch.save_state("state1", 1 + next_batch.state("state1"))
      update_2 = next_batch.save_state("state2",
                                       -1 + next_batch.state("state2"))

      original_values = {}

      def insert(which):
        for i in range(20):
          # Insert varying length inputs
          pad_i = num_unroll * (1 + (i % 10))
          length_i = int(np.random.rand() * pad_i)
          key_value = "key_%02d_%04d" % (which, i)
          stored_state = {
              "length": length_i,
              "seq1": np.random.rand(pad_i, 5),
              "seq2": np.random.rand(pad_i, 4, 2),
              "seq3": np.random.rand(pad_i),
              "context1": np.random.rand(3, 4).astype(np.str),
              "context2": np.asarray(
                  100 * np.random.rand(), dtype=np.int32),
              "state1": np.random.rand(6, 7),
              "state2": np.asarray(
                  100 * np.random.rand(), dtype=np.int32)
          }
          original_values[key_value] = stored_state
          sess.run([state_saver.prefetch_op],
                   feed_dict={
                       length: stored_state["length"],
                       key: key_value,
                       sequences["seq1"]: stored_state["seq1"],
                       sequences["seq2"]: stored_state["seq2"],
                       sequences["seq3"]: stored_state["seq3"],
                       context["context1"]: stored_state["context1"],
                       context["context2"]: stored_state["context2"],
                       initial_states["state1"]: stored_state["state1"],
                       initial_states["state2"]: stored_state["state2"]
                   })

      processed_count = [0]

      def process_and_check_state():
        next_batch = state_saver.next_batch
        while True:
          try:
            (got_key, next_key, length, total_length, sequence, sequence_count,
             context1, context2, seq1, seq2, seq3, state1, state2, _,
             _) = (sess.run([
                 next_batch.key, next_batch.next_key, next_batch.length,
                 next_batch.total_length, next_batch.sequence,
                 next_batch.sequence_count, next_batch.context["context1"],
                 next_batch.context["context2"], next_batch.sequences["seq1"],
                 next_batch.sequences["seq2"], next_batch.sequences["seq3"],
                 next_batch.state("state1"), next_batch.state("state2"),
                 update_1, update_2
             ]))

          except errors_impl.OutOfRangeError:
            # SQSS has been closed
            break

          self.assertEqual(len(got_key), batch_size_value)

          processed_count[0] += len(got_key)

          for i in range(batch_size_value):
            key_name = got_key[i].decode("ascii").split(":")[1]
            # We really saved this unique key
            self.assertTrue(key_name in original_values)
            # The unique key matches next_key
            self.assertEqual(key_name,
                             next_key[i].decode("ascii").split(":")[1])
            # Pull out the random values we used to create this example
            stored_state = original_values[key_name]
            self.assertEqual(total_length[i], stored_state["length"])
            self.assertEqual("%05d_of_%05d:%s" %
                             (sequence[i], sequence_count[i], key_name),
                             got_key[i].decode("ascii"))
            expected_length = max(
                0,
                min(num_unroll,
                    stored_state["length"] - sequence[i] * num_unroll))
            self.assertEqual(length[i], expected_length)
            expected_state1 = stored_state["state1"] + sequence[i]
            expected_state2 = stored_state["state2"] - sequence[i]
            expected_sequence1 = stored_state["seq1"][sequence[i] * num_unroll:(
                sequence[i] + 1) * num_unroll]
            expected_sequence2 = stored_state["seq2"][sequence[i] * num_unroll:(
                sequence[i] + 1) * num_unroll]
            expected_sequence3 = stored_state["seq3"][sequence[i] * num_unroll:(
                sequence[i] + 1) * num_unroll]

            self.assertAllClose(state1[i], expected_state1)
            self.assertAllEqual(state2[i], expected_state2)
            # context1 is strings, which come back as bytes
            self.assertAllEqual(context1[i].astype(np.str),
                                stored_state["context1"])
            self.assertAllEqual(context2[i], stored_state["context2"])
            self.assertAllClose(seq1[i], expected_sequence1)
            self.assertAllClose(seq2[i], expected_sequence2)
            self.assertAllClose(seq3[i], expected_sequence3)

      # Total number of inserts will be a multiple of batch_size
      insert_threads = [
          self.checkedThread(
              insert, args=(which,)) for which in range(batch_size_value)
      ]
      process_threads = [
          self.checkedThread(process_and_check_state)
          for _ in range(num_proc_threads)
      ]

      for t in insert_threads:
        t.start()
      for t in process_threads:
        t.start()
      for t in insert_threads:
        t.join()

      time.sleep(3)  # Allow the threads to run and process for a while
      cancel_op.run()

      for t in process_threads:
        t.join()

      # Each thread processed at least 2 sequence segments
      self.assertGreater(processed_count[0], 2 * 20 * batch_size_value)