def make_actor( self, random_key: networks_lib.PRNGKey, policy_network, adder: Optional[adders.Adder] = None, variable_source: Optional[core.VariableSource] = None ) -> acme.Actor: variable_client = variable_utils.VariableClient(client=variable_source, key='network', update_period=1000, device='cpu') return acting.IMPALAActor( forward_fn=policy_network.forward_fn, initial_state_init_fn=policy_network.initial_state_init_fn, initial_state_fn=policy_network.initial_state_fn, variable_client=variable_client, adder=adder, rng=hk.PRNGSequence(random_key), )
def make_actor( self, random_key: networks_lib.PRNGKey, policy: impala_networks.IMPALANetworks, environment_spec: specs.EnvironmentSpec, variable_source: Optional[core.VariableSource] = None, adder: Optional[adders.Adder] = None, ) -> acme.Actor: del environment_spec variable_client = variable_utils.VariableClient( client=variable_source, key='network', update_period=self._config.variable_update_period, device='cpu') return acting.IMPALAActor( forward_fn=policy.forward_fn, initial_state_fn=policy.initial_state_fn, variable_client=variable_client, adder=adder, rng=hk.PRNGSequence(random_key), )
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: networks.PolicyValueRNN, unroll_fn: networks.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): # Data is handled by the reverb replay queue. num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } reverb_queue = replay.make_reverb_online_queue( environment_spec=environment_spec, extra_spec=extra_spec, max_queue_size=max_queue_size, sequence_length=sequence_length, sequence_period=sequence_period, batch_size=batch_size, ) self._server = reverb_queue.server self._can_sample = reverb_queue.can_sample # Make the learner. optimizer = optax.chain( optax.clip_by_global_norm(max_gradient_norm), optax.adam(learning_rate), ) key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed)) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, unroll_fn=unroll_fn, initial_state_fn=initial_state_fn, iterator=reverb_queue.data_iterator, random_key=key_learner, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) # Make the actor. variable_client = variable_utils.VariableClient(self._learner, key='policy') transformed = hk.without_apply_rng(hk.transform(forward_fn)) self._actor = acting.IMPALAActor( forward_fn=jax.jit(transformed.apply, backend='cpu'), initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(key_actor), adder=reverb_queue.adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks.PolicyValueRNN, initial_state_fn: Callable[[], networks.RNNState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. extra_spec = { 'core_state': hk.transform(initial_state_fn).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } # Remove batch dimensions. dataset = datasets.make_reverb_dataset( client=reverb.TFClient(address), environment_spec=environment_spec, batch_size=batch_size, extra_spec=extra_spec, sequence_length=sequence_length) rng = hk.PRNGSequence(seed) optimizer = optix.chain( optix.clip_by_global_norm(max_gradient_norm), optix.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, network=network, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=rng, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = jax_variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( network=network, initial_state_fn=initial_state_fn, rng=rng, adder=adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: impala_types.PolicyValueFn, unroll_init_fn: impala_types.PolicyValueInitFn, unroll_fn: impala_types.PolicyValueFn, initial_state_init_fn: impala_types.RecurrentStateInitFn, initial_state_fn: impala_types.RecurrentStateFn, config: impala_config.IMPALAConfig, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, ): networks = impala_networks.IMPALANetworks( forward_fn=forward_fn, unroll_init_fn=unroll_init_fn, unroll_fn=unroll_fn, initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_fn, ) self._config = config # Data is handled by the reverb replay queue. num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') key, key_initial_state = jax.random.split( jax.random.PRNGKey(self._config.seed)) params = initial_state_init_fn(key_initial_state) extra_spec = { 'core_state': initial_state_fn(params), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } reverb_queue = replay.make_reverb_online_queue( environment_spec=environment_spec, extra_spec=extra_spec, max_queue_size=self._config.max_queue_size, sequence_length=self._config.sequence_length, sequence_period=self._config.sequence_period, batch_size=self._config.batch_size, ) self._server = reverb_queue.server self._can_sample = reverb_queue.can_sample # Make the learner. optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.adam(self._config.learning_rate), ) key_learner, key_actor = jax.random.split(key) self._learner = learning.IMPALALearner( networks=networks, iterator=reverb_queue.data_iterator, random_key=key_learner, counter=counter, logger=logger, optimizer=optimizer, discount=self._config.discount, entropy_cost=self._config.entropy_cost, baseline_cost=self._config.baseline_cost, max_abs_reward=self._config.max_abs_reward, ) # Make the actor. variable_client = variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( forward_fn=jax.jit(forward_fn, backend='cpu'), initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(key_actor), adder=reverb_queue.adder, variable_client=variable_client, )
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: networks.PolicyValueRNN, unroll_fn: networks.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': hk.without_apply_rng(hk.transform(initial_state_fn, apply_rng=True)).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } signature = adders.SequenceAdder.signature(environment_spec, extra_spec) queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size, signature=signature) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. # We don't use datasets.make_reverb_dataset() here to avoid interleaving # and prefetching, that doesn't work well with can_sample() check on update. dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=1, sequence_length=sequence_length, emit_timesteps=False) dataset = dataset.batch(batch_size, drop_remainder=True) optimizer = optax.chain( optax.clip_by_global_norm(max_gradient_norm), optax.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, unroll_fn=unroll_fn, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=hk.PRNGSequence(seed), counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( forward_fn=jax.jit(hk.without_apply_rng( hk.transform(forward_fn, apply_rng=True)).apply, backend='cpu'), initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(seed), adder=adder, variable_client=variable_client, )