def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def actor_evaluator( variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor = actors.FeedForwardActor( policy=evaluator_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient(variable_source, 'policy', device='cpu')) # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split( jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, random_key=key_learner, optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(key_actor), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) evaluator = actors.FeedForwardActor( policy=evaluator_network, random_key=key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient(learner, 'policy', device='cpu')) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def make_actor( self, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: assert variable_source is not None key, self._random_key = jax.random.split(self._random_key) return actors.FeedForwardActor( policy=policy_network, random_key=key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient(variable_source, '', device='cpu'), adder=adder, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 32.0, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0., learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec=environment_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder(client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset( server_address=address, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, transition_adder=True) def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=dataset.as_numpy_iterator(), replay_client=reverb.Client(address), ) variable_client = variable_utils.VariableClient(learner, '') actor = actors.FeedForwardActor(policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_client, adder=adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = tfds.as_numpy(dataset) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) policy_network = hk.without_apply_rng(hk.transform(policy_network)) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: action_values = policy_network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, optimizer=optax.adam(FLAGS.learning_rate), obs_spec=environment.observation_spec(), dataset=dataset, counter=learner_counter, rng=hk.PRNGSequence(FLAGS.seed)) # Create the actor which defines how we take actions. variable_client = variable_utils.VariableClient(learner, '') evaluator = actors.FeedForwardActor(evaluator_network, variable_client=variable_client, rng=hk.PRNGSequence(FLAGS.seed)) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, samples_per_insert: float = 0.5, min_replay_size: int = 1000, max_replay_size: int = 1000000, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, n_step: int = 5, epsilon: float = 0.05, learning_rate: float = 1e-3, discount: float = 0.99, seed: int = 1, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=n_step, batch_size=batch_size, max_replay_size=max_replay_size, min_replay_size=min_replay_size, priority_exponent=priority_exponent, discount=discount, ) self._server = reverb_replay.server # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(seed), optimizer=optax.adam(learning_rate), discount=discount, importance_sampling_exponent=importance_sampling_exponent, target_update_period=target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(epsilon).sample(key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(seed), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)