def testCachedCall(self): # A generator returns different values for each invocation. def Gen(): for i in range(1, 1000): yield np.array([[0, i], [i, 0]]).astype(np.float32), np.array( [[i, 0], [0, -i]]).astype(np.float32) it = Gen() # Wraps gen() in a defun. @function.Defun() def MyFn(): return tf.py_func(lambda: next(it), [], [tf.float32, tf.float32]) # A graph calls MyFn via CachedCall. g = tf.Graph() with g.as_default(): _ = MyFn.name u, v = ops.cached_call(MyFn, [tf.float32, tf.float32]) with self.session(graph=g) as sess: for _ in range(10): x, y = sess.run([u, v]) self.assertAllEqual(x, [[0, 1], [1, 0]]) self.assertAllEqual(y, [[1, 0], [0, -1]])
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret