def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Build environment, model, network. environment = self._environment_factory() network = self._network_factory(self._env_spec.actions) model = self._model_factory(self._env_spec) # Create variable client for communicating with the learner. tf2_utils.create_variables(network, [self._env_spec.observations]) variable_client = tf2_variable_utils.VariableClient( client=variable_source, variables={'policy': network.trainable_variables}, update_period=self._variable_update_period) # Create the agent. actor = acting.MCTSActor( environment_spec=self._env_spec, model=model, network=network, discount=self._discount, variable_client=variable_client, num_simulations=self._num_simulations, ) # Create the run loop and return it. logger = loggers.make_default_logger('evaluator') return acme.EnvironmentLoop(environment, actor, counter=counter, logger=logger)
def actor_evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) # Inference happens on CPU, so it's better to move variables there too. variable_client = variable_utils.VariableClient(variable_source, 'policy', device='cpu') actor = actors.GenericActor(actor_core, random_key, variable_client, backend='cpu') # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, ) -> acme.EnvironmentLoop: """The actor process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and target networks to act with. environment = self._environment_factory(False) agent_networks = self._network_factory(action_spec, self._num_critic_heads) # Make sure observation network is defined. observation_network = agent_networks.get('observation', tf.identity) # Create a stochastic behavior policy. behavior_network = snt.Sequential([ observation_network, agent_networks['policy'], networks.StochasticSamplingHead(), ]) # Ensure network variables are created. tf2_utils.create_variables(behavior_network, [observation_spec]) policy_variables = {'policy': behavior_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient(variable_source, policy_variables, update_period=1000) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Component to add things into replay. adder = adders.NStepTransitionAdder( client=replay, n_step=self._n_step, max_in_flight_items=self._max_in_flight_items, discount=self._additional_discount) # Create the agent. actor = actors.FeedForwardActor(policy_network=behavior_network, adder=adder, variable_client=variable_client) # Create logger and counter; actors will not spam bigtable. counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger('actor', save_data=False, time_delta=self._log_every, steps_key='actor_steps') # Create the run loop and return it. return acme.EnvironmentLoop(environment, actor, counter, logger)
def learner(self, queue: reverb.Client, counter: counting.Counter): """The Learning part of the agent.""" # Use architect and create the environment. # Create the networks. network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._environment_spec.observations]) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=queue.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size) logger = loggers.make_default_logger('learner', steps_key='learner_steps') counter = counting.Counter(counter, 'learner') # Return the learning agent. learner = learning.IMPALALearner( environment_spec=self._environment_spec, network=network, dataset=dataset, discount=self._discount, learning_rate=self._learning_rate, entropy_cost=self._entropy_cost, baseline_cost=self._baseline_cost, max_abs_reward=self._max_abs_reward, max_gradient_norm=self._max_gradient_norm, counter=counter, logger=logger, ) return tf2_savers.CheckpointingRunner(learner, time_delta_minutes=5, subdirectory='impala_learner')
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" environment = self._environment_factory(True) network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._obs_spec]) policy_network = snt.DeepRNN([ network, lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32), ]) variable_client = tf2_variable_utils.VariableClient( client=variable_source, variables={'policy': policy_network.variables}, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Create the agent. actor = actors.RecurrentActor( policy_network=policy_network, variable_client=variable_client) # Create the run loop and return it. logger = loggers.make_default_logger( 'evaluator', save_data=True, steps_key='evaluator_steps') counter = counting.Counter(counter, 'evaluator') return acme.EnvironmentLoop(environment, actor, counter, logger)
def evaluator( random_key: networks_lib.PRNGKey, variable_source: core.VariableSource, counter: counting.Counter, make_actor: MakeActorFn, ): """The evaluation process.""" # Create environment and evaluator networks environment_key, actor_key = jax.random.split(random_key) # Environments normally require uint32 as a seed. environment = environment_factory(utils.sample_uint32(environment_key)) networks = network_factory(specs.make_environment_spec(environment)) actor = make_actor(actor_key, policy_factory(networks), variable_source) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') if logger_fn is not None: logger = logger_fn('evaluator', 'actor_steps') else: logger = loggers.make_default_logger('evaluator', log_to_bigtable, steps_key='actor_steps') # Create the run loop and return it. return environment_loop.EnvironmentLoop(environment, actor, counter, logger, observers=observers)
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, logger: loggers.Logger = None, ): """The evaluation process.""" # Create the behavior policy. networks = self._network_factory(self._environment_spec.actions) networks.init(self._environment_spec) policy_network = networks.make_policy(self._environment_spec) # Create the agent. actor = self._builder.make_actor( policy_network=policy_network, variable_source=variable_source, ) # Make the environment. environment = self._environment_factory(True) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = logger or loggers.make_default_logger( 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps', ) # Create the run loop and return it. return acme.EnvironmentLoop(environment, actor, counter, logger)
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" # Create the networks to optimize (online) and target networks. online_networks = self._network_factory(self._environment_spec.actions) target_networks = copy.deepcopy(online_networks) # Initialize the networks. online_networks.init(self._environment_spec) target_networks.init(self._environment_spec) dataset = self._builder.make_dataset_iterator(replay) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') return self._builder.make_learner( networks=(online_networks, target_networks), dataset=dataset, counter=counter, logger=logger, )
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, ) -> acme.EnvironmentLoop: """The actor process.""" # Create the behavior policy. networks = self._network_factory(self._environment_spec.actions) networks.init(self._environment_spec) policy_network = networks.make_policy( environment_spec=self._environment_spec, sigma=self._sigma, ) # Create the agent. actor = self._builder.make_actor( policy_network=policy_network, adder=self._builder.make_adder(replay), variable_source=variable_source, ) # Create the environment. environment = self._environment_factory(False) # Create logger and counter; actors will not spam bigtable. counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger('actor', save_data=False, time_delta=self._log_every, steps_key='actor_steps') # Create the loop to connect environment and agent. return acme.EnvironmentLoop(environment, actor, counter, logger)
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" # If we are running on multiple accelerator devices, this replicates # weights and updates across devices. replicator = agent.get_replicator(self._accelerator) with replicator.scope(): # Create the networks to optimize (online) and target networks. online_networks = self._network_factory( self._environment_spec.actions) target_networks = copy.deepcopy(online_networks) # Initialize the networks. online_networks.init(self._environment_spec) target_networks.init(self._environment_spec) dataset = self._builder.make_dataset_iterator(replay) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every, steps_key='learner_steps') return self._builder.make_learner( networks=(online_networks, target_networks), dataset=dataset, counter=counter, logger=logger, checkpoint=True, )
def learner( self, random_key: networks_lib.PRNGKey, counter: counting.Counter, ): """The Learning part of the agent.""" # Counter and logger. counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger( 'learner', self._save_logs, time_delta=self._log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') # Create the learner. networks = self._network_factory() learner = self._make_learner(random_key, networks, counter, logger) kwargs = { 'directory': self._workdir, 'add_uid': self._workdir == '~/acme' } # Return the learning agent. return savers.CheckpointingRunner(learner, subdirectory='learner', time_delta_minutes=5, **kwargs)
def run_ppo_agent(self, make_networks_fn): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) distribution_value_networks = make_networks_fn(spec) ppo_networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=4, num_epochs=2, num_minibatches=2) workdir = self.create_tempdir() counter = counting.Counter() logger = loggers.make_default_logger('learner') # Construct the agent. agent = ppo.PPO( spec=spec, networks=ppo_networks, config=config, seed=0, workdir=workdir.full_path, normalize_input=True, counter=counter, logger=logger, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=20)
def evaluator(self, variable_source: acme.VariableSource, counter: counting.Counter): """The evaluation process.""" environment = self._environment_factory(True) network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._environment_spec.observations]) variable_client = tf2_variable_utils.VariableClient( client=variable_source, variables={'policy': network.variables}, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Create the agent. actor = acting.IMPALAActor( network=network, variable_client=variable_client) # Create the run loop and return it. logger = loggers.make_default_logger( 'evaluator', steps_key='evaluator_steps') counter = counting.Counter(counter, 'evaluator') return acme.EnvironmentLoop(environment, actor, counter, logger)
def make_experiment_logger(label: str, steps_key: Optional[str] = None, task_instance: int = 0) -> loggers.Logger: del task_instance if steps_key is None: steps_key = f'{label}_steps' return loggers.make_default_logger(label=label, steps_key=steps_key)
def actor_evaluator( variable_source: core.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create the actor loading the weights from variable source. actor = actors.FeedForwardActor( policy=evaluator_network, random_key=random_key, # Inference happens on CPU, so it's better to move variables there too. variable_client=variable_utils.VariableClient(variable_source, 'policy', device='cpu')) # Logger. logger = loggers.make_default_logger('evaluator', steps_key='evaluator_steps') # Create environment and evaluator networks environment = environment_factory(False) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') # Create the run loop and return it. return environment_loop.EnvironmentLoop( environment, actor, counter, logger, )
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, ): """The actor process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and behavior networks environment = self._environment_factory(False) agent_networks = self._network_factory(action_spec) # Create behavior network by adding some random dithering. behavior_network = snt.Sequential([ agent_networks.get('observation', tf.identity), agent_networks.get('policy'), networks.ClippedGaussian(self._sigma), ]) # Ensure network variables are created. tf2_utils.create_variables(behavior_network, [observation_spec]) variables = {'policy': behavior_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, variables, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Component to add things into replay. adder = adders.NStepTransitionAdder(client=replay, n_step=self._n_step, discount=self._discount) # Create the agent. actor = actors.FeedForwardActor(behavior_network, adder=adder, variable_client=variable_client) # Create logger and counter; actors will not spam bigtable. counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger('actor', save_data=False, time_delta=self._log_every, steps_key='actor_steps') # Create the loop to connect environment and agent. return acme.EnvironmentLoop(environment, actor, counter, logger)
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and target networks to act with. environment = self._environment_factory(True) agent_networks = self._network_factory(action_spec) # Make sure observation network is defined. observation_network = agent_networks.get('observation', tf.identity) # Create a stochastic behavior policy. evaluator_network = snt.Sequential([ observation_network, agent_networks['policy'], networks.StochasticMeanHead(), ]) # Ensure network variables are created. tf2_utils.create_variables(evaluator_network, [observation_spec]) policy_variables = {'policy': evaluator_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, policy_variables, update_period=self._variable_update_period) # Make sure not to evaluate a random actor by assigning variables before # running the environment loop. variable_client.update_and_wait() # Create the agent. evaluator = actors.FeedForwardActor( policy_network=evaluator_network, variable_client=variable_client) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = loggers.make_default_logger( 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') observers = self._make_observers() if self._make_observers else () # Create the run loop and return it. return acme.EnvironmentLoop( environment, evaluator, counter, logger, observers=observers)
def __init__(self, direct_rl_learner_factory: Callable[ [Any, Iterator[reverb.ReplaySample]], acme.Learner], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, rnd_network: rnd_networks.RNDNetworks, rng_key: jnp.ndarray, grad_updates_per_batch: int, is_sequence_based: bool, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): self._is_sequence_based = is_sequence_based target_key, predictor_key = jax.random.split(rng_key) target_params = rnd_network.target.init(target_key) predictor_params = rnd_network.predictor.init(predictor_key) optimizer_state = optimizer.init(predictor_params) self._state = RNDTrainingState(optimizer_state=optimizer_state, params=predictor_params, target_params=target_params, steps=0) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) loss = functools.partial(rnd_loss, networks=rnd_network) self._update = functools.partial(rnd_update_step, loss_fn=loss, optimizer=optimizer) self._update = utils.process_multiple_batches(self._update, grad_updates_per_batch) self._update = jax.jit(self._update) self._get_reward = jax.jit( functools.partial(rnd_networks.compute_rnd_reward, networks=rnd_network)) # Generator expression that works the same as an iterator. # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions updated_iterator = (self._process_sample(sample) for sample in iterator) self._direct_rl_learner = direct_rl_learner_factory( rnd_network.direct_rl_networks, updated_iterator) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, environment, actor, rewarder, counter=None, logger=None): self._environment = environment self._actor = actor self._rewarder = rewarder self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger()
def __init__( self, environment, actor, counter=None, logger=None, label='environment_loop', ): # Internalize agent and environment. self._environment = environment self._actor = actor self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label)
def __init__(self, environment_spec: specs.EnvironmentSpec, action_spec: specs.BoundedArray, z_dim: int) -> None: self._z_dim = z_dim z_spec = specs.BoundedArray((z_dim, ), np.float64, minimum=0, maximum=1) # Modify the environment_spec to also include the latent variable # observation (z) self._obs_space = environment_spec.observations assert ( len(self._obs_space.shape) == 1 ), f"Only vector observations are supported for now. Observations shape passed: {obs_shape}" updated_observations = specs.BoundedArray( (self._obs_space.shape[0] + z_dim, ), dtype=environment_spec.observations.dtype, name=environment_spec.observations.name, minimum=np.append(environment_spec.observations.minimum, [0] * z_dim), maximum=np.append(environment_spec.observations.maximum, [0] * z_dim), ) environment_spec = specs.EnvironmentSpec( observations=updated_observations, actions=environment_spec.actions, rewards=environment_spec.rewards, discounts=environment_spec.discounts, ) self._agent_networks = make_feed_forward_networks(action_spec, z_spec) self._agent = dmpo.DistributionalMPO( environment_spec=environment_spec, policy_network=self._agent_networks['policy'], critic_network=self._agent_networks['critic'], observation_network=self._agent_networks['observation'], # pytype: disable=wrong-arg-types extra_modules_to_save={ 'discriminator': self._agent_networks['discriminator'], }, return_action_entropy=True, ) self._z_distribution = tfd.Categorical([1] * z_dim) self._current_z = self._z_distribution.sample() # Create discriminator optimizer. self._discriminator_optimizer = snt.optimizers.Adam(1e-4) self._discriminator_logger = loggers.make_default_logger( 'discriminator') # Create variables for the discriminator. tf2_utils.create_variables(self._agent_networks['discriminator'], [self._obs_space])
def __init__( self, environment: dm_env.Environment, actor: core.Actor, counter: counting.Counter = None, logger: loggers.Logger = None, label: str = 'environment_loop', ): # Internalize agent and environment. self._environment = environment self._actor = actor self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) self.begin_time = time.time()
def __init__( self, environment: dm_env.Environment, actor: core.Actor, counter: counting.Counter = None, logger: loggers.Logger = None, should_update: bool = True, label: str = 'environment_loop', ): # Internalize agent and environment. self._environment = environment self._actor = actor self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) self._should_update = should_update
def __init__(self, spec: specs.EnvironmentSpec, networks: networks_lib.FeedForwardNetwork, rng: networks_lib.PRNGKey, config: ars_config.ARSConfig, iterator: Iterator[reverb.ReplaySample], counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): self._config = config self._lock = threading.Lock() # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( 'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key=self._counter.get_steps_key()) # Iterator on demonstration transitions. self._iterator = iterator if self._config.normalize_observations: normalizer_params = running_statistics.init_state( spec.observations) self._normalizer_update_fn = running_statistics.update else: normalizer_params = () self._normalizer_update_fn = lambda a, b: a rng1, rng2, tmp = jax.random.split(rng, 3) # Create initial state. self._training_state = TrainingState( key=rng1, policy_params=networks.init(tmp), normalizer_params=normalizer_params, training_iteration=0) self._evaluation_state = EvaluationState( key=rng2, evaluation_queue=collections.deque(), received_results={}, noises=[]) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, epsilon: float, ) -> acme.EnvironmentLoop: """The actor process.""" environment = self._environment_factory(False) network = self._network_factory(self._environment_spec.actions) tf2_utils.create_variables(network, [self._obs_spec]) policy_network = snt.DeepRNN([ network, lambda qs: tf.cast(trfl.epsilon_greedy(qs, epsilon).sample(), tf.int32), ]) # Component to add things into replay. sequence_length = self._burn_in_length + self._trace_length + 1 adder = adders.SequenceAdder( client=replay, period=self._replay_period, sequence_length=sequence_length, delta_encoded=True, ) variable_client = tf2_variable_utils.VariableClient( client=variable_source, variables={'policy': policy_network.variables}, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Create the agent. actor = actors.RecurrentActor( policy_network=policy_network, variable_client=variable_client, adder=adder) counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger( 'actor', save_data=False, steps_key='actor_steps') # Create the loop to connect environment and agent. return acme.EnvironmentLoop(environment, actor, counter, logger)
def __init__( self, environment: dm_env.Environment, executor: mava.core.Executor, counter: counting.Counter = None, logger: loggers.Logger = None, should_update: bool = True, label: str = "parallel_environment_loop", ): # Internalize agent and environment. self._environment = environment self._executor = executor self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) self._should_update = should_update self._running_statistics: Dict[str, float] = {}
def __init__( self, actor_id, environment_module, environment_fn_name, environment_kwargs, network_module, network_fn_name, network_kwargs, adder_module, adder_fn_name, adder_kwargs, replay_server_address, variable_server_name, variable_server_address, counter: counting.Counter = None, logger: loggers.Logger = None, ): # Counter and Logger self._actor_id = actor_id self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger( f'actor_{actor_id}') # Create the environment self._environment = getattr(environment_module, environment_fn_name)(**environment_kwargs) env_spec = acme.make_environment_spec(self._environment) # Create actor's network self._network = getattr(network_module, network_fn_name)(**network_kwargs) tf2_utils.create_variables(self._network, [env_spec.observations]) self._variables = tree.flatten(self._network.variables) self._policy = tf.function(self._network) # The adder is used to insert observations into replay. self._adder = getattr(adder_module, adder_fn_name)( reverb.Client(replay_server_address), **adder_kwargs) variable_client = reverb.TFClient(variable_server_address) self._variable_dataset = variable_client.dataset( table=variable_server_name, dtypes=[tf.float32 for _ in self._variables], shapes=[v.shape for v in self._variables])
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and evaluator networks environment = self._environment_factory(True) agent_networks = self._network_factory(action_spec) # Create evaluator network. evaluator_network = snt.Sequential([ agent_networks.get('observation', tf.identity), agent_networks.get('policy'), ]) # Ensure network variables are created. tf2_utils.create_variables(evaluator_network, [observation_spec]) variables = {'policy': evaluator_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, variables, update_period=self._variable_update_period) # Make sure not to evaluate a random actor by assigning variables before # running the environment loop. variable_client.update_and_wait() # Create the evaluator; note it will not add experience to replay. evaluator = actors.FeedForwardActor(evaluator_network, variable_client=variable_client) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = loggers.make_default_logger('evaluator', time_delta=self._log_every, steps_key='evaluator_steps') # Create the run loop and return it. return acme.EnvironmentLoop(environment, evaluator, counter, logger)
def __init__( self, environment: dm_env.Environment, actor: core.Actor, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, should_update: bool = True, label: str = 'environment_loop', observers: Sequence[observers_lib.EnvLoopObserver] = (), ): # Internalize agent and environment. self._environment = environment self._actor = actor self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) self._should_update = should_update self._observers = observers
def learner( self, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" # Create online and target networks. online_networks = self._network_factory(self._environment_spec) target_networks = self._network_factory(self._environment_spec) # The dataset object to learn from. dataset = datasets.make_reverb_dataset( server_address=replay.server_address, batch_size=self._batch_size, prefetch_size=self._prefetch_size, ) counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger('learner', time_delta=self._log_every) # Create policy loss module if a factory is passed. if self._policy_loss_factory: policy_loss_module = self._policy_loss_factory() else: policy_loss_module = None # Return the learning agent. return learning.MoGMPOLearner( policy_network=online_networks['policy'], critic_network=online_networks['critic'], observation_network=online_networks['observation'], target_policy_network=target_networks['policy'], target_critic_network=target_networks['critic'], target_observation_network=target_networks['observation'], discount=self._additional_discount, num_samples=self._num_samples, policy_evaluation_config=self._policy_evaluation_config, target_policy_update_period=self._target_policy_update_period, target_critic_update_period=self._target_critic_update_period, policy_loss_module=policy_loss_module, dataset=dataset, counter=counter, logger=logger)