Ejemplo n.º 1
0
  def actor(self, random_key, replay,
            variable_source, counter,
            actor_id):
    """The actor process."""
    adder = self._builder.make_adder(replay)

    environment_key, actor_key = jax.random.split(random_key)
    # Create environment and policy core.

    # Environments normally require uint32 as a seed.
    environment = self._environment_factory(
        utils.sample_uint32(environment_key))

    networks = self._network_factory(specs.make_environment_spec(environment))
    policy_network = self._policy_network(networks)
    actor = self._builder.make_actor(actor_key, policy_network, adder,
                                     variable_source)

    # Create logger and counter.
    counter = counting.Counter(counter, 'actor')
    # Only actor #0 will write to bigtable in order not to spam it too much.
    logger = self._actor_logger_fn(actor_id)
    # Create the loop to connect environment and agent.
    return environment_loop.EnvironmentLoop(environment, actor, counter,
                                            logger, observers=self._observers)
Ejemplo n.º 2
0
    def evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        networks = network_factory(specs.make_environment_spec(environment))

        actor = make_actor(actor_key, policy_factory(networks),
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        if logger_fn is not None:
            logger = logger_fn('evaluator', 'actor_steps')
        else:
            logger = loggers.make_default_logger('evaluator',
                                                 log_to_bigtable,
                                                 steps_key='actor_steps')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Ejemplo n.º 3
0
    def evaluator(
        random_key: types.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)
        networks = network_factory(environment_spec)
        policy = policy_factory(networks, environment_spec, True)
        actor = make_actor(actor_key, policy, environment_spec,
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = logger_factory('evaluator', 'actor_steps', 0)

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Ejemplo n.º 4
0
    def build_actor(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        actor_id: ActorId,
    ) -> environment_loop.EnvironmentLoop:
        """The actor process."""
        environment_key, actor_key = jax.random.split(random_key)
        # Create environment and policy core.

        # Environments normally require uint32 as a seed.
        environment = experiment.environment_factory(
            utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)

        networks = experiment.network_factory(environment_spec)
        policy_network = config.make_policy(experiment=experiment,
                                            networks=networks,
                                            environment_spec=environment_spec,
                                            evaluation=False)
        adder = experiment.builder.make_adder(replay, environment_spec,
                                              policy_network)
        actor = experiment.builder.make_actor(actor_key, policy_network,
                                              environment_spec,
                                              variable_source, adder)

        # Create logger and counter.
        counter = counting.Counter(counter, 'actor')
        logger = experiment.logger_factory('actor', counter.get_steps_key(),
                                           actor_id)
        # Create the loop to connect environment and agent.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=experiment.observers)