示例#1
0
def create_agent(sess, environment, summary_writer=None):
    """Creates a DQN agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    environment: An Atari 2600 Gym environment.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    if not FLAGS.debug_mode:
        summary_writer = None
    if FLAGS.agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=5,
                                  summary_writer=summary_writer)
    elif FLAGS.agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(sess,
                                          num_actions=5,
                                          summary_writer=summary_writer)
    elif FLAGS.agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess, num_actions=5, summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
示例#2
0
def create_agent(sess,
                 environment,
                 agent_name=None,
                 summary_writer=None,
                 debug_mode=False):
    """Creates an agent.

  Args:
    sess: A `tf.compat.v1.Session` object for running associated ops.
    environment: A gym environment (e.g. Atari 2600).
    agent_name: str, name of the agent to create.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
      the agent will output in-episode statistics to Tensorboard. Disabled by
      default as this results in slower training.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list.
  """
    assert agent_name is not None
    if not debug_mode:
        summary_writer = None
    if agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess,
                                  num_actions=environment.action_space.n,
                                  summary_writer=summary_writer)
    elif agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess,
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_dqn':
        return jax_dqn_agent.JaxDQNAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_quantile':
        return jax_quantile_agent.JaxQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_rainbow':
        return jax_rainbow_agent.JaxRainbowAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    elif agent_name == 'jax_implicit_quantile':
        return jax_implicit_quantile_agent.JaxImplicitQuantileAgent(
            num_actions=environment.action_space.n,
            summary_writer=summary_writer)
    else:
        raise ValueError('Unknown agent: {}'.format(agent_name))
 def testCreateAgentWithDefaults(self):
     # Verifies that we can create and train an agent with the default values.
     with tf.Session() as sess:
         agent = rainbow_agent.RainbowAgent(sess, num_actions=4)
         sess.run(tf.global_variables_initializer())
         observation = np.ones([84, 84, 1])
         agent.begin_episode(observation)
         agent.step(reward=1, observation=observation)
         agent.end_episode(reward=1)
示例#4
0
 def testStoreTransitionWithPrioritizedSamplingy(self):
   with tf.Session() as sess:
     agent = rainbow_agent.RainbowAgent(
         sess, num_actions=4, replay_scheme='prioritized')
     dummy_frame = np.zeros((84, 84))
     # Adding transitions with default, 10., default priorities.
     agent._store_transition(dummy_frame, 0, 0, False)
     agent._store_transition(dummy_frame, 0, 0, False, 10.)
     agent._store_transition(dummy_frame, 0, 0, False)
     returned_priorities = agent._replay.memory.get_priority(
         np.arange(self.stack_size - 1, self.stack_size + 2, dtype=np.int32))
     expected_priorities = [1., 10., 10.]
     self.assertAllEqual(returned_priorities, expected_priorities)
    def __init__(self, **kwargs):
        super(RainbowWrapper, self).__init__(**kwargs)
        args, num_inputs, num_outputs, factor = self.get_args(kwargs)
        self.minmax = default_value_arg(kwargs, 'minmax', None)
        if args.true_environment and args.state_forms[0] == 'raw':
            network = atari_lib.rainbow_network
            print(network)
            observation_shape = (84, 84)
        else:
            network = create_rainbow_network(self.minmax)
            print("num inputs", num_inputs)
            observation_shape = (num_inputs, )

        self.sess = tf.Session(
            '', config=tf.ConfigProto(allow_soft_placement=True))
        # local_device_protos = device_lib.list_local_devices()
        # print([x.name for x in local_device_protos])
        # print(atari_lib.NATURE_DQN_OBSERVATION_SHAPE,
        #            atari_lib.NATURE_DQN_DTYPE,
        #            atari_lib.NATURE_DQN_STACK_SIZE,)
        # print(num_inputs)

        self.dope_rainbow = rainbow_agent.RainbowAgent(
            self.sess,
            num_outputs,
            observation_shape=observation_shape,
            observation_dtype=tf.float32,
            stack_size=args.num_stack,
            network=network,
            num_atoms=51,
            vmax=args.value_bounds[1],
            gamma=args.gamma,
            update_horizon=1,
            min_replay_history=20000,
            update_period=4,
            target_update_period=8000,
            epsilon_fn=dqn_agent.linearly_decaying_epsilon,
            epsilon_train=0.01,
            epsilon_eval=0.001,
            epsilon_decay_period=250000,
            replay_scheme='prioritized',
            tf_device='/gpu:' + str(args.gpu),
            use_staging=True,
            optimizer=tf.train.AdamOptimizer(learning_rate=.00025,
                                             epsilon=0.0003125),
            summary_writer=None,
            summary_writing_frequency=500)
        self.dope_rainbow.eval_mode = False
        self.sess.run(tf.global_variables_initializer())
示例#6
0
def create_agent(sess, agent_name, num_actions,
                 observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE,
                 observation_dtype=atari_lib.NATURE_DQN_DTYPE,
                 stack_size=atari_lib.NATURE_DQN_STACK_SIZE,
                 summary_writer=None):
  """Creates an agent.

  Args:
    sess: A `tf.Session` object for running associated ops.
    agent_name: str, name of the agent to create.
    num_actions: int, number of actions the agent can take at any state.
    summary_writer: A Tensorflow summary writer to pass to the agent
      for in-agent training statistics in Tensorboard.
    observation_shape: tuple of ints describing the observation shape.
    observation_dtype: tf.DType, specifies the type of the observations. Note
      that if your inputs are continuous, you should set this to tf.float32.
    stack_size: int, number of frames to use in state stack.

  Returns:
    agent: An RL agent.

  Raises:
    ValueError: If `agent_name` is not in supported list or one of the
      GAIRL submodules is not in supported list when the chosen agent is GAIRL.
  """
  if agent_name == 'dqn':
    return dqn_agent.DQNAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'rainbow':
    return rainbow_agent.RainbowAgent(
      sess, num_actions, observation_shape=observation_shape,
      observation_dtype=observation_dtype, stack_size=stack_size,
      summary_writer=summary_writer
    )
  elif agent_name == 'implicit_quantile':
    return implicit_quantile_agent.ImplicitQuantileAgent(
      sess, num_actions, summary_writer=summary_writer
    )
  else:
    raise ValueError('Unknown agent: {}'.format(agent_name))
示例#7
0
def create_agent(sess, environment):
    """Creates a DQN agent.

    Args:
      sess: A `tf.Session` object for running associated ops.
      environment: An Atari 2600 Gym environment.

    Returns:
      agent: An RL agent.

    Raises:
      ValueError: If `agent_name` is not in supported list.
    """
    if FLAGS.agent_name == 'dqn':
        return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'rainbow':
        return rainbow_agent.RainbowAgent(
            sess, num_actions=environment.action_space.n)
    elif FLAGS.agent_name == 'implicit_quantile':
        return implicit_quantile_agent.ImplicitQuantileAgent(
            sess, num_actions=environment.action_space.n)
    else:
        raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
def create_agent_fn(sess, env, summary_writer):
    return rainbow_agent.RainbowAgent(sess=sess,
                                      num_actions=env.action_space.n,
                                      summary_writer=summary_writer)
    def _create_test_agent(self, sess):
        stack_size = self.stack_size

        # This dummy network allows us to deterministically anticipate that
        # action 0 will be selected by an argmax.

        # In Rainbow we are dealing with a distribution over Q-values,
        # which are represented as num_atoms bins, ranging from -vmax to vmax.
        # The output layer will have num_actions * num_atoms elements,
        # so each group of num_atoms weights represent the logits for a
        # particular action. By setting 1s everywhere, except for the first
        # num_atoms (representing the logits for the first action), which are
        # set to np.arange(num_atoms), we are ensuring that the first action
        # places higher weight on higher Q-values; this results in the first
        # action being chosen.
        class MockRainbowNetwork(tf.keras.Model):
            """Custom tf.keras.Model used in tests."""
            def __init__(self, num_actions, num_atoms, support, **kwargs):
                super(MockRainbowNetwork, self).__init__(**kwargs)
                self.num_actions = num_actions
                self.num_atoms = num_atoms
                self.support = support
                first_row = np.tile(np.ones(self.num_atoms),
                                    self.num_actions - 1)
                first_row = np.concatenate(
                    (np.arange(self.num_atoms), first_row))
                bottom_rows = np.tile(
                    np.ones(self.num_actions * self.num_atoms),
                    (stack_size - 1, 1))
                weights_initializer = np.concatenate(
                    ([first_row], bottom_rows))
                self.layer = tf.keras.layers.Dense(
                    self.num_actions * self.num_atoms,
                    kernel_initializer=tf.constant_initializer(
                        weights_initializer),
                    bias_initializer=tf.ones_initializer())

            def call(self, state):
                inputs = tf.constant(np.zeros((state.shape[0], stack_size)),
                                     dtype=tf.float32)
                net = self.layer(inputs)
                logits = tf.reshape(net,
                                    [-1, self.num_actions, self.num_atoms])
                probabilities = tf.keras.activations.softmax(logits)
                qs = tf.reduce_sum(self.support * probabilities, axis=2)
                return atari_lib.RainbowNetworkType(qs, logits, probabilities)

        agent = rainbow_agent.RainbowAgent(
            sess=sess,
            network=MockRainbowNetwork,
            num_actions=self._num_actions,
            num_atoms=self._num_atoms,
            vmax=self._vmax,
            min_replay_history=self._min_replay_history,
            epsilon_fn=lambda w, x, y, z: 0.0,  # No exploration.
            epsilon_eval=0.0,
            epsilon_decay_period=self._epsilon_decay_period)
        # This ensures non-random action choices (since epsilon_eval = 0.0) and
        # skips the train_step.
        agent.eval_mode = True
        sess.run(tf.global_variables_initializer())
        return agent