def _get_extra_specs(self) -> Any: """helper to establish specs for extra information Returns: Dict[str, Any]: dictionary containing extra specs """ agents = self._environment_spec.get_agent_ids() core_state_specs = {} core_message_specs = {} networks = self._network_factory( # type: ignore environment_spec=self._environment_spec ) for agent in agents: agent_type = agent.split("_")[0] core_state_specs[agent] = ( tf2_utils.squeeze_batch_dim( networks["q_networks"][agent_type].initial_state(1) ), ) if self._communication_module_fn is not None: core_message_specs[agent] = ( tf2_utils.squeeze_batch_dim( networks["q_networks"][agent_type].initial_message(1) ), ) extras = { "core_states": core_state_specs, "core_messages": core_message_specs, } return extras
def replay(self) -> List[reverb.Table]: """The replay storage.""" network = self._network_factory(self._environment_spec.actions) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) if self._samples_per_insert: limiter = reverb.rate_limiters.SampleToInsertRatio( min_size_to_sample=self._min_replay_size, samples_per_insert=self._samples_per_insert, error_buffer=self._batch_size) else: limiter = reverb.rate_limiters.MinSize(self._min_replay_size) table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(self._priority_exponent), remover=reverb.selectors.Fifo(), max_size=self._max_replay_size, rate_limiter=limiter, signature=adders.SequenceAdder.signature( self._environment_spec, extra_spec, sequence_length=self._burn_in_length + self._trace_length + 1)) return [table]
def make_replay_tables( self, environment_spec: specs.MAEnvironmentSpec, ) -> List[reverb.Table]: """Create tables to insert data into. Args: environment_spec (specs.MAEnvironmentSpec): description of the action and observation spaces etc. for each agent in the system. Returns: List[reverb.Table]: a list of data tables for inserting data. """ agent_specs = environment_spec.get_agent_specs() extras_spec: Dict[str, Dict[str, acme_types.NestedArray]] = {"log_probs": {}} for agent, spec in agent_specs.items(): # Make dummy log_probs extras_spec["log_probs"][agent] = tf.ones(shape=(1,), dtype=tf.float32) # Squeeze the batch dim. extras_spec = tf2_utils.squeeze_batch_dim(extras_spec) replay_table = reverb.Table.queue( name=self._config.replay_table_name, max_size=self._config.max_queue_size, signature=reverb_adders.ParallelSequenceAdder.signature( environment_spec, extras_spec=extras_spec ), ) return [replay_table]
def spec(output: tf.Tensor) -> tf.TensorSpec: # If the output is not a Tensor, return None as spec is ill-defined. if not isinstance(output, tf.Tensor): return None # If this is not a scalar Tensor, make sure to squeeze out the batch dim. if tf.rank(output) > 0: output = squeeze_batch_dim(output) return tf.TensorSpec(output.shape, output.dtype)
def create_variables( network: snt.Module, input_spec: List[OLT], ) -> Optional[tf.TensorSpec]: """Builds the network with dummy inputs to create the necessary variables. Args: network: Sonnet Module whose variables are to be created. input_spec: list of input specs to the network. The length of this list should match the number of arguments expected by `network`. Returns: output_spec: only returns an output spec if the output is a tf.Tensor, else it doesn't return anything (None); e.g. if the output is a tfp.distributions.Distribution. """ # Create a dummy observation with no batch dimension. dummy_input = [ OLT( observation=zeros_like(in_spec.observation), legal_actions=ones_like(in_spec.legal_actions), terminal=zeros_like(in_spec.terminal), ) for in_spec in input_spec ] # If we have an RNNCore the hidden state will be an additional input. if isinstance(network, snt.RNNCore): initial_state = squeeze_batch_dim(network.initial_state(1)) dummy_input += [initial_state] # Forward pass of the network which will create variables as a side effect. dummy_output = network(*add_batch_dim(dummy_input)) # Evaluate the input signature by converting the dummy input into a # TensorSpec. We then save the signature as a property of the network. This is # done so that we can later use it when creating snapshots. We do this here # because the snapshot code may not have access to the precise form of the # inputs. input_signature = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input) network._input_signature = input_signature # pylint: disable=protected-access def spec(output: tf.Tensor) -> tf.TensorSpec: # If the output is not a Tensor, return None as spec is ill-defined. if not isinstance(output, tf.Tensor): return None # If this is not a scalar Tensor, make sure to squeeze out the batch dim. if tf.rank(output) > 0: output = squeeze_batch_dim(output) return tf.TensorSpec(output.shape, output.dtype) return tree.map_structure(spec, dummy_output)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, queue: adder.Adder, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, n_step_horizon: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, verbose_level: Optional[int] = 0, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) tf2_utils.create_variables(network, [environment_spec.observations]) actor = acting.A2CActor(environment_spec=environment_spec, verbose_level=verbose_level, network=network, queue=queue) learner = learning.A2CLearner( environment_spec=environment_spec, network=network, dataset=queue, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, ) super().__init__(actor=actor, learner=learner, min_observations=0, observations_per_step=n_step_horizon)
def _get_extra_specs(self) -> Any: """helper to establish specs for extra information Returns: Dict[str, Any]: dictionary containing extra specs """ agents = self._environment_spec.get_agent_ids() core_state_specs = {} networks = self._network_factory( # type: ignore environment_spec=self._environment_spec) for agent in agents: agent_type = agent.split("_")[0] core_state_specs[agent] = (tf2_utils.squeeze_batch_dim( networks["policies"][agent_type].initial_state(1)), ) return {"core_states": core_state_specs}
def queue(self): """The queue.""" num_actions = self._environment_spec.actions.num_values network = self._network_factory(self._environment_spec.actions) extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) signature = adders.SequenceAdder.signature( self._environment_spec, extra_spec, sequence_length=self._sequence_length) queue = reverb.Table.queue( name=adders.DEFAULT_PRIORITY_TABLE, max_size=self._max_queue_size, signature=signature) return [queue]
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__(self, environment_spec: specs.EnvironmentSpec, policy_network: snt.Module, critic_network: snt.Module, encoder_network: types.TensorTransformation = tf.identity, entropy_coeff: float = 0.01, target_update_period: int = 0, discount: float = 0.99, batch_size: int = 256, policy_learn_rate: float = 3e-4, critic_learn_rate: float = 5e-4, prefetch_size: int = 4, min_replay_size: int = 1000, max_replay_size: int = 250000, samples_per_insert: float = 64.0, n_step: int = 5, sigma: float = 0.5, clipping: bool = True, logger: loggers.Logger = None, counter: counting.Counter = None, checkpoint: bool = True, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): """Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed into any network. discount: discount to use for TD updates. batch_size: batch size for updates. prefetch_size: size to prefetch from replay. target_update_period: number of learner steps to perform before updating the target networks. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. logger: logger object to be used by learner. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner. replay_table_name: string indicating what name to give the replay table. """ # Create a replay server to add data to. This uses no limiter behavior in # order to allow the Agent interface to handle it. dim_actions = np.prod(environment_spec.actions.shape, dtype=int) extra_spec = { 'logP': tf.ones(shape=(1), dtype=tf.float32), 'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=replay_table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(1), signature=adders.NStepTransitionAdder.signature( environment_spec, extras_spec=extra_spec)) self._server = reverb.Server([replay_table], port=None) # The adder is used to insert observations into replay. address = f'localhost:{self._server.port}' adder = adders.NStepTransitionAdder( priority_fns={replay_table_name: lambda x: 1.}, client=reverb.Client(address), n_step=n_step, discount=discount) # The dataset provides an interface to sample from replay. dataset = datasets.make_reverb_dataset(table=replay_table_name, server_address=address, batch_size=batch_size, prefetch_size=prefetch_size) # Make sure observation network is a Sonnet Module. observation_network = model.MDPNormalization(environment_spec, encoder_network) # Get observation and action specs. act_spec = environment_spec.actions obs_spec = environment_spec.observations # Create the behavior policy. sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma) self._behavior_network = model.PolicyValueBehaviorNet( snt.Sequential([observation_network, policy_network]), sampling_head) # Create variables. emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) tf2_utils.create_variables(policy_network, [emb_spec]) tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) # Create the actor which defines how we take actions. actor = model.SACFeedForwardActor(self._behavior_network, adder) if target_update_period > 0: target_policy_network = copy.deepcopy(policy_network) target_critic_network = copy.deepcopy(critic_network) target_observation_network = copy.deepcopy(observation_network) tf2_utils.create_variables(target_policy_network, [emb_spec]) tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) tf2_utils.create_variables(target_observation_network, [obs_spec]) else: target_policy_network = policy_network target_critic_network = critic_network target_observation_network = observation_network # Create optimizers. policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate) critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate) # The learner updates the parameters (and initializes them). learner = learning.SACLearner( policy_network=policy_network, critic_network=critic_network, sampling_head=sampling_head, observation_network=observation_network, target_policy_network=target_policy_network, target_critic_network=target_critic_network, target_observation_network=target_observation_network, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, target_update_period=target_update_period, learning_rate=policy_learn_rate, clipping=clipping, entropy_coeff=entropy_coeff, discount=discount, dataset=dataset, counter=counter, logger=logger, checkpoint=checkpoint, ) super().__init__(actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: float = 20000, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): if store_lstm_state: extra_spec = { 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), } else: extra_spec = () replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. self._adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._saver = tf2_savers.Saver(learner.state) policy_network = snt.DeepRNN([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) actor = actors.RecurrentActor(policy_network, self._adder, store_recurrent_state=store_lstm_state) max_Q_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(), ]) self._deterministic_actor = actors.RecurrentActor( max_Q_network, self._adder, store_recurrent_state=store_lstm_state) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = ( float(replay_period * batch_size) / samples_per_insert) super().__init__( actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size, signature=adders.SequenceAdder.signature( environment_spec, extras_spec=extra_spec, sequence_length=sequence_length)) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, batch_size=batch_size) tf2_utils.create_variables(network, [environment_spec.observations]) self._actor = acting.IMPALAActor(network, adder) self._learner = learning.IMPALALearner( environment_spec=environment_spec, network=network, dataset=dataset, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, )