Example #1
0
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
Example #2
0
    def testRandomPermutationSequenceNoRepeat(self):
        with self.session() as sess:
            out = ops.random_permutation_sequence(num=20,
                                                  batch=7,
                                                  repeat=False)

            # Each epoch takes exactly 3 steps.
            vals = sess.run(out).tolist() + sess.run(out).tolist() + sess.run(
                out).tolist()
            self.assertEqual(list(range(20)), sorted(vals))

            # repeat=False. We should see OutOfRange error.
            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(out)
Example #3
0
    def testRandomPermutationSequenceRepeat(self):
        with self.session() as sess:
            out = ops.random_permutation_sequence(num=20, batch=7, repeat=True)

            remaining = list(range(20))
            for _ in range(10):
                # Each epoch takes exactly 3 steps.
                vals = sess.run(out).tolist() + sess.run(
                    out).tolist() + sess.run(out).tolist()
                self.assertEqual(len(vals), 21)

                # Contains all the remaining values from previous epoch.
                for x in remaining:
                    vals.remove(x)  # Raises exception if x is not in vals.

                # Remaining items have no duplicates.
                self.assertEqual(len(vals), len(set(vals)))

                remaining = list(set(range(20)) - set(vals))