예제 #1
0
  def make_learner(
      self,
      random_key: networks_lib.PRNGKey,
      networks: impala_networks.IMPALANetworks,
      dataset: Iterator[reverb.ReplaySample],
      logger_fn: loggers.LoggerFactory,
      environment_spec: specs.EnvironmentSpec,
      replay_client: Optional[reverb.Client] = None,
      counter: Optional[counting.Counter] = None,
  ) -> core.Learner:
    del environment_spec, replay_client

    optimizer = optax.chain(
        optax.clip_by_global_norm(self._config.max_gradient_norm),
        optax.adam(
            self._config.learning_rate,
            b1=self._config.adam_momentum_decay,
            b2=self._config.adam_variance_decay,
            eps=self._config.adam_eps,
            eps_root=self._config.adam_eps_root))

    return learning.IMPALALearner(
        networks=networks,
        iterator=dataset,
        optimizer=optimizer,
        random_key=random_key,
        discount=self._config.discount,
        entropy_cost=self._config.entropy_cost,
        baseline_cost=self._config.baseline_cost,
        max_abs_reward=self._config.max_abs_reward,
        counter=counter,
        logger=logger_fn('learner'),
    )
예제 #2
0
파일: agent.py 프로젝트: whl19910402/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=max_queue_size,
            sequence_length=sequence_length,
            sequence_period=sequence_period,
            batch_size=batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )
        key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed))
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        transformed = hk.without_apply_rng(hk.transform(forward_fn))
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(transformed.apply, backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
예제 #3
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
예제 #4
0
파일: agent.py 프로젝트: vishalbelsare/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: impala_types.PolicyValueFn,
        unroll_init_fn: impala_types.PolicyValueInitFn,
        unroll_fn: impala_types.PolicyValueFn,
        initial_state_init_fn: impala_types.RecurrentStateInitFn,
        initial_state_fn: impala_types.RecurrentStateFn,
        config: impala_config.IMPALAConfig,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
    ):
        networks = impala_networks.IMPALANetworks(
            forward_fn=forward_fn,
            unroll_init_fn=unroll_init_fn,
            unroll_fn=unroll_fn,
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
        )

        self._config = config

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        key, key_initial_state = jax.random.split(
            jax.random.PRNGKey(self._config.seed))
        params = initial_state_init_fn(key_initial_state)
        extra_spec = {
            'core_state': initial_state_fn(params),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }

        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=self._config.max_queue_size,
            sequence_length=self._config.sequence_length,
            sequence_period=self._config.sequence_period,
            batch_size=self._config.batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(self._config.max_gradient_norm),
            optax.adam(self._config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(key)
        self._learner = learning.IMPALALearner(
            networks=networks,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            baseline_cost=self._config.baseline_cost,
            max_abs_reward=self._config.max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(forward_fn, backend='cpu'),
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
예제 #5
0
파일: agent.py 프로젝트: staylonging/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn,
                                              apply_rng=True)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        signature = adders.SequenceAdder.signature(environment_spec,
                                                   extra_spec)
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=signature)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        # We don't use datasets.make_reverb_dataset() here to avoid interleaving
        # and prefetching, that doesn't work well with can_sample() check on update.
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=1,
            sequence_length=sequence_length,
            emit_timesteps=False)
        dataset = dataset.batch(batch_size, drop_remainder=True)

        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )

        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=hk.PRNGSequence(seed),
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(hk.without_apply_rng(
                hk.transform(forward_fn, apply_rng=True)).apply,
                               backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(seed),
            adder=adder,
            variable_client=variable_client,
        )