def InputBatch(self): p = self.params @function.Defun() def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.to_float(x), tf.to_float(y) # Loads data and label into memory and keep it around. data, label = py_x_ops.cached_call(f=ReadData, T=[tf.float32, tf.float32]) b, shape = self.InputBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = py_x_ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def testCachedCall(self): # A generator returns different values for each invocation. def Gen(): for i in range(1, 1000): yield np.array([[0, i], [i, 0]]).astype(np.float32), np.array( [[i, 0], [0, -i]]).astype(np.float32) it = Gen() # Wraps gen() in a defun. @function.Defun() def MyFn(): return tf.py_func(lambda: next(it), [], [tf.float32, tf.float32]) # A graph calls MyFn via CachedCall. g = tf.Graph() with g.as_default(): _ = MyFn.name u, v = py_x_ops.cached_call(MyFn, [tf.float32, tf.float32]) with self.session(graph=g) as sess: for _ in range(10): x, y = sess.run([u, v]) self.assertAllEqual(x, [[0, 1], [1, 0]]) self.assertAllEqual(y, [[1, 0], [0, -1]])
def _InputBatchFromCKPT(self): p = self.params @function.Defun() def ReadData(): x, = io_ops.restore_v2(p.ckpt, [p.data], [''], [p.data_dtype]) return x # Loads data and label into memory and keep it around. data, = py_x_ops.cached_call(f=ReadData, T=[p.data_dtype]) b = p.batch_size total_length = p.data_shape[0] total_batches = total_length // b total_steps = total_batches // p.num_steps left_over = total_batches % p.num_steps > 0 if left_over: total_steps += 1 if p.eval: dataset = tf.data.Dataset.range(total_steps).repeat() iterator = dataset.make_one_shot_iterator() global_step = iterator.get_next() else: global_step = py_utils.GetOrCreateGlobalStep() - 1 batch_id = tf.to_int32(global_step % total_steps) data = data[:total_batches * b] data = tf.reshape(data, [b, total_batches]) start = p.num_steps * batch_id end = tf.minimum(tf.to_int32(total_batches), start + p.num_steps) raw = tf.gather(data, tf.range(start, end, dtype=tf.int32), axis=1, name='ids') label_end = tf.minimum(end + 1, tf.to_int32(total_batches)) label = tf.gather(data, tf.range(start + 1, label_end, dtype=tf.int32), axis=1, name='labels') raw = py_utils.PadOrTrimTo(raw, [b, end - start]) ret = py_utils.NestedMap() # raw = tf.reshape(data[:700], [20, 35]) # ret.ids = raw # ret.labels = raw # ret.weights = tf.ones([20, 35]) # ret.paddings = 1.0 - ret.weights # ret.word_count = 700 # ret.take_last_state = py_utils.GetOrCreateGlobalStep() > 0 ret.ids = raw ret.labels = py_utils.PadOrTrimTo(label, [b, end - start]) ret.weights = py_utils.PadOrTrimTo(tf.ones([b, label_end - start], dtype=tf.float32), [b, end - start]) ret.paddings = 1.0 - ret.weights ret.word_count = b * (label_end - start - 1) ret.take_last_state = batch_id > 0 return ret