Beispiel #1
0
  def InputBatch(self):
    p = self.params

    @function.Defun()
    def ReadData():
      x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                               [p.data_dtype, p.label_dtype])
      # Always convert to float32.
      return tf.to_float(x), tf.to_float(y)

    # Loads data and label into memory and keep it around.
    data, label = py_x_ops.cached_call(f=ReadData, T=[tf.float32, tf.float32])
    b, shape = self.InputBatchSize(), list(p.data_shape)
    data = tf.reshape(data, [-1] + shape)
    label = tf.reshape(label, [-1])
    label = py_utils.HasShape(label, [tf.shape(data)[0]])
    sample_ids = py_x_ops.random_permutation_sequence(
        num=p.num_samples,
        batch=b,
        repeat=p.repeat,
        seed=p.random_seed if p.random_seed else 0)
    n = tf.shape(sample_ids)[0]
    raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
    ret = py_utils.NestedMap(
        raw=raw,
        data=self._Preprocess(raw),
        label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
        weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
    if not py_utils.use_tpu():
      ret['sample_ids'] = sample_ids
    return ret
Beispiel #2
0
  def testRandomPermutationSequenceNoRepeat(self):
    with self.session() as sess:
      out = py_x_ops.random_permutation_sequence(num=20, batch=7, repeat=False)

      # Each epoch takes exactly 3 steps.
      vals = sess.run(out).tolist() + sess.run(out).tolist() + sess.run(
          out).tolist()
      self.assertEqual(list(range(20)), sorted(vals))

      # repeat=False. We should see OutOfRange error.
      with self.assertRaises(tf.errors.OutOfRangeError):
        sess.run(out)
Beispiel #3
0
  def testRandomPermutationSequenceRepeat(self):
    with self.session() as sess:
      out = py_x_ops.random_permutation_sequence(num=20, batch=7, repeat=True)

      remaining = list(range(20))
      for _ in range(10):
        # Each epoch takes exactly 3 steps.
        vals = sess.run(out).tolist() + sess.run(out).tolist() + sess.run(
            out).tolist()
        self.assertEqual(len(vals), 21)

        # Contains all the remaining values from previous epoch.
        for x in remaining:
          vals.remove(x)  # Raises exception if x is not in vals.

        # Remaining items have no duplicates.
        self.assertEqual(len(vals), len(set(vals)))

        remaining = list(set(range(20)) - set(vals))