def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def testRandomPermutationSequenceNoRepeat(self): with self.session() as sess: out = ops.random_permutation_sequence(num=20, batch=7, repeat=False) # Each epoch takes exactly 3 steps. vals = sess.run(out).tolist() + sess.run(out).tolist() + sess.run( out).tolist() self.assertEqual(list(range(20)), sorted(vals)) # repeat=False. We should see OutOfRange error. with self.assertRaises(tf.errors.OutOfRangeError): sess.run(out)
def testRandomPermutationSequenceRepeat(self): with self.session() as sess: out = ops.random_permutation_sequence(num=20, batch=7, repeat=True) remaining = list(range(20)) for _ in range(10): # Each epoch takes exactly 3 steps. vals = sess.run(out).tolist() + sess.run( out).tolist() + sess.run(out).tolist() self.assertEqual(len(vals), 21) # Contains all the remaining values from previous epoch. for x in remaining: vals.remove(x) # Raises exception if x is not in vals. # Remaining items have no duplicates. self.assertEqual(len(vals), len(set(vals))) remaining = list(set(range(20)) - set(vals))