Пример #1
0
    def actor_evaluator(
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor = actors.FeedForwardActor(
            policy=evaluator_network,
            random_key=random_key,
            # Inference happens on CPU, so it's better to move variables there too.
            variable_client=variable_utils.VariableClient(variable_source,
                                                          'policy',
                                                          device='cpu'))

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Пример #2
0
  def actor(self, random_key, replay,
            variable_source, counter,
            actor_id):
    """The actor process."""
    adder = self._builder.make_adder(replay)

    environment_key, actor_key = jax.random.split(random_key)
    # Create environment and policy core.

    # Environments normally require uint32 as a seed.
    environment = self._environment_factory(
        utils.sample_uint32(environment_key))

    networks = self._network_factory(specs.make_environment_spec(environment))
    policy_network = self._policy_network(networks)
    actor = self._builder.make_actor(actor_key, policy_network, adder,
                                     variable_source)

    # Create logger and counter.
    counter = counting.Counter(counter, 'actor')
    # Only actor #0 will write to bigtable in order not to spam it too much.
    logger = self._actor_logger_fn(actor_id)
    # Create the loop to connect environment and agent.
    return environment_loop.EnvironmentLoop(environment, actor, counter,
                                            logger, observers=self._observers)
Пример #3
0
    def actor_evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
            evaluator_network)
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client = variable_utils.VariableClient(variable_source,
                                                        'policy',
                                                        device='cpu')
        actor = actors.GenericActor(actor_core,
                                    random_key,
                                    variable_client,
                                    backend='cpu')

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Пример #4
0
    def evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        networks = network_factory(specs.make_environment_spec(environment))

        actor = make_actor(actor_key, policy_factory(networks),
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        if logger_fn is not None:
            logger = logger_fn('evaluator', 'actor_steps')
        else:
            logger = loggers.make_default_logger('evaluator',
                                                 log_to_bigtable,
                                                 steps_key='actor_steps')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Пример #5
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Пример #6
0
    def evaluator(
        random_key: types.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        make_actor: MakeActorFn,
    ):
        """The evaluation process."""

        # Create environment and evaluator networks
        environment_key, actor_key = jax.random.split(random_key)
        # Environments normally require uint32 as a seed.
        environment = environment_factory(utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)
        networks = network_factory(environment_spec)
        policy = policy_factory(networks, environment_spec, True)
        actor = make_actor(actor_key, policy, environment_spec,
                           variable_source)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = logger_factory('evaluator', 'actor_steps', 0)

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=observers)
Пример #7
0
    def test_environment_loop(self):
        # Create the actor/environment and stick them in a loop.
        environment = fakes.DiscreteEnvironment(episode_length=10)
        actor = fakes.Actor(specs.make_environment_spec(environment))
        loop = environment_loop.EnvironmentLoop(environment, actor)

        # Run the loop. There should be episode_length+1 update calls per episode.
        loop.run(num_episodes=10)
        self.assertEqual(actor.num_updates, 100)
Пример #8
0
    def test_recurrent(self):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        network = snt.DeepRNN([
            snt.Flatten(),
            snt.Linear(env_spec.actions.num_values),
            lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype
                                ),
        ])

        actor = actors_tf2.RecurrentActor(network)
        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Пример #9
0
    def test_feedforward(self):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        network = snt.Sequential([
            snt.Flatten(),
            snt.Linear(env_spec.actions.num_values),
            lambda x: tf.argmax(x, axis=-1, output_type=env_spec.actions.dtype
                                ),
        ])

        actor = actors_tf2.FeedForwardActor(network)
        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Пример #10
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      jax.random.PRNGKey(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Пример #11
0
def _parameterized_setup(discount_spec: Optional[types.NestedSpec] = None,
                         reward_spec: Optional[types.NestedSpec] = None):
  """Common setup code that, unlike self.setUp, takes arguments.

  Args:
    discount_spec: None, or a (nested) specs.BoundedArray.
    reward_spec: None, or a (nested) specs.Array.
  Returns:
    environment, actor, loop
  """
  env_kwargs = {'episode_length': EPISODE_LENGTH}
  if discount_spec:
    env_kwargs['discount_spec'] = discount_spec
  if reward_spec:
    env_kwargs['reward_spec'] = reward_spec

  environment = fakes.DiscreteEnvironment(**env_kwargs)
  actor = fakes.Actor(specs.make_environment_spec(environment))
  loop = environment_loop.EnvironmentLoop(environment, actor)
  return actor, loop
Пример #12
0
    def __init__(
            self,
            eval_actor,
            environment,
            num_episodes,
            counter,
            logger,
            eval_sync=None,
            progress_counter_name='actor_steps',
            min_steps_between_evals=None,
            self_cleanup=False,
            observers=(),
    ):
        super().__init__()
        assert num_episodes >= 1
        self._eval_actor = eval_actor
        self._num_episodes = num_episodes
        self._counter = counter
        self._logger = logger
        # Create the run loop and return it.
        self._env = environment
        self._environment_loop = environment_loop.EnvironmentLoop(
            environment, eval_actor, should_update=False, observers=observers)
        self._eval_sync = eval_sync or (lambda _: None)
        self._progress_counter_name = progress_counter_name
        self._last_steps = None
        self._eval_every_steps = min_steps_between_evals

        self._pending_tear_down = False
        if self_cleanup:
            # Do not rely on the instance owner to cleanup this evaluator.
            # Register a signal handler to perform some resource cleanup.
            try:
                signal.signal(signal.SIGTERM, self._signal_handler)  # pytype: disable=wrong-arg-types
            except ValueError:
                logging.warning(
                    'Caught ValueError when registering signal handler. '
                    'This probably means we are not running in the main thread. '
                )
Пример #13
0
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Пример #14
0
  def test_recurrent(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)
    output_size = env_spec.actions.num_values
    obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    rng = hk.PRNGSequence(1)

    @hk.transform
    def network(inputs: jnp.ndarray, state: hk.LSTMState):
      return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)

    @hk.transform
    def initial_state(batch_size: int):
      network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
      return network.initial_state(batch_size)

    initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
    params = network.init(next(rng), obs, initial_state)

    def policy(
        params: jnp.ndarray,
        key: jnp.ndarray,
        observation: jnp.ndarray,
        core_state: hk.LSTMState
    ) -> Tuple[jnp.ndarray, hk.LSTMState]:
      del key  # Unused for test-case deterministic policy.
      action_values, core_state = network.apply(params, observation, core_state)
      return jnp.argmax(action_values, axis=-1), core_state

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.RecurrentActor(
        policy, hk.PRNGSequence(1), initial_state, variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Пример #15
0
    def build_actor(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        variable_source: core.VariableSource,
        counter: counting.Counter,
        actor_id: ActorId,
    ) -> environment_loop.EnvironmentLoop:
        """The actor process."""
        environment_key, actor_key = jax.random.split(random_key)
        # Create environment and policy core.

        # Environments normally require uint32 as a seed.
        environment = experiment.environment_factory(
            utils.sample_uint32(environment_key))
        environment_spec = specs.make_environment_spec(environment)

        networks = experiment.network_factory(environment_spec)
        policy_network = config.make_policy(experiment=experiment,
                                            networks=networks,
                                            environment_spec=environment_spec,
                                            evaluation=False)
        adder = experiment.builder.make_adder(replay, environment_spec,
                                              policy_network)
        actor = experiment.builder.make_actor(actor_key, policy_network,
                                              environment_spec,
                                              variable_source, adder)

        # Create logger and counter.
        counter = counting.Counter(counter, 'actor')
        logger = experiment.logger_factory('actor', counter.get_steps_key(),
                                           actor_id)
        # Create the loop to connect environment and agent.
        return environment_loop.EnvironmentLoop(environment,
                                                actor,
                                                counter,
                                                logger,
                                                observers=experiment.observers)
Пример #16
0
# Create a logger for the agent and environment loop.
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)

# Create the D4PG agent.
agent = d4pg.D4PG(environment_spec=environment_spec,
                  policy_network=policy_network,
                  critic_network=critic_network,
                  observation_network=observation_network,
                  sigma=1.0,
                  logger=agent_logger,
                  checkpoint=False)

# Create an loop connecting this agent to the environment created above.
env_loop = environment_loop.EnvironmentLoop(environment,
                                            agent,
                                            logger=env_loop_logger)

# Run a `num_episodes` training episodes.
# Rerun this cell until the agent has learned the given task.
env_loop.run(num_episodes=5000)


@tf.function(input_signature=[tf.TensorSpec(shape=(1, 32), dtype=np.float32)])
def policy_inference(x):
    return policy_network(x)


p_save = snt.Module()
p_save.inference = policy_inference
p_save.all_variables = list(policy_network.variables)
Пример #17
0
 def setUp(self):
     super().setUp()
     # Create the actor/environment and stick them in a loop.
     environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH)
     self.actor = fakes.Actor(specs.make_environment_spec(environment))
     self.loop = environment_loop.EnvironmentLoop(environment, self.actor)