示例#1
0
    def __init__(self, batch_size=1):
        self._state_feeds = {}
        self._state_fetches = []
        self._state_feed_names = []
        self._batch_size = batch_size
        self._graph = tf.get_default_graph()

        # Store the feeds and fetches for recurrent states.
        statesaver = bookkeeper.recurrent_state()
        for state in six.itervalues(statesaver.GetStateDescriptors()):
            shape = [d.size for d in state['feed_shape'].dim]
            if shape[0] == 0:
                shape[0] = batch_size
            feed_name = state['feed_op'].name
            self._state_feed_names.append(feed_name)
            self._state_fetches.append(state['fetch_name'])
  def __init__(self, batch_size=1):
    self._state_feeds = {}
    self._state_fetches = []
    self._state_feed_names = []
    self._batch_size = batch_size
    self._graph = tf.get_default_graph()

    # Store the feeds and fetches for recurrent states.
    statesaver = bookkeeper.recurrent_state()
    for state in six.itervalues(statesaver.GetStateDescriptors()):
      shape = [d.size for d in state['feed_shape'].dim]
      if shape[0] == 0:
        shape[0] = batch_size
      feed_name = state['feed_op'].name
      self._state_feed_names.append(feed_name)
      self._state_fetches.append(state['fetch_name'])