def _set_up_staging(self, transition):
        """Sets up staging ops for prefetching the next transition.

    This allows us to hide the py_func latency. To do so we use a staging area
    to pre-fetch the next batch of transitions.

    Args:
      transition: tuple of tf.Tensors with shape
        memory.get_transition_elements().

    Returns:
      prefetched_transition: tuple of tf.Tensors with shape
        memory.get_transition_elements() that have been previously prefetched.
    """
        transition_type = self.memory.get_transition_elements()

        # Create the staging area in CPU.
        prefetch_area = StagingArea(
            [shape_with_type.type for shape_with_type in transition_type])

        # Store prefetch op for tests, but keep it private -- users should not be
        # calling _prefetch_batch.
        self._prefetch_batch = prefetch_area.put(transition)
        initial_prefetch = tf.cond(tf.equal(prefetch_area.size(), 0),
                                   lambda: prefetch_area.put(transition),
                                   tf.no_op)

        # Every time a transition is sampled self.prefetch_batch will be
        # called. If the staging area is empty, two put ops will be called.
        with tf.control_dependencies([self._prefetch_batch, initial_prefetch]):
            prefetched_transition = prefetch_area.get()

        return prefetched_transition
예제 #2
0
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

length = 10
dataset_range = tf.data.Dataset.range(length)
iter = dataset_range.make_one_shot_iterator()
next_item = iter.get_next()

with tf.device('cpu:0'):
    a = tf.Variable([1], dtype=tf.int64)

    area = StagingArea(dtypes=[tf.int64])
    area_put = area.put([next_item])
    area_get = area.get()[0]
    area_size = area.size()
    area_get_put = tf.tuple([area_get], control_inputs=[area_put])[0]

    b = a + area_get
    c = b + area_get

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # first item, just put()
    print('put:', sess.run(area_put))
    # get() & put()
    for i in range(length - 1):
        # this works as "semicolon"
        print(sess.run(c))
        print('put(); get() =', sess.run(area_put))
        print('size:', sess.run(area_size))