def create_agent_fn(sess, environment, summary_writer): """Creates the appropriate agent.""" if agent_type == 'dqn': return dqn_agent.DQNAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_type == 'iqn': return implicit_quantile_agent.ImplicitQuantileAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_type == 'al_dqn': return al_dqn.ALDQNAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_type == 'al_iqn': return al_iqn.ALImplicitQuantileAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_type == 'sail_dqn': return sail_dqn.SAILDQNAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_type == 'sail_iqn': return sail_iqn.SAILImplicitQuantileAgent( sess=sess, num_actions=environment.action_space.n, summary_writer=summary_writer) else: raise ValueError('Wrong agent %s' % agent_type)
def create_agent(sess, environment, summary_writer=None): """Creates a DQN agent. Args: sess: A `tf.Session` object for running associated ops. environment: An Atari 2600 Gym environment. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ if not FLAGS.debug_mode: summary_writer = None if FLAGS.agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=5, summary_writer=summary_writer) elif FLAGS.agent_name == 'rainbow': return rainbow_rgb_agent.RainbowRGBAgent(sess, num_actions=5, summary_writer=summary_writer) elif FLAGS.agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=5, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
def create_agent(sess, environment, agent_name=None, summary_writer=None, debug_mode=False): """Creates an agent. Args: sess: A `tf.compat.v1.Session` object for running associated ops. environment: A gym environment (e.g. Atari 2600). agent_name: str, name of the agent to create. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. debug_mode: bool, whether to output Tensorboard summaries. If set to true, the agent will output in-episode statistics to Tensorboard. Disabled by default as this results in slower training. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ assert agent_name is not None if not debug_mode: summary_writer = None if agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_dqn': return jax_dqn_agent.JaxDQNAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_quantile': return jax_quantile_agent.JaxQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_rainbow': return jax_rainbow_agent.JaxRainbowAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) elif agent_name == 'jax_implicit_quantile': return jax_implicit_quantile_agent.JaxImplicitQuantileAgent( num_actions=environment.action_space.n, summary_writer=summary_writer) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def testCreateAgentWithDefaults(self): # Verifies that we can create and train an agent with the default values. with self.test_session(use_gpu=False) as sess: agent = implicit_quantile_agent.ImplicitQuantileAgent(sess, num_actions=4) sess.run(tf.compat.v1.global_variables_initializer()) observation = np.ones([84, 84, 1]) agent.begin_episode(observation) agent.step(reward=1, observation=observation) agent.end_episode(reward=1)
def create_agent(sess, agent_name, num_actions, observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE, observation_dtype=atari_lib.NATURE_DQN_DTYPE, stack_size=atari_lib.NATURE_DQN_STACK_SIZE, summary_writer=None): """Creates an agent. Args: sess: A `tf.Session` object for running associated ops. agent_name: str, name of the agent to create. num_actions: int, number of actions the agent can take at any state. summary_writer: A Tensorflow summary writer to pass to the agent for in-agent training statistics in Tensorboard. observation_shape: tuple of ints describing the observation shape. observation_dtype: tf.DType, specifies the type of the observations. Note that if your inputs are continuous, you should set this to tf.float32. stack_size: int, number of frames to use in state stack. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list or one of the GAIRL submodules is not in supported list when the chosen agent is GAIRL. """ if agent_name == 'dqn': return dqn_agent.DQNAgent( sess, num_actions, observation_shape=observation_shape, observation_dtype=observation_dtype, stack_size=stack_size, summary_writer=summary_writer ) elif agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions, observation_shape=observation_shape, observation_dtype=observation_dtype, stack_size=stack_size, summary_writer=summary_writer ) elif agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions, summary_writer=summary_writer ) else: raise ValueError('Unknown agent: {}'.format(agent_name))
def _create_test_agent(self, sess): # This dummy network allows us to deterministically anticipate that the # state-action-quantile outputs will be equal to sum of the # corresponding quantile inputs. # State/Quantile shapes will be k x 1, (N x batch_size) x 1, # or (N' x batch_size) x 1. class MockImplicitQuantileNetwork(tf.keras.Model): """Custom tf.keras.Model used in tests.""" def __init__(self, num_actions, quantile_embedding_dim, **kwargs): # This weights_initializer gives action 0 a higher weight, ensuring # that it gets picked by the argmax. super(MockImplicitQuantileNetwork, self).__init__(**kwargs) self.num_actions = num_actions self.layer = tf.keras.layers.Dense( self.num_actions, kernel_initializer=tf.ones_initializer(), bias_initializer=tf.zeros_initializer()) def call(self, state, num_quantiles): batch_size = state.get_shape().as_list()[0] inputs = tf.constant(np.ones( (batch_size * num_quantiles, self.num_actions)), dtype=tf.float32) quantiles_shape = [num_quantiles * batch_size, 1] quantiles = tf.ones(quantiles_shape) return atari_lib.ImplicitQuantileNetworkType( self.layer(inputs), quantiles) agent = implicit_quantile_agent.ImplicitQuantileAgent( sess=sess, network=MockImplicitQuantileNetwork, num_actions=self._num_actions, kappa=1.0, num_tau_samples=2, num_tau_prime_samples=3, num_quantile_samples=4) # This ensures non-random action choices (since epsilon_eval = 0.0) and # skips the train_step. agent.eval_mode = True sess.run(tf.global_variables_initializer()) return agent
def create_agent(sess, environment): """Creates a DQN agent. Args: sess: A `tf.Session` object for running associated ops. environment: An Atari 2600 Gym environment. Returns: agent: An RL agent. Raises: ValueError: If `agent_name` is not in supported list. """ if FLAGS.agent_name == 'dqn': return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n) elif FLAGS.agent_name == 'rainbow': return rainbow_agent.RainbowAgent( sess, num_actions=environment.action_space.n) elif FLAGS.agent_name == 'implicit_quantile': return implicit_quantile_agent.ImplicitQuantileAgent( sess, num_actions=environment.action_space.n) else: raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))