def make_actor(self, random_key, policy_network, adder=None, variable_source=None, force_eval_with_q_filter=False): assert variable_source is not None if self._config.eval_with_q_filter or force_eval_with_q_filter: params_to_get = ['policy', 'all_q'] if self._config.use_img_encoder: params_to_get.append('img_encoder') return actors.GenericActor( actor=policy_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there. variable_client=variable_utils.VariableClient(variable_source, params_to_get, device='cpu'), adder=adder, ) else: params_to_get = ['policy'] if self._config.use_img_encoder: params_to_get.append('img_encoder') return actors.GenericActor( actor=policy_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there. variable_client=variable_utils.VariableClient(variable_source, params_to_get, device='cpu'), adder=adder, )
def make_actor( self, random_key, policy_network, adder = None, variable_source = None): assert variable_source is not None if self._config.use_img_encoder: return actors.GenericActor( actor=policy_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient( variable_source, ['policy', 'img_encoder'], device='cpu'), adder=adder, ) else: return actors.GenericActor( actor=policy_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient( variable_source, ['policy',], device='cpu'), adder=adder, )
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = td3.make_networks(environment_spec) # Create the learner. learner = td3.TD3Learner( networks=networks, random_key=key_learner, discount=FLAGS.discount, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), use_sarsa_target=FLAGS.use_sarsa_target, bc_alpha=FLAGS.bc_alpha, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: del key return networks.policy_network.apply(params, observation) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def make_actor( self, random_key: networks_lib.PRNGKey, policy: Tuple[str, networks_lib.FeedForwardNetwork], environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, adder: Optional[adders.Adder] = None, ) -> acme.Actor: del environment_spec assert variable_source is not None kname, policy = policy normalization_apply_fn = (running_statistics.normalize if self._config.normalize_observations else (lambda a, b: a)) policy_to_run = get_policy(policy, normalization_apply_fn) actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy_to_run) variable_client = variable_utils.VariableClient(variable_source, kname, device='cpu') return actors.GenericActor(actor_core, random_key, variable_client, adder, backend='cpu', per_episode_update=True)
def actor_evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') actor = actors.GenericActor(actor_core, random_key, variable_client, backend='cpu') # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def make_actor(actor_core: ActorCore, random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, adder: Optional[adders.Adder] = None) -> core.Actor: """Creates an MBOP actor from an actor core. Args: actor_core: An MBOP actor core. random_key: JAX Random key. variable_source: The source to get networks parameters from. adder: An adder to add experiences to. The `extras` of the adder holds the state of the recurrent policy. If `has_extras=True` then the `extras` part returned from the recurrent policy is appended to the state before added to the adder. Returns: A recurrent actor. """ variable_client = variable_utils.VariableClient(client=variable_source, key=[ 'world_model-policy', 'policy_prior-policy', 'n_step_return-policy' ]) return actors.GenericActor(actor_core, random_key, variable_client, adder, backend=None)
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset. transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = cql.make_networks(environment_spec) # Create the learner. learner = cql.CQLLearner( batch_size=FLAGS.batch_size, networks=networks, random_key=key_learner, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, demonstrations=demonstrations, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = networks.policy_network.apply(params, observation) return networks.sample_eval(dist_params, key) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: dqn_config.DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). loss_fn = losses.PrioritizedDoubleQLearning( discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, ) learner = learning_lib.SGDLearner( network=network, loss_fn=loss_fn, data_iterator=reverb_replay.data_iterator, optimizer=optimizer, target_update_period=config.target_update_period, random_key=key_learner, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. assert config.epsilon is not Sequence def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample(key, action_values) actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) variable_client = variable_utils.VariableClient(learner, '') actor = actors.GenericActor( actor_core, key_actor, variable_client, reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None) -> acme.Actor: assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy_network) variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') return actors.GenericActor( actor_core, random_key, variable_client, adder, backend='cpu')
def make_actor( self, random_key: networks_lib.PRNGKey, policy: actor_core_lib.FeedForwardPolicy, environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: del environment_spec assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) variable_client = variable_utils.VariableClient( variable_source, 'policy', device='cpu') return actors.GenericActor( actor_core, random_key, variable_client, backend='cpu')
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: assert variable_source is not None actor = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy_network) variable_client = variable_utils.VariableClient( variable_source, 'network', device='cpu', update_period=self._config.variable_update_period) return actors.GenericActor( actor, random_key, variable_client, adder, backend='cpu')
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor_core = actor_core_lib.batched_recurrent_to_actor_core( policy, initial_state) actor = actors.GenericActor(actor_core, jax.random.PRNGKey(1), variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') return actors.GenericActor(actor_core, random_key, variable_client, adder, backend='cpu')
def make_actor( self, random_key: networks_lib.PRNGKey, policy: r2d2_actor.R2D2Policy, environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, adder: Optional[adders.Adder] = None, ) -> acme.Actor: del environment_spec # Create variable client. variable_client = variable_utils.VariableClient( variable_source, key='actor_variables', update_period=self._config.variable_update_period) return actors.GenericActor(policy, random_key, variable_client, adder, backend='cpu')
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network: dqn_actor.EpsilonPolicy, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: assert variable_source is not None # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, '', device='cpu') epsilon = self._config.epsilon epsilons = epsilon if epsilon is Sequence else (epsilon, ) actor_core = dqn_actor.alternating_epsilons_actor_core( policy_network, epsilons=epsilons) return actors.GenericActor(actor=actor_core, random_key=random_key, variable_client=variable_client, adder=adder, backend=self._actor_backend)
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None ) -> acme.Actor: # Create variable client. variable_client = variable_utils.VariableClient( variable_source, key='actor_variables', update_period=self._config.variable_update_period) # TODO(b/186613827) move this to # - the actor __init__ function - this is a good place if it is specific # for R2D2. # - the EnvironmentLoop - this is a good place if it potentially applies # for all actors. # # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() initial_state_key1, initial_state_key2, random_key = jax.random.split( random_key, 3) actor_initial_state_params = self._networks.initial_state.init( initial_state_key1, 1) actor_initial_state = self._networks.initial_state.apply( actor_initial_state_params, initial_state_key2, 1) actor_core = r2d2_actor.get_actor_core(policy_network, actor_initial_state, self._config.num_epsilons) return actors.GenericActor(actor_core, random_key, variable_client, adder, backend='cpu')
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)