예제 #1
0
 def make_actor(
         self,
         random_key: networks_lib.PRNGKey,
         policy_network,
         adder: Optional[adders.Adder] = None,
         variable_source: Optional[core.VariableSource] = None
 ) -> acme.Actor:
     variable_client = variable_utils.VariableClient(client=variable_source,
                                                     key='network',
                                                     update_period=1000,
                                                     device='cpu')
     return acting.IMPALAActor(
         forward_fn=policy_network.forward_fn,
         initial_state_init_fn=policy_network.initial_state_init_fn,
         initial_state_fn=policy_network.initial_state_fn,
         variable_client=variable_client,
         adder=adder,
         rng=hk.PRNGSequence(random_key),
     )
예제 #2
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy: impala_networks.IMPALANetworks,
     environment_spec: specs.EnvironmentSpec,
     variable_source: Optional[core.VariableSource] = None,
     adder: Optional[adders.Adder] = None,
 ) -> acme.Actor:
   del environment_spec
   variable_client = variable_utils.VariableClient(
       client=variable_source,
       key='network',
       update_period=self._config.variable_update_period,
       device='cpu')
   return acting.IMPALAActor(
       forward_fn=policy.forward_fn,
       initial_state_fn=policy.initial_state_fn,
       variable_client=variable_client,
       adder=adder,
       rng=hk.PRNGSequence(random_key),
   )
예제 #3
0
파일: agent.py 프로젝트: whl19910402/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=max_queue_size,
            sequence_length=sequence_length,
            sequence_period=sequence_period,
            batch_size=batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )
        key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed))
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        transformed = hk.without_apply_rng(hk.transform(forward_fn))
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(transformed.apply, backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
예제 #4
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
예제 #5
0
파일: agent.py 프로젝트: vishalbelsare/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: impala_types.PolicyValueFn,
        unroll_init_fn: impala_types.PolicyValueInitFn,
        unroll_fn: impala_types.PolicyValueFn,
        initial_state_init_fn: impala_types.RecurrentStateInitFn,
        initial_state_fn: impala_types.RecurrentStateFn,
        config: impala_config.IMPALAConfig,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
    ):
        networks = impala_networks.IMPALANetworks(
            forward_fn=forward_fn,
            unroll_init_fn=unroll_init_fn,
            unroll_fn=unroll_fn,
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
        )

        self._config = config

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        key, key_initial_state = jax.random.split(
            jax.random.PRNGKey(self._config.seed))
        params = initial_state_init_fn(key_initial_state)
        extra_spec = {
            'core_state': initial_state_fn(params),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }

        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=self._config.max_queue_size,
            sequence_length=self._config.sequence_length,
            sequence_period=self._config.sequence_period,
            batch_size=self._config.batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(self._config.max_gradient_norm),
            optax.adam(self._config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(key)
        self._learner = learning.IMPALALearner(
            networks=networks,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            baseline_cost=self._config.baseline_cost,
            max_abs_reward=self._config.max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(forward_fn, backend='cpu'),
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
예제 #6
0
파일: agent.py 프로젝트: staylonging/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn,
                                              apply_rng=True)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        signature = adders.SequenceAdder.signature(environment_spec,
                                                   extra_spec)
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=signature)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        # We don't use datasets.make_reverb_dataset() here to avoid interleaving
        # and prefetching, that doesn't work well with can_sample() check on update.
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=1,
            sequence_length=sequence_length,
            emit_timesteps=False)
        dataset = dataset.batch(batch_size, drop_remainder=True)

        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )

        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=hk.PRNGSequence(seed),
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(hk.without_apply_rng(
                hk.transform(forward_fn, apply_rng=True)).apply,
                               backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(seed),
            adder=adder,
            variable_client=variable_client,
        )