Example #1
0
    def __init__(self,
                 network: discrete_networks.DiscreteFilteredQNetwork,
                 dataset: tf.data.Dataset,
                 learning_rate: float,
                 counter: counting.Counter = None,
                 bc_logger: loggers.Logger = None,
                 bcq_logger: loggers.Logger = None,
                 **bcq_learner_kwargs):
        counter = counter or counting.Counter()
        self._bc_logger = bc_logger or loggers.TerminalLogger('bc_learner',
                                                              time_delta=1.)
        self._bcq_logger = bcq_logger or loggers.TerminalLogger('bcq_learner',
                                                                time_delta=1.)

        self._bc_learner = bc.BCLearner(network=network.g_network,
                                        learning_rate=learning_rate,
                                        dataset=dataset,
                                        counter=counting.Counter(
                                            counter, 'bc'),
                                        logger=self._bc_logger,
                                        checkpoint=False)
        self._bcq_learner = _InternalBCQLearner(network=network,
                                                learning_rate=learning_rate,
                                                dataset=dataset,
                                                counter=counting.Counter(
                                                    counter, 'bcq'),
                                                logger=self._bcq_logger,
                                                **bcq_learner_kwargs)
Example #2
0
def main(_):
    # Create an environment and grab the spec.
    environment = atari.environment(FLAGS.game)
    environment_spec = specs.make_environment_spec(environment)

    # Create dataset.
    dataset = atari.dataset(path=FLAGS.dataset_path,
                            game=FLAGS.game,
                            run=FLAGS.run,
                            num_shards=FLAGS.num_shards)
    # Discard extra inputs
    dataset = dataset.map(lambda x: x._replace(data=x.data[:5]))

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Build network.
    g_network = make_network(environment_spec.actions)
    q_network = make_network(environment_spec.actions)
    network = networks.DiscreteFilteredQNetwork(g_network=g_network,
                                                q_network=q_network,
                                                threshold=FLAGS.bcq_threshold)
    tf2_utils.create_variables(network, [environment_spec.observations])

    evaluator_network = snt.Sequential([
        q_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Counters.
    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = bcq.DiscreteBCQLearner(
        network=network,
        dataset=dataset,
        learning_rate=FLAGS.learning_rate,
        discount=FLAGS.discount,
        importance_sampling_exponent=FLAGS.importance_sampling_exponent,
        target_update_period=FLAGS.target_update_period,
        counter=counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Example #3
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_from_id(FLAGS.bsuite_id)
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors_tf2.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Example #4
0
 def test_shared_counts(self):
     # Two counters with shared parent should share counts (modulo namespacing).
     parent = counting.Counter()
     child1 = counting.Counter(parent, 'child1')
     child2 = counting.Counter(parent, 'child2')
     child1.increment(foo=1)
     result = child2.increment(foo=2)
     expected = {'child1_foo': 1, 'child2_foo': 2}
     self.assertEqual(result, expected)
    def make_learner(
        self,
        random_key,
        networks,
        dataset,
        logger_fn,
        environment_spec=None,
        replay_client=None,
        counter=None,
    ):
        """Creates the learner."""
        counter = counter or counting.Counter()
        discrete_rl_counter = counting.Counter(counter, 'direct_rl')

        aquadem_learner_key, discrete_rl_learner_key = jax.random.split(
            random_key)

        def discrete_rl_learner_factory(networks, dataset):
            return self._rl_agent.make_learner(
                discrete_rl_learner_key,
                networks,
                dataset,
                logger_fn=logger_fn,
                environment_spec=environment_spec,
                replay_client=replay_client,
                counter=discrete_rl_counter)

        # pytype:disable=attribute-error
        demonstrations_iterator = self._make_demonstrations(
            self._rl_agent._config.batch_size)  # pylint: disable=protected-access
        # pytype:enable=attribute-error

        optimizer = optax.adam(
            learning_rate=self._config.encoder_learning_rate)
        return learning.AquademLearner(
            random_key=aquadem_learner_key,
            discrete_rl_learner_factory=discrete_rl_learner_factory,
            iterator=dataset,
            demonstrations_iterator=demonstrations_iterator,
            optimizer=optimizer,
            networks=networks,
            make_demonstrations=self._make_demonstrations,
            encoder_num_steps=self._config.encoder_num_steps,
            encoder_batch_size=self._config.encoder_batch_size,
            encoder_eval_every=self._config.encoder_eval_every,
            temperature=self._config.temperature,
            num_actions=self._config.num_actions,
            demonstration_ratio=self._config.demonstration_ratio,
            min_demo_reward=self._config.min_demo_reward,
            counter=counter,
            logger=logger_fn('learner'))
Example #6
0
    def actor(
        self,
        replay: reverb.Client,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
    ) -> acme.EnvironmentLoop:
        """The actor process."""

        # Create the behavior policy.
        networks = self._network_factory(self._environment_spec.actions)
        networks.init(self._environment_spec)
        policy_network = networks.make_policy(
            environment_spec=self._environment_spec,
            sigma=self._sigma,
        )

        # Create the agent.
        actor = self._builder.make_actor(
            policy_network=policy_network,
            adder=self._builder.make_adder(replay),
            variable_source=variable_source,
        )

        # Create the environment.
        environment = self._environment_factory(False)

        # Create logger and counter; actors will not spam bigtable.
        counter = counting.Counter(counter, 'actor')
        logger = loggers.make_default_logger('actor',
                                             save_data=False,
                                             time_delta=self._log_every,
                                             steps_key='actor_steps')

        # Create the loop to connect environment and agent.
        return acme.EnvironmentLoop(environment, actor, counter, logger)
Example #7
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        # Create the networks to optimize (online) and target networks.
        online_networks = self._network_factory(self._environment_spec.actions)
        target_networks = copy.deepcopy(online_networks)

        # Initialize the networks.
        online_networks.init(self._environment_spec)
        target_networks.init(self._environment_spec)

        dataset = self._builder.make_dataset_iterator(replay)
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        return self._builder.make_learner(
            networks=(online_networks, target_networks),
            dataset=dataset,
            counter=counter,
            logger=logger,
        )
Example #8
0
    def trainer(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ) -> mava.core.Trainer:
        """System trainer

        Args:
            replay (reverb.Client): replay data table to pull data from.
            counter (counting.Counter): step counter object.

        Returns:
            mava.core.Trainer: system trainer.
        """

        # Create the networks to optimize (online)
        networks = self._network_factory(  # type: ignore
            environment_spec=self._environment_spec,
            shared_weights=self._shared_weights)

        # Create system architecture with target networks.
        architecture = self._architecture(
            environment_spec=self._environment_spec,
            value_networks=networks["q_networks"],
            shared_weights=self._shared_weights,
        )

        if self._builder._replay_stabiliser_fn is not None:
            architecture = self._builder._replay_stabiliser_fn(  # type: ignore
                architecture)

        communication_module = None
        if self._communication_module_fn is not None:
            communication_module = self._communication_module_fn(
                architecture=architecture,
                shared=True,
                channel_size=1,
                channel_noise=0,
            )
            system_networks = communication_module.create_system()
        else:
            system_networks = architecture.create_system()

        # create logger
        trainer_logger_config = {}
        if self._logger_config and "trainer" in self._logger_config:
            trainer_logger_config = self._logger_config["trainer"]
        trainer_logger = self._logger_factory(  # type: ignore
            "trainer", **trainer_logger_config)

        dataset = self._builder.make_dataset_iterator(replay)
        counter = counting.Counter(counter, "trainer")

        return self._builder.make_trainer(
            networks=system_networks,
            dataset=dataset,
            counter=counter,
            communication_module=communication_module,
            logger=trainer_logger,
        )
Example #9
0
    def evaluator(
        random_key: types.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)
        networks = network_factory(environment_spec)
        policy = policy_factory(networks, environment_spec, True)
        actor = make_actor(actor_key, policy, environment_spec,
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = logger_factory('evaluator', 'actor_steps', 0)

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Example #10
0
    def evaluator(
        self,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
        logger: loggers.Logger = None,
    ):
        """The evaluation process."""

        # Create the behavior policy.
        networks = self._network_factory(self._environment_spec.actions)
        networks.init(self._environment_spec)
        policy_network = networks.make_policy(self._environment_spec)

        # Create the agent.
        actor = self._builder.make_actor(
            policy_network=policy_network,
            variable_source=variable_source,
        )

        # Make the environment.
        environment = self._environment_factory(True)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = logger or loggers.make_default_logger(
            'evaluator',
            time_delta=self._log_every,
            steps_key='evaluator_steps',
        )

        # Create the run loop and return it.
        return acme.EnvironmentLoop(environment, actor, counter, logger)
Example #11
0
  def __init__(
      self,
      adapt_environment: dm_env.Environment,
      test_environment: dm_env.Environment,
      meta_agent: meta_agent.MetaAgent,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      label: str = "meta_train",
      verbose_level: int = 0,
  ):
    # Internalize counter and loggers.
    self._adapt_env = adapt_environment
    self._test_env = test_environment
    self._batch_size = self._adapt_env.batch_size
    self._counter = counter or counting.Counter() # Not used. TODO: consider removing
    self._logger = logger
    self._label = label
    self._verbose_level = verbose_level

    # Create train_loop and test_loop.
    # TODO: This looks not pretty. Improve interfaces.
    self._meta_agent = meta_agent
    self._agent = meta_agent.instantiate_adapt_agent()  # Fast agent
    self._test_actor = meta_agent.instantiate_test_actor()  # Test actor (no learning)

    self._adaptation_loop = EnvironmentLoop(
        self._adapt_env, self._agent,
        label='adaptation', verbose_level=self._verbose_level)
    self._test_loop = EnvironmentLoop(
        self._test_env, self._test_actor,
        label='test', verbose_level=self._verbose_level)
Example #12
0
    def test_value_dice(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)

        spec = specs.make_environment_spec(environment)

        # Create the networks.
        network = value_dice.make_networks(spec)

        config = value_dice.ValueDiceConfig(batch_size=10, min_replay_size=1)
        counter = counting.Counter()
        agent = value_dice.ValueDice(
            spec=spec,
            network=network,
            config=config,
            make_demonstrations=fakes.transition_iterator(environment),
            seed=0,
            counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Example #13
0
  def evaluator(
      self,
      variable_source: acme.VariableSource,
      counter: counting.Counter,
  ):
    """The evaluation process."""
    environment = self._environment_factory(True)
    network = self._network_factory(self._environment_spec.actions)

    tf2_utils.create_variables(network, [self._obs_spec])
    policy_network = snt.DeepRNN([
        network,
        lambda qs: tf.cast(tf.argmax(qs, axis=-1), tf.int32),
    ])

    variable_client = tf2_variable_utils.VariableClient(
        client=variable_source,
        variables={'policy': policy_network.variables},
        update_period=self._variable_update_period)

    # Make sure not to use a random policy after checkpoint restoration by
    # assigning variables before running the environment loop.
    variable_client.update_and_wait()

    # Create the agent.
    actor = actors.RecurrentActor(
        policy_network=policy_network, variable_client=variable_client)

    # Create the run loop and return it.
    logger = loggers.make_default_logger(
        'evaluator', save_data=True, steps_key='evaluator_steps')
    counter = counting.Counter(counter, 'evaluator')

    return acme.EnvironmentLoop(environment, actor, counter, logger)
Example #14
0
    def learner(
        self,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        # If we are running on multiple accelerator devices, this replicates
        # weights and updates across devices.
        replicator = agent.get_replicator(self._accelerator)

        with replicator.scope():
            # Create the networks to optimize (online) and target networks.
            online_networks = self._network_factory(
                self._environment_spec.actions)
            target_networks = copy.deepcopy(online_networks)

            # Initialize the networks.
            online_networks.init(self._environment_spec)
            target_networks.init(self._environment_spec)

        dataset = self._builder.make_dataset_iterator(replay)

        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger('learner',
                                             time_delta=self._log_every,
                                             steps_key='learner_steps')

        return self._builder.make_learner(
            networks=(online_networks, target_networks),
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=True,
        )
Example #15
0
    def evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        networks = network_factory(specs.make_environment_spec(environment))

        actor = make_actor(actor_key, policy_factory(networks),
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        if logger_fn is not None:
            logger = logger_fn('evaluator', 'actor_steps')
        else:
            logger = loggers.make_default_logger('evaluator',
                                                 log_to_bigtable,
                                                 steps_key='actor_steps')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Example #16
0
  def learner(self, queue: reverb.Client, counter: counting.Counter):
    """The Learning part of the agent."""
    # Use architect and create the environment.
    # Create the networks.
    network = self._network_factory(self._environment_spec.actions)
    tf2_utils.create_variables(network, [self._environment_spec.observations])

    # The dataset object to learn from.
    dataset = datasets.make_reverb_dataset(
        server_address=queue.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size)

    logger = loggers.make_default_logger('learner', steps_key='learner_steps')
    counter = counting.Counter(counter, 'learner')

    # Return the learning agent.
    learner = learning.IMPALALearner(
        environment_spec=self._environment_spec,
        network=network,
        dataset=dataset,
        discount=self._discount,
        learning_rate=self._learning_rate,
        entropy_cost=self._entropy_cost,
        baseline_cost=self._baseline_cost,
        max_abs_reward=self._max_abs_reward,
        max_gradient_norm=self._max_gradient_norm,
        counter=counter,
        logger=logger,
    )

    return tf2_savers.CheckpointingRunner(learner,
                                          time_delta_minutes=5,
                                          subdirectory='impala_learner')
Example #17
0
 def test_get_steps_key(self):
     parent = counting.Counter()
     child1 = counting.Counter(parent,
                               'child1',
                               time_delta=0.,
                               return_only_prefixed=False)
     child2 = counting.Counter(parent,
                               'child2',
                               time_delta=0.,
                               return_only_prefixed=True)
     self.assertEqual(child1.get_steps_key(), 'child1_steps')
     self.assertEqual(child2.get_steps_key(), 'steps')
     child1.increment(steps=1)
     child2.increment(steps=2)
     self.assertEqual(child1.get_counts().get(child1.get_steps_key()), 1)
     self.assertEqual(child2.get_counts().get(child2.get_steps_key()), 2)
Example #18
0
    def actor_evaluator(
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor = actors.FeedForwardActor(
            policy=evaluator_network,
            random_key=random_key,
            # Inference happens on CPU, so it's better to move variables there too.
            variable_client=variable_utils.VariableClient(variable_source,
                                                          'policy',
                                                          device='cpu'))

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Example #19
0
    def test_r2d2(self):
        # Create a fake environment to test with.
        environment = fakes.fake_atari_wrapped(oar_wrapper=True)
        spec = specs.make_environment_spec(environment)

        config = r2d2.R2D2Config(batch_size=1,
                                 trace_length=5,
                                 sequence_period=1,
                                 samples_per_insert=0.,
                                 min_replay_size=1,
                                 burn_in_length=1)

        counter = counting.Counter()
        agent = r2d2.R2D2(
            spec=spec,
            networks=r2d2.make_atari_networks(config.batch_size, spec),
            config=config,
            seed=0,
            counter=counter,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
Example #20
0
    def actor(
        self,
        replay: reverb.Client,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
    ) -> acme.EnvironmentLoop:
        """The actor process."""

        action_spec = self._environment_spec.actions
        observation_spec = self._environment_spec.observations

        # Create environment and target networks to act with.
        environment = self._environment_factory(False)
        agent_networks = self._network_factory(action_spec,
                                               self._num_critic_heads)

        # Make sure observation network is defined.
        observation_network = agent_networks.get('observation', tf.identity)

        # Create a stochastic behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            agent_networks['policy'],
            networks.StochasticSamplingHead(),
        ])

        # Ensure network variables are created.
        tf2_utils.create_variables(behavior_network, [observation_spec])
        policy_variables = {'policy': behavior_network.variables}

        # Create the variable client responsible for keeping the actor up-to-date.
        variable_client = tf2_variable_utils.VariableClient(variable_source,
                                                            policy_variables,
                                                            update_period=1000)

        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        # Component to add things into replay.
        adder = adders.NStepTransitionAdder(
            client=replay,
            n_step=self._n_step,
            max_in_flight_items=self._max_in_flight_items,
            discount=self._additional_discount)

        # Create the agent.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder,
                                        variable_client=variable_client)

        # Create logger and counter; actors will not spam bigtable.
        counter = counting.Counter(counter, 'actor')
        logger = loggers.make_default_logger('actor',
                                             save_data=False,
                                             time_delta=self._log_every,
                                             steps_key='actor_steps')

        # Create the run loop and return it.
        return acme.EnvironmentLoop(environment, actor, counter, logger)
Example #21
0
  def evaluator(self, variable_source: acme.VariableSource,
                counter: counting.Counter):
    """The evaluation process."""
    environment = self._environment_factory(True)
    network = self._network_factory(self._environment_spec.actions)
    tf2_utils.create_variables(network, [self._environment_spec.observations])

    variable_client = tf2_variable_utils.VariableClient(
        client=variable_source,
        variables={'policy': network.variables},
        update_period=self._variable_update_period)

    # Make sure not to use a random policy after checkpoint restoration by
    # assigning variables before running the environment loop.
    variable_client.update_and_wait()

    # Create the agent.
    actor = acting.IMPALAActor(
        network=network, variable_client=variable_client)

    # Create the run loop and return it.
    logger = loggers.make_default_logger(
        'evaluator', steps_key='evaluator_steps')
    counter = counting.Counter(counter, 'evaluator')
    return acme.EnvironmentLoop(environment, actor, counter, logger)
Example #22
0
  def actor(self, random_key, replay,
            variable_source, counter,
            actor_id):
    """The actor process."""
    adder = self._builder.make_adder(replay)

    environment_key, actor_key = jax.random.split(random_key)
    # Create environment and policy core.

    # Environments normally require uint32 as a seed.
    environment = self._environment_factory(
        utils.sample_uint32(environment_key))

    networks = self._network_factory(specs.make_environment_spec(environment))
    policy_network = self._policy_network(networks)
    actor = self._builder.make_actor(actor_key, policy_network, adder,
                                     variable_source)

    # Create logger and counter.
    counter = counting.Counter(counter, 'actor')
    # Only actor #0 will write to bigtable in order not to spam it too much.
    logger = self._actor_logger_fn(actor_id)
    # Create the loop to connect environment and agent.
    return environment_loop.EnvironmentLoop(environment, actor, counter,
                                            logger, observers=self._observers)
Example #23
0
    def actor_evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
            evaluator_network)
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client = variable_utils.VariableClient(variable_source,
                                                        'policy',
                                                        device='cpu')
        actor = actors.GenericActor(actor_core,
                                    random_key,
                                    variable_client,
                                    backend='cpu')

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Example #24
0
  def test_td3_fd(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    td3_network = td3.make_networks(spec)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=1)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)
    counter = counting.Counter()
    agent = lfd.TD3fD(
        spec=spec,
        td3_network=td3_network,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=20)
Example #25
0
    def __init__(self,
                 network: snt.Module,
                 learning_rate: float,
                 dataset: tf.data.Dataset,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner.

    Args:
      network: the online Q network (the one being optimized)
      learning_rate: learning rate for the q-network update.
      dataset: dataset to learn from.
      counter: Counter object for (potentially distributed) counting.
      logger: Logger object for writing logs to.
    """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types
        # TODO(b/155086959): Fix type stubs and remove.

        self._network = network
        self._optimizer = snt.optimizers.Adam(learning_rate)

        self._variables: List[List[tf.Tensor]] = [network.trainable_variables]
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)
Example #26
0
    def run_ppo_agent(self, make_networks_fn):
        # Create a fake environment to test with.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 5),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        distribution_value_networks = make_networks_fn(spec)
        ppo_networks = ppo.make_ppo_networks(distribution_value_networks)
        config = ppo.PPOConfig(unroll_length=4,
                               num_epochs=2,
                               num_minibatches=2)
        workdir = self.create_tempdir()
        counter = counting.Counter()
        logger = loggers.make_default_logger('learner')
        # Construct the agent.
        agent = ppo.PPO(
            spec=spec,
            networks=ppo_networks,
            config=config,
            seed=0,
            workdir=workdir.full_path,
            normalize_input=True,
            counter=counter,
            logger=logger,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=20)
    def learner(
        self,
        random_key: networks_lib.PRNGKey,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""
        # Counter and logger.
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger(
            'learner',
            self._save_logs,
            time_delta=self._log_every,
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key='learner_steps')

        # Create the learner.
        networks = self._network_factory()
        learner = self._make_learner(random_key, networks, counter, logger)

        kwargs = {
            'directory': self._workdir,
            'add_uid': self._workdir == '~/acme'
        }
        # Return the learning agent.
        return savers.CheckpointingRunner(learner,
                                          subdirectory='learner',
                                          time_delta_minutes=5,
                                          **kwargs)
Example #28
0
    def test_d4pg(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        networks = make_networks(spec)

        config = d4pg.D4PGConfig(
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            samples_per_insert_tolerance_rate=float('inf'))
        counter = counting.Counter()
        agent = d4pg.D4PG(spec,
                          networks,
                          config=config,
                          random_seed=0,
                          counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Example #29
0
def main(_):
    # Create an environment, grab the spec, and use it to create networks.
    environment = helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)
    agent_networks = ppo.make_continuous_networks(environment_spec)

    # Construct the agent.
    config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length,
                           num_minibatches=FLAGS.num_minibatches,
                           num_epochs=FLAGS.num_epochs,
                           batch_size=FLAGS.batch_size)

    learner_logger = experiment_utils.make_experiment_logger(
        label='learner', steps_key='learner_steps')
    agent = ppo.PPO(environment_spec,
                    agent_networks,
                    config=config,
                    seed=FLAGS.seed,
                    counter=counting.Counter(prefix='learner'),
                    logger=learner_logger)

    # Create the environment loop used for training.
    train_logger = experiment_utils.make_experiment_logger(
        label='train', steps_key='train_steps')
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      counter=counting.Counter(prefix='train'),
                                      logger=train_logger)

    # Create the evaluation actor and loop.
    eval_logger = experiment_utils.make_experiment_logger(
        label='eval', steps_key='eval_steps')
    eval_actor = agent.builder.make_actor(
        random_key=jax.random.PRNGKey(FLAGS.seed),
        policy_network=ppo.make_inference_fn(agent_networks, evaluation=True),
        variable_source=agent)
    eval_env = helpers.make_environment(task=FLAGS.env_name)
    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     counter=counting.Counter(prefix='eval'),
                                     logger=eval_logger)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=5)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=5)
Example #30
0
 def test_return_only_prefixed(self):
     parent = counting.Counter()
     child1 = counting.Counter(parent,
                               'child1',
                               time_delta=0.,
                               return_only_prefixed=False)
     child2 = counting.Counter(parent,
                               'child2',
                               time_delta=0.,
                               return_only_prefixed=True)
     child1.increment(foo=1)
     child2.increment(bar=1)
     self.assertEqual(child1.get_counts(), {
         'child1_foo': 1,
         'child2_bar': 1
     })
     self.assertEqual(child2.get_counts(), {'bar': 1})