Ejemplo n.º 1
0
  def create_sampling_ops(self, use_staging):
    """Creates the ops necessary to sample from the replay buffer.

    Creates the transition dictionary containing the sampling tensors.

    Args:
      use_staging: bool, when True it would use a staging area to prefetch
        the next sampling batch.
    """
    with tf.name_scope('sample_replay'):
      with tf.device('/cpu:*'):
        transition_type = self.memory.get_transition_elements()
        transition_tensors = tf.py_func(
            self.memory.sample_transition_batch, [],
            [return_entry.type for return_entry in transition_type],
            name='replay_sample_py_func')
        self._set_transition_shape(transition_tensors, transition_type)
        if use_staging:
          transition_tensors = self._set_up_staging(transition_tensors)
          self._set_transition_shape(transition_tensors, transition_type)

        # Unpack sample transition into member variables.
        self.unpack_transition(transition_tensors, transition_type)
Ejemplo n.º 2
0
    def __init__(self,
                 num_actions,
                 observation_size,
                 stack_size,
                 use_staging=True,
                 replay_capacity=1000000,
                 batch_size=32,
                 update_horizon=1,
                 gamma=1.0,
                 wrapped_memory=None):
        """Initializes a graph wrapper for the python replay memory.

    Args:
      num_actions: int, number of possible actions.
      observation_size: int, size of an input frame.
      stack_size: int, number of frames to use in state stack.
      use_staging: bool, when True it would use a staging area to prefetch the
        next sampling batch.
      replay_capacity: int, number of transitions to keep in memory.
      batch_size: int.
      update_horizon: int, length of update ('n' in n-step update).
      gamma: int, the discount factor.
      wrapped_memory: The 'inner' memory data structure. Defaults to None, which
        creates the standard DQN replay memory.

    Raises:
      ValueError: If update_horizon is not positive.
      ValueError: If discount factor is not in [0, 1].
    """
        if replay_capacity < update_horizon + 1:
            raise ValueError(
                'Update horizon (%i) should be significantly smaller '
                'than replay capacity (%i).' %
                (update_horizon, replay_capacity))
        if not update_horizon >= 1:
            raise ValueError('Update horizon must be positive.')
        if not 0.0 <= gamma <= 1.0:
            raise ValueError('Discount factor (gamma) must be in [0, 1].')

        # Allow subclasses to create self.memory.
        if wrapped_memory is not None:
            self.memory = wrapped_memory
        else:
            self.memory = OutOfGraphReplayMemory(num_actions, observation_size,
                                                 stack_size, replay_capacity,
                                                 batch_size, update_horizon,
                                                 gamma)

        with tf.name_scope('replay'):
            with tf.name_scope('add_placeholders'):
                self.add_obs_ph = tf.placeholder(tf.uint8, [observation_size],
                                                 name='add_obs_ph')
                self.add_action_ph = tf.placeholder(tf.int32, [],
                                                    name='add_action_ph')
                self.add_reward_ph = tf.placeholder(tf.float32, [],
                                                    name='add_reward_ph')
                self.add_terminal_ph = tf.placeholder(tf.uint8, [],
                                                      name='add_terminal_ph')
                self.add_legal_actions_ph = tf.placeholder(
                    tf.float32, [num_actions], name='add_legal_actions_ph')

            add_transition_ph = [
                self.add_obs_ph, self.add_action_ph, self.add_reward_ph,
                self.add_terminal_ph, self.add_legal_actions_ph
            ]

            with tf.device('/cpu:*'):
                self.add_transition_op = tf.py_func(self.memory.add,
                                                    add_transition_ph, [],
                                                    name='replay_add_py_func')

                self.transition = tf.py_func(
                    self.memory.sample_transition_batch, [], [
                        tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8,
                        tf.int32, tf.float32
                    ],
                    name='replay_sample_py_func')

                if use_staging:
                    # To hide the py_func latency use a staging area to pre-fetch the next
                    # batch of transitions.
                    (states, actions, rewards, next_states, terminals, indices,
                     next_legal_actions) = self.transition
                    # StagingArea requires all the shapes to be defined.
                    states.set_shape(
                        [batch_size, observation_size, stack_size])
                    actions.set_shape([batch_size])
                    rewards.set_shape([batch_size])
                    next_states.set_shape(
                        [batch_size, observation_size, stack_size])
                    terminals.set_shape([batch_size])
                    indices.set_shape([batch_size])
                    next_legal_actions.set_shape([batch_size, num_actions])

                    # Create the staging area in CPU.
                    prefetch_area = tf.contrib.staging.StagingArea([
                        tf.uint8, tf.int32, tf.float32, tf.uint8, tf.uint8,
                        tf.int32, tf.float32
                    ])

                    self.prefetch_batch = prefetch_area.put(
                        (states, actions, rewards, next_states, terminals,
                         indices, next_legal_actions))
                else:
                    self.prefetch_batch = tf.no_op()

            if use_staging:
                # Get the sample_transition_batch in GPU. This would do the copy from
                # CPU to GPU.
                self.transition = prefetch_area.get()

            (self.states, self.actions, self.rewards, self.next_states,
             self.terminals, self.indices,
             self.next_legal_actions) = self.transition

            # Since these are py_func tensors, no information about their shape is
            # present. Setting the shape only for the necessary tensors
            self.states.set_shape([None, observation_size, stack_size])
            self.next_states.set_shape([None, observation_size, stack_size])
    def __init__(self,
                 num_actions=None,
                 observation_size=None,
                 num_players=None,
                 gamma=0.99,
                 update_horizon=1,
                 min_replay_history=500,
                 update_period=4,
                 stack_size=1,
                 target_update_period=500,
                 epsilon_fn=linearly_decaying_epsilon,
                 epsilon_train=0.02,
                 epsilon_eval=0.001,
                 epsilon_decay_period=1000,
                 graph_template=dqn_template,
                 tf_device='/cpu:*',
                 use_staging=True,
                 optimizer=tf.train.RMSPropOptimizer(learning_rate=.0025,
                                                     decay=0.95,
                                                     momentum=0.0,
                                                     epsilon=1e-6,
                                                     centered=True)):
        """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      stack_size: int, number of observations to use as state.
      target_update_period: Update period for the target network.
      epsilon_fn: Function expecting 4 parameters: (decay_period, step,
        warmup_steps, epsilon), and which returns the epsilon value used for
        exploration during training.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      graph_template: function for building the neural network graph.
      tf_device: str, Tensorflow device on which to run computations.
      use_staging: bool, when True use a staging area to prefetch the next
        sampling batch.
      optimizer: Optimizer instance used for learning.
    """

        self.partial_reload = False

        tf.logging.info('Creating %s agent with the following parameters:',
                        self.__class__.__name__)
        tf.logging.info('\t gamma: %f', gamma)
        tf.logging.info('\t update_horizon: %f', update_horizon)
        tf.logging.info('\t min_replay_history: %d', min_replay_history)
        tf.logging.info('\t update_period: %d', update_period)
        tf.logging.info('\t target_update_period: %d', target_update_period)
        tf.logging.info('\t epsilon_train: %f', epsilon_train)
        tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
        tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
        tf.logging.info('\t tf_device: %s', tf_device)
        tf.logging.info('\t use_staging: %s', use_staging)
        tf.logging.info('\t optimizer: %s', optimizer)

        # Global variables.
        self.num_actions = num_actions
        self.observation_size = observation_size
        self.num_players = num_players
        self.gamma = gamma
        self.update_horizon = update_horizon
        self.cumulative_gamma = math.pow(gamma, update_horizon)
        self.min_replay_history = min_replay_history
        self.target_update_period = target_update_period
        self.epsilon_fn = epsilon_fn
        self.epsilon_train = epsilon_train
        self.epsilon_eval = epsilon_eval
        self.epsilon_decay_period = epsilon_decay_period
        self.update_period = update_period
        self.eval_mode = False
        self.training_steps = 0
        self.batch_staged = False
        self.optimizer = optimizer

        with tf.device(tf_device):
            # Calling online_convnet will generate a new graph as defined in
            # graph_template using whatever input is passed, but will always share
            # the same weights.
            online_convnet = tf.make_template('Online', graph_template)
            target_convnet = tf.make_template('Target', graph_template)
            # The state of the agent. The last axis is the number of past observations
            # that make up the state.
            states_shape = (1, observation_size, stack_size)
            self.state = np.zeros(states_shape)
            self.state_ph = tf.placeholder(tf.uint8,
                                           states_shape,
                                           name='state_ph')
            self.legal_actions_ph = tf.placeholder(tf.float32,
                                                   [self.num_actions],
                                                   name='legal_actions_ph')
            self._q = online_convnet(state=self.state_ph,
                                     num_actions=self.num_actions)
            self._replay = self._build_replay_memory(use_staging)
            self._replay_qs = online_convnet(self._replay.states,
                                             self.num_actions)
            self._replay_next_qt = target_convnet(self._replay.next_states,
                                                  self.num_actions)
            self._train_op = self._build_train_op()
            self._sync_qt_ops = self._build_sync_op()

            self._q_argmax = tf.argmax(self._q + self.legal_actions_ph,
                                       axis=1)[0]

        # Set up a session and initialize variables.
        self._sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        self._init_op = tf.global_variables_initializer()
        self._sess.run(self._init_op)

        self._saver = tf.train.Saver(max_to_keep=3)

        # This keeps tracks of the observed transitions during play, for each
        # player.
        self.transitions = [[] for _ in range(num_players)]