예제 #1
0
  def InputBatch(self):
    p = self.params

    @function.Defun()
    def ReadData():
      x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                               [p.data_dtype, p.label_dtype])
      # Always convert to float32.
      return tf.to_float(x), tf.to_float(y)

    # Loads data and label into memory and keep it around.
    data, label = py_x_ops.cached_call(f=ReadData, T=[tf.float32, tf.float32])
    b, shape = self.InputBatchSize(), list(p.data_shape)
    data = tf.reshape(data, [-1] + shape)
    label = tf.reshape(label, [-1])
    label = py_utils.HasShape(label, [tf.shape(data)[0]])
    sample_ids = py_x_ops.random_permutation_sequence(
        num=p.num_samples,
        batch=b,
        repeat=p.repeat,
        seed=p.random_seed if p.random_seed else 0)
    n = tf.shape(sample_ids)[0]
    raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
    ret = py_utils.NestedMap(
        raw=raw,
        data=self._Preprocess(raw),
        label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
        weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
    if not py_utils.use_tpu():
      ret['sample_ids'] = sample_ids
    return ret
예제 #2
0
    def testCachedCall(self):
        # A generator returns different values for each invocation.
        def Gen():
            for i in range(1, 1000):
                yield np.array([[0, i], [i, 0]]).astype(np.float32), np.array(
                    [[i, 0], [0, -i]]).astype(np.float32)

        it = Gen()

        # Wraps gen() in a defun.
        @function.Defun()
        def MyFn():
            return tf.py_func(lambda: next(it), [], [tf.float32, tf.float32])

        # A graph calls MyFn via CachedCall.
        g = tf.Graph()
        with g.as_default():
            _ = MyFn.name
            u, v = py_x_ops.cached_call(MyFn, [tf.float32, tf.float32])

        with self.session(graph=g) as sess:
            for _ in range(10):
                x, y = sess.run([u, v])
                self.assertAllEqual(x, [[0, 1], [1, 0]])
                self.assertAllEqual(y, [[1, 0], [0, -1]])
예제 #3
0
  def _InputBatchFromCKPT(self):
    p = self.params

    @function.Defun()
    def ReadData():
      x, = io_ops.restore_v2(p.ckpt, [p.data], [''],
                               [p.data_dtype])
      return x

    # Loads data and label into memory and keep it around.
    data, = py_x_ops.cached_call(f=ReadData, T=[p.data_dtype])
    
    
    b = p.batch_size
    total_length = p.data_shape[0]
    total_batches = total_length // b
    total_steps = total_batches // p.num_steps
    left_over = total_batches % p.num_steps > 0
    if left_over:
      total_steps += 1
    
    if p.eval:
      dataset = tf.data.Dataset.range(total_steps).repeat()
      iterator = dataset.make_one_shot_iterator()
      global_step = iterator.get_next()
    else:
      global_step = py_utils.GetOrCreateGlobalStep() - 1
    
    batch_id = tf.to_int32(global_step % total_steps)
    
    data = data[:total_batches * b]
    data = tf.reshape(data, [b, total_batches])
    
    start = p.num_steps * batch_id
    end = tf.minimum(tf.to_int32(total_batches), start + p.num_steps)
    raw = tf.gather(data, tf.range(start, end, dtype=tf.int32), axis=1, name='ids')
    label_end = tf.minimum(end + 1, tf.to_int32(total_batches))
    label = tf.gather(data, tf.range(start + 1, label_end, dtype=tf.int32), axis=1, name='labels')
    raw = py_utils.PadOrTrimTo(raw, [b, end - start])
    ret = py_utils.NestedMap()
    # raw = tf.reshape(data[:700], [20, 35])
    # ret.ids = raw
    # ret.labels = raw
    # ret.weights = tf.ones([20, 35])
    # ret.paddings = 1.0 - ret.weights
    # ret.word_count = 700
    # ret.take_last_state = py_utils.GetOrCreateGlobalStep() > 0
    ret.ids = raw
    ret.labels = py_utils.PadOrTrimTo(label, [b, end - start])
    ret.weights = py_utils.PadOrTrimTo(tf.ones([b, label_end - start], dtype=tf.float32), [b, end - start])
    ret.paddings = 1.0 - ret.weights
    ret.word_count = b * (label_end - start - 1)
    ret.take_last_state = batch_id > 0
    
    return ret