Exemplo n.º 1
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Exemplo n.º 2
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Exemplo n.º 3
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 0.5,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.05,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=n_step,
            batch_size=batch_size,
            max_replay_size=max_replay_size,
            min_replay_size=min_replay_size,
            priority_exponent=priority_exponent,
            discount=discount,
        )
        self._server = reverb_replay.server

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(seed),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)