def actor_evaluator( variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor = actors.FeedForwardActor( policy=evaluator_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient(variable_source, 'policy', device='cpu')) # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def actor(self, random_key, replay, variable_source, counter, actor_id): """The actor process.""" adder = self._builder.make_adder(replay) environment_key, actor_key = jax.random.split(random_key) # Create environment and policy core. # Environments normally require uint32 as a seed. environment = self._environment_factory( utils.sample_uint32(environment_key)) networks = self._network_factory(specs.make_environment_spec(environment)) policy_network = self._policy_network(networks) actor = self._builder.make_actor(actor_key, policy_network, adder, variable_source) # Create logger and counter. counter = counting.Counter(counter, 'actor') # Only actor #0 will write to bigtable in order not to spam it too much. logger = self._actor_logger_fn(actor_id) # Create the loop to connect environment and agent. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=self._observers)
def actor_evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') actor = actors.GenericActor(actor_core, random_key, variable_client, backend='cpu') # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, make_actor: MakeActorFn, ): """The evaluation process.""" # Create environment and evaluator networks environment_key, actor_key = jax.random.split(random_key) # Environments normally require uint32 as a seed. environment = environment_factory(utils.sample_uint32(environment_key)) networks = network_factory(specs.make_environment_spec(environment)) actor = make_actor(actor_key, policy_factory(networks), variable_source) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') if logger_fn is not None: logger = logger_fn('evaluator', 'actor_steps') else: logger = loggers.make_default_logger('evaluator', log_to_bigtable, steps_key='actor_steps') # Create the run loop and return it. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=observers)
def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def evaluator( random_key: types.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, make_actor: MakeActorFn, ): """The evaluation process.""" # Create environment and evaluator networks environment_key, actor_key = jax.random.split(random_key) # Environments normally require uint32 as a seed. environment = environment_factory(utils.sample_uint32(environment_key)) environment_spec = specs.make_environment_spec(environment) networks = network_factory(environment_spec) policy = policy_factory(networks, environment_spec, True) actor = make_actor(actor_key, policy, environment_spec, variable_source) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = logger_factory('evaluator', 'actor_steps', 0) # Create the run loop and return it. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=observers)
def test_environment_loop(self): # Create the actor/environment and stick them in a loop. environment = fakes.DiscreteEnvironment(episode_length=10) actor = fakes.Actor(specs.make_environment_spec(environment)) loop = environment_loop.EnvironmentLoop(environment, actor) # Run the loop. There should be episode_length+1 update calls per episode. loop.run(num_episodes=10) self.assertEqual(actor.num_updates, 100)
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) network = snt.DeepRNN([ snt.Flatten(), snt.Linear(env_spec.actions.num_values), lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype ), ]) actor = actors_tf2.RecurrentActor(network) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) network = snt.Sequential([ snt.Flatten(), snt.Linear(env_spec.actions.num_values), lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype ), ]) actor = actors_tf2.FeedForwardActor(network) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_recurrent(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) if has_extras: return (actions, (action_values, )), core_state else: return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor = actors.RecurrentActor(policy, jax.random.PRNGKey(1), initial_state, variable_client, has_extras=has_extras) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None, reward_spec: Optional[types.NestedSpec] = None): """Common setup code that, unlike self.setUp, takes arguments. Args: discount_spec: None, or a (nested) specs.BoundedArray. reward_spec: None, or a (nested) specs.Array. Returns: environment, actor, loop """ env_kwargs = {'episode_length': EPISODE_LENGTH} if discount_spec: env_kwargs['discount_spec'] = discount_spec if reward_spec: env_kwargs['reward_spec'] = reward_spec environment = fakes.DiscreteEnvironment(**env_kwargs) actor = fakes.Actor(specs.make_environment_spec(environment)) loop = environment_loop.EnvironmentLoop(environment, actor) return actor, loop
def __init__( self, eval_actor, environment, num_episodes, counter, logger, eval_sync=None, progress_counter_name='actor_steps', min_steps_between_evals=None, self_cleanup=False, observers=(), ): super().__init__() assert num_episodes >= 1 self._eval_actor = eval_actor self._num_episodes = num_episodes self._counter = counter self._logger = logger # Create the run loop and return it. self._env = environment self._environment_loop = environment_loop.EnvironmentLoop( environment, eval_actor, should_update=False, observers=observers) self._eval_sync = eval_sync or (lambda _: None) self._progress_counter_name = progress_counter_name self._last_steps = None self._eval_every_steps = min_steps_between_evals self._pending_tear_down = False if self_cleanup: # Do not rely on the instance owner to cleanup this evaluator. # Register a signal handler to perform some resource cleanup. try: signal.signal(signal.SIGTERM, self._signal_handler) # pytype: disable=wrong-arg-types except ValueError: logging.warning( 'Caught ValueError when registering signal handler. ' 'This probably means we are not running in the main thread. ' )
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @hk.transform def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state) @hk.transform def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState ) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) return jnp.argmax(action_values, axis=-1), core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.RecurrentActor( policy, hk.PRNGSequence(1), initial_state, variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def build_actor( random_key: networks_lib.PRNGKey, replay: reverb.Client, variable_source: core.VariableSource, counter: counting.Counter, actor_id: ActorId, ) -> environment_loop.EnvironmentLoop: """The actor process.""" environment_key, actor_key = jax.random.split(random_key) # Create environment and policy core. # Environments normally require uint32 as a seed. environment = experiment.environment_factory( utils.sample_uint32(environment_key)) environment_spec = specs.make_environment_spec(environment) networks = experiment.network_factory(environment_spec) policy_network = config.make_policy(experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=False) adder = experiment.builder.make_adder(replay, environment_spec, policy_network) actor = experiment.builder.make_actor(actor_key, policy_network, environment_spec, variable_source, adder) # Create logger and counter. counter = counting.Counter(counter, 'actor') logger = experiment.logger_factory('actor', counter.get_steps_key(), actor_id) # Create the loop to connect environment and agent. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=experiment.observers)
# Create a logger for the agent and environment loop. agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.) env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.) # Create the D4PG agent. agent = d4pg.D4PG(environment_spec=environment_spec, policy_network=policy_network, critic_network=critic_network, observation_network=observation_network, sigma=1.0, logger=agent_logger, checkpoint=False) # Create an loop connecting this agent to the environment created above. env_loop = environment_loop.EnvironmentLoop(environment, agent, logger=env_loop_logger) # Run a `num_episodes` training episodes. # Rerun this cell until the agent has learned the given task. env_loop.run(num_episodes=5000) @tf.function(input_signature=[tf.TensorSpec(shape=(1, 32), dtype=np.float32)]) def policy_inference(x): return policy_network(x) p_save = snt.Module() p_save.inference = policy_inference p_save.all_variables = list(policy_network.variables)
def setUp(self): super().setUp() # Create the actor/environment and stick them in a loop. environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH) self.actor = fakes.Actor(specs.make_environment_spec(environment)) self.loop = environment_loop.EnvironmentLoop(environment, self.actor)