def actor_evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') actor = actors.GenericActor(actor_core, random_key, variable_client, backend='cpu') # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = td3.make_networks(environment_spec) # Create the learner. learner = td3.TD3Learner( networks=networks, random_key=key_learner, discount=FLAGS.discount, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate), use_sarsa_target=FLAGS.use_sarsa_target, bc_alpha=FLAGS.bc_alpha, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: del key return networks.policy_network.apply(params, observation) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment and grab the spec. environment = bc_utils.make_environment() environment_spec = specs.make_environment_spec(environment) # Unwrap the environment to get the demonstrations. dataset = bc_utils.make_demonstrations(environment.environment, FLAGS.batch_size) dataset = dataset.as_numpy_iterator() # Create the networks to optimize. network = bc_utils.make_network(environment_spec) key = jax.random.PRNGKey(FLAGS.seed) key, key1 = jax.random.split(key, 2) def logp_fn(logits, actions): logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1) logits_actions = logits_actions - special.logsumexp(logits, axis=-1) return logits_actions loss_fn = bc.logp(logp_fn=logp_fn) learner = bc.BCLearner(network=network, random_key=key1, loss_fn=loss_fn, optimizer=optax.adam(FLAGS.learning_rate), demonstrations=dataset, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample( key, dist_params) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset. transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions_iterator, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = cql.make_networks(environment_spec) # Create the learner. learner = cql.CQLLearner( batch_size=FLAGS.batch_size, networks=networks, random_key=key_learner, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), fixed_cql_coefficient=FLAGS.fixed_cql_coefficient, cql_lagrange_threshold=FLAGS.cql_lagrange_threshold, demonstrations=demonstrations, num_sgd_steps_per_step=1) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = networks.policy_network.apply(params, observation) return networks.sample_eval(dist_params, key) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: dqn_config.DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). loss_fn = losses.PrioritizedDoubleQLearning( discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, ) learner = learning_lib.SGDLearner( network=network, loss_fn=loss_fn, data_iterator=reverb_replay.data_iterator, optimizer=optimizer, target_update_period=config.target_update_period, random_key=key_learner, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. assert config.epsilon is not Sequence def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample(key, action_values) actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) variable_client = variable_utils.VariableClient(learner, '') actor = actors.GenericActor( actor_core, key_actor, variable_client, reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def apply_policy_and_sample( networks, eval_mode = False): """Returns a function that computes actions.""" sample_fn = networks.sample if not eval_mode else networks.sample_eval if not sample_fn: raise ValueError('sample function is not provided') def apply_and_sample(params, key, obs): return sample_fn(networks.policy_network.apply(params, obs), key) return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
def apply_policy_and_sample_with_img_encoder(networks, eval_mode=False): """Returns a function that computes actions.""" sample_fn = networks.sample if not eval_mode else networks.sample_eval if not sample_fn: raise ValueError('sample function is not provided') def apply_and_sample(params, key, obs): img = obs['state_image'] img_embedding = networks.img_encoder.apply(params[1], img) x = dict(state_image=img_embedding, state_dense=obs['state_dense']) return sample_fn(networks.policy_network.apply(params[0], x), key) return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None) -> acme.Actor: assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy_network) variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') return actors.GenericActor( actor_core, random_key, variable_client, adder, backend='cpu')
def make_actor( self, random_key: networks_lib.PRNGKey, policy: actor_core_lib.FeedForwardPolicy, environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: del environment_spec assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) variable_client = variable_utils.VariableClient( variable_source, 'policy', device='cpu') return actors.GenericActor( actor_core, random_key, variable_client, backend='cpu')
def build_q_filtered_actor( networks, num_samples, with_uniform=True, ): def select_action( params, key, obs, ): key, sub_key = jax.random.split(key) policy_params = params[0] q_params = params[1] dist = networks.policy_network.apply(policy_params, obs) acts = dist._sample_n(num_samples, sub_key) acts = acts[:, 0, :] # N x act_dim if with_uniform: key, sub_key = jax.random.split(sub_key) unif_acts = jax.random.uniform(sub_key, acts.shape, dtype=acts.dtype, minval=-1., maxval=1.) acts = jnp.concatenate([acts, unif_acts], axis=0) def obs_tile_fn(t): # t = jnp.expand_dims(t, axis=0) tile_shape = [1] * t.ndim # tile_shape[0] = num_samples tile_shape[0] = acts.shape[0] return jnp.tile(t, tile_shape) tiled_obs = jax.tree_map(obs_tile_fn, obs) # batch_size x num_critics all_q = networks.q_network.apply(q_params, tiled_obs, acts) # num_devices x num_per_device x batch_size q_score = jnp.min(all_q, axis=-1) best_idx = jnp.argmax(q_score) # return acts[best_idx], key return acts[best_idx][None, :] # return actor_core.ActorCore( # init=lambda key: key, # select_action=select_action, # get_extras=lambda x: ()) return actor_core.batched_feed_forward_to_actor_core(select_action)
def apply_policy_and_sample(networks, eval_mode=False, use_img_encoder=False): """Returns a function that computes actions.""" sample_fn = networks.sample if not eval_mode else networks.sample_eval if not sample_fn: raise ValueError('sample function is not provided') def apply_and_sample(params, key, obs): if use_img_encoder: params, encoder_params = params[0], params[1] obs = { 'state_image': networks.img_encoder.apply(encoder_params, obs['state_image']), 'state_dense': obs['state_dense'] } return sample_fn(networks.policy_network.apply(params, obs), key) return actor_core.batched_feed_forward_to_actor_core(apply_and_sample)
def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None, ) -> core.Actor: assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') return actors.GenericActor(actor_core, random_key, variable_client, adder, backend='cpu')
def make_actor(self, random_key, policy_network, adder=None, variable_source=None): assert variable_source is not None actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy_network) variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') if self._config.use_random_actor: ACTOR = contrastive_utils.InitiallyRandomActor # pylint: disable=invalid-name else: ACTOR = actors.GenericActor # pylint: disable=invalid-name return ACTOR(actor_core, random_key, variable_client, adder, backend='cpu')
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def build_q_filtered_actor( networks, beta, num_samples, use_img_encoder=False, with_uniform=True, ensemble_method='deep_ensembles', ensemble_size=None, # not used for deep ensembles mimo_using_obs_tile=False, mimo_using_act_tile=False, ): if ensemble_method not in [ 'deep_ensembles', 'mimo', ]: raise NotImplementedError() def select_action( params, key, obs, ): key, sub_key = jax.random.split(key) policy_params = params[0] all_q_params = params[1] if use_img_encoder: img_encoder_params = params[2] obs = { 'state_image': networks.img_encoder.apply(img_encoder_params, obs['state_image']), 'state_dense': obs['state_dense'] } dist = networks.policy_network.apply(policy_params, obs) acts = dist._sample_n(num_samples, sub_key) acts = acts[:, 0, :] # N x act_dim if with_uniform: key, sub_key = jax.random.split(sub_key) unif_acts = jax.random.uniform(sub_key, acts.shape, dtype=acts.dtype, minval=-1., maxval=1.) acts = jnp.concatenate([acts, unif_acts], axis=0) if ensemble_method == 'deep_ensembles': get_all_q_values = jax.pmap(jax.vmap(networks.q_network.apply, in_axes=(0, None, None), out_axes=0), in_axes=(0, None, None), out_axes=0) elif ensemble_method == 'mimo': get_all_q_values = jax.pmap(jax.vmap(networks.q_network.apply, in_axes=(0, None, None), out_axes=0), in_axes=(0, None, None), out_axes=0) else: raise NotImplementedError() def obs_tile_fn(t): # t = jnp.expand_dims(t, axis=0) tile_shape = [1] * t.ndim # tile_shape[0] = num_samples tile_shape[0] = acts.shape[0] return jnp.tile(t, tile_shape) tiled_obs = jax.tree_map(obs_tile_fn, obs) if ensemble_method == 'deep_ensembles': # num_devices x num_per_device x batch_size x 2(because of double-Q) all_q = get_all_q_values(all_q_params, tiled_obs, acts) # num_devices x num_per_device x batch_size all_q = jnp.min(all_q, axis=-1) q_mean = jnp.mean(all_q, axis=(0, 1)) q_std = jnp.std(all_q, axis=(0, 1)) q_score = q_mean + beta * q_std # batch_size best_idx = jnp.argmax(q_score) elif ensemble_method == 'mimo': if mimo_using_obs_tile: # if using the version where we also tile the obs tile_shape = [1] * tiled_obs.ndim tile_shape[-1] = ensemble_size tiled_obs = jnp.tile(tiled_obs, tile_shape) if mimo_using_act_tile: # if using the version where we are tiling the acts tile_shape = [1] * acts.ndim tile_shape[-1] = ensemble_size tiled_acts = jnp.tile(acts, tile_shape) else: # otherwise tiled_acts = acts all_q = get_all_q_values( all_q_params, tiled_obs, tiled_acts ) # 1 x 1 x batch_size x ensemble_size x (num_qs_per_member) all_q = jnp.min(all_q, axis=-1) # 1 x 1 x batch_size x ensemble_size q_mean = jnp.mean(all_q, axis=(0, 1, 3)) q_std = jnp.std(all_q, axis=(0, 1, 3)) q_score = q_mean + beta * q_std # batch_size best_idx = jnp.argmax(q_score) else: raise NotImplementedError() # return acts[best_idx], key return acts[best_idx][None, :] # return actor_core.ActorCore( # init=lambda key: key, # select_action=select_action, # get_extras=lambda x: ()) # return select_action return actor_core.batched_feed_forward_to_actor_core(select_action)