Exemple #1
0
    def testCachedCall(self):
        # A generator returns different values for each invocation.
        def Gen():
            for i in range(1, 1000):
                yield np.array([[0, i], [i, 0]]).astype(np.float32), np.array(
                    [[i, 0], [0, -i]]).astype(np.float32)

        it = Gen()

        # Wraps gen() in a defun.
        @function.Defun()
        def MyFn():
            return tf.py_func(lambda: next(it), [], [tf.float32, tf.float32])

        # A graph calls MyFn via CachedCall.
        g = tf.Graph()
        with g.as_default():
            _ = MyFn.name
            u, v = ops.cached_call(MyFn, [tf.float32, tf.float32])

        with self.session(graph=g) as sess:
            for _ in range(10):
                x, y = sess.run([u, v])
                self.assertAllEqual(x, [[0, 1], [1, 0]])
                self.assertAllEqual(y, [[1, 0], [0, -1]])
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret