示例#1
0
文件: system.py 项目: NetColby/DNRL
    def _get_extra_specs(self) -> Any:
        """helper to establish specs for extra information

        Returns:
            Dict[str, Any]: dictionary containing extra specs
        """

        agents = self._environment_spec.get_agent_ids()
        core_state_specs = {}
        core_message_specs = {}

        networks = self._network_factory(  # type: ignore
            environment_spec=self._environment_spec
        )
        for agent in agents:
            agent_type = agent.split("_")[0]
            core_state_specs[agent] = (
                tf2_utils.squeeze_batch_dim(
                    networks["q_networks"][agent_type].initial_state(1)
                ),
            )
            if self._communication_module_fn is not None:
                core_message_specs[agent] = (
                    tf2_utils.squeeze_batch_dim(
                        networks["q_networks"][agent_type].initial_message(1)
                    ),
                )

        extras = {
            "core_states": core_state_specs,
            "core_messages": core_message_specs,
        }
        return extras
示例#2
0
  def replay(self) -> List[reverb.Table]:
    """The replay storage."""
    network = self._network_factory(self._environment_spec.actions)
    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    if self._samples_per_insert:
      limiter = reverb.rate_limiters.SampleToInsertRatio(
          min_size_to_sample=self._min_replay_size,
          samples_per_insert=self._samples_per_insert,
          error_buffer=self._batch_size)
    else:
      limiter = reverb.rate_limiters.MinSize(self._min_replay_size)
    table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Prioritized(self._priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=self._max_replay_size,
        rate_limiter=limiter,
        signature=adders.SequenceAdder.signature(
            self._environment_spec,
            extra_spec,
            sequence_length=self._burn_in_length + self._trace_length + 1))

    return [table]
示例#3
0
文件: builder.py 项目: NetColby/DNRL
    def make_replay_tables(
        self,
        environment_spec: specs.MAEnvironmentSpec,
    ) -> List[reverb.Table]:
        """Create tables to insert data into.

        Args:
            environment_spec (specs.MAEnvironmentSpec): description of the action and
                observation spaces etc. for each agent in the system.

        Returns:
            List[reverb.Table]: a list of data tables for inserting data.
        """

        agent_specs = environment_spec.get_agent_specs()
        extras_spec: Dict[str, Dict[str, acme_types.NestedArray]] = {"log_probs": {}}

        for agent, spec in agent_specs.items():
            # Make dummy log_probs
            extras_spec["log_probs"][agent] = tf.ones(shape=(1,), dtype=tf.float32)

        # Squeeze the batch dim.
        extras_spec = tf2_utils.squeeze_batch_dim(extras_spec)

        replay_table = reverb.Table.queue(
            name=self._config.replay_table_name,
            max_size=self._config.max_queue_size,
            signature=reverb_adders.ParallelSequenceAdder.signature(
                environment_spec, extras_spec=extras_spec
            ),
        )

        return [replay_table]
示例#4
0
文件: tf_utils.py 项目: NetColby/DNRL
 def spec(output: tf.Tensor) -> tf.TensorSpec:
     # If the output is not a Tensor, return None as spec is ill-defined.
     if not isinstance(output, tf.Tensor):
         return None
     # If this is not a scalar Tensor, make sure to squeeze out the batch dim.
     if tf.rank(output) > 0:
         output = squeeze_batch_dim(output)
     return tf.TensorSpec(output.shape, output.dtype)
示例#5
0
文件: tf_utils.py 项目: NetColby/DNRL
def create_variables(
    network: snt.Module,
    input_spec: List[OLT],
) -> Optional[tf.TensorSpec]:
    """Builds the network with dummy inputs to create the necessary variables.
    Args:
      network: Sonnet Module whose variables are to be created.
      input_spec: list of input specs to the network. The length of this list
        should match the number of arguments expected by `network`.
    Returns:
      output_spec: only returns an output spec if the output is a tf.Tensor, else
          it doesn't return anything (None); e.g. if the output is a
          tfp.distributions.Distribution.
    """
    # Create a dummy observation with no batch dimension.
    dummy_input = [
        OLT(
            observation=zeros_like(in_spec.observation),
            legal_actions=ones_like(in_spec.legal_actions),
            terminal=zeros_like(in_spec.terminal),
        ) for in_spec in input_spec
    ]

    # If we have an RNNCore the hidden state will be an additional input.
    if isinstance(network, snt.RNNCore):
        initial_state = squeeze_batch_dim(network.initial_state(1))
        dummy_input += [initial_state]

    # Forward pass of the network which will create variables as a side effect.
    dummy_output = network(*add_batch_dim(dummy_input))

    # Evaluate the input signature by converting the dummy input into a
    # TensorSpec. We then save the signature as a property of the network. This is
    # done so that we can later use it when creating snapshots. We do this here
    # because the snapshot code may not have access to the precise form of the
    # inputs.
    input_signature = tree.map_structure(
        lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input)
    network._input_signature = input_signature  # pylint: disable=protected-access

    def spec(output: tf.Tensor) -> tf.TensorSpec:
        # If the output is not a Tensor, return None as spec is ill-defined.
        if not isinstance(output, tf.Tensor):
            return None
        # If this is not a scalar Tensor, make sure to squeeze out the batch dim.
        if tf.rank(output) > 0:
            output = squeeze_batch_dim(output)
        return tf.TensorSpec(output.shape, output.dtype)

    return tree.map_structure(spec, dummy_output)
示例#6
0
文件: agent.py 项目: srsohn/mtsgi
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        queue: adder.Adder,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        n_step_horizon: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
        verbose_level: Optional[int] = 0,
    ):
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        tf2_utils.create_variables(network, [environment_spec.observations])

        actor = acting.A2CActor(environment_spec=environment_spec,
                                verbose_level=verbose_level,
                                network=network,
                                queue=queue)
        learner = learning.A2CLearner(
            environment_spec=environment_spec,
            network=network,
            dataset=queue,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=0,
                         observations_per_step=n_step_horizon)
示例#7
0
    def _get_extra_specs(self) -> Any:
        """helper to establish specs for extra information

        Returns:
            Dict[str, Any]: dictionary containing extra specs
        """

        agents = self._environment_spec.get_agent_ids()
        core_state_specs = {}
        networks = self._network_factory(  # type: ignore
            environment_spec=self._environment_spec)
        for agent in agents:
            agent_type = agent.split("_")[0]
            core_state_specs[agent] = (tf2_utils.squeeze_batch_dim(
                networks["policies"][agent_type].initial_state(1)), )
        return {"core_states": core_state_specs}
示例#8
0
 def queue(self):
   """The queue."""
   num_actions = self._environment_spec.actions.num_values
   network = self._network_factory(self._environment_spec.actions)
   extra_spec = {
       'core_state': network.initial_state(1),
       'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
   }
   # Remove batch dimensions.
   extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
   signature = adders.SequenceAdder.signature(
       self._environment_spec,
       extra_spec,
       sequence_length=self._sequence_length)
   queue = reverb.Table.queue(
       name=adders.DEFAULT_PRIORITY_TABLE,
       max_size=self._max_queue_size,
       signature=signature)
   return [queue]
示例#9
0
文件: agent.py 项目: dzorlu/acme
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        target_network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        demonstration_generator: iter,
        demonstration_ratio: float,
        model_directory: str,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        epsilon: float = 0.01,
        learning_rate: float = 1e-3,
        log_to_bigtable: bool = False,
        log_name: str = 'agent',
        checkpoint: bool = True,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
    ):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # replay table
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        # demonstation table.
        demonstration_table = reverb.Table(
            name='demonstration_table',
            sampler=reverb.selectors.Prioritized(0.8),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))

        # launch server
        self._server = reverb.Server([replay_table, demonstration_table],
                                     port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1

        # Component to add things into replay and demo
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)
        priority_function = {demonstration_table.name: lambda x: 1.}
        demo_adder = adders.SequenceAdder(client=reverb.Client(address),
                                          priority_fns=priority_function,
                                          **sequence_kwargs)
        # play demonstrations and write
        # exhaust the generator
        # TODO: MAX REPLAY SIZE
        _prev_action = 1  # this has to come from spec
        _add_first = True
        #include this to make datasets equivalent
        numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1))
        for ts, action in demonstration_generator:
            if _add_first:
                demo_adder.add_first(ts)
                _add_first = False
            else:
                demo_adder.add(_prev_action, ts, extras=(numpy_state, ))
            _prev_action = action
            # reset to new episode
            if ts.last():
                _prev_action = None
                _add_first = True

        # replay dataset
        max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=
            2,  # memory perf improvment attempt  https://github.com/deepmind/acme/issues/33
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        # demonstation dataset
        d_dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=demonstration_table.name,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=2,
            sequence_length=sequence_length,
            emit_timesteps=sequence_length is None)

        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, d_dataset],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            directory=model_directory,
            subdirectory='r2d2_learner_v1',
            time_delta_minutes=15,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None,
                                                   time_delta_minutes=15000.,
                                                   directory=model_directory)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
示例#10
0
文件: agent.py 项目: novatig/acme
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 encoder_network: types.TensorTransformation = tf.identity,
                 entropy_coeff: float = 0.01,
                 target_update_period: int = 0,
                 discount: float = 0.99,
                 batch_size: int = 256,
                 policy_learn_rate: float = 3e-4,
                 critic_learn_rate: float = 5e-4,
                 prefetch_size: int = 4,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 250000,
                 samples_per_insert: float = 64.0,
                 n_step: int = 5,
                 sigma: float = 0.5,
                 clipping: bool = True,
                 logger: loggers.Logger = None,
                 counter: counting.Counter = None,
                 checkpoint: bool = True,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      n_step: number of steps to squash into a single transition.
      sigma: standard deviation of zero-mean, Gaussian exploration noise.
      clipping: whether to clip gradients by global norm.
      logger: logger object to be used by learner.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """
        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.

        dim_actions = np.prod(environment_spec.actions.shape, dtype=int)
        extra_spec = {
            'logP': tf.ones(shape=(1), dtype=tf.float32),
            'policy': tf.ones(shape=(1, dim_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        replay_table = reverb.Table(
            name=replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec, extras_spec=extra_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(
            priority_fns={replay_table_name: lambda x: 1.},
            client=reverb.Client(address),
            n_step=n_step,
            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(table=replay_table_name,
                                               server_address=address,
                                               batch_size=batch_size,
                                               prefetch_size=prefetch_size)

        # Make sure observation network is a Sonnet Module.
        observation_network = model.MDPNormalization(environment_spec,
                                                     encoder_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations

        # Create the behavior policy.
        sampling_head = model.SquashedGaussianSamplingHead(act_spec, sigma)
        self._behavior_network = model.PolicyValueBehaviorNet(
            snt.Sequential([observation_network, policy_network]),
            sampling_head)

        # Create variables.
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])

        # Create the actor which defines how we take actions.
        actor = model.SACFeedForwardActor(self._behavior_network, adder)

        if target_update_period > 0:
            target_policy_network = copy.deepcopy(policy_network)
            target_critic_network = copy.deepcopy(critic_network)
            target_observation_network = copy.deepcopy(observation_network)

            tf2_utils.create_variables(target_policy_network, [emb_spec])
            tf2_utils.create_variables(target_critic_network,
                                       [emb_spec, act_spec])
            tf2_utils.create_variables(target_observation_network, [obs_spec])
        else:
            target_policy_network = policy_network
            target_critic_network = critic_network
            target_observation_network = observation_network

        # Create optimizers.
        policy_optimizer = snt.optimizers.Adam(learning_rate=policy_learn_rate)
        critic_optimizer = snt.optimizers.Adam(learning_rate=critic_learn_rate)

        # The learner updates the parameters (and initializes them).
        learner = learning.SACLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            sampling_head=sampling_head,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            target_update_period=target_update_period,
            learning_rate=policy_learn_rate,
            clipping=clipping,
            entropy_coeff=entropy_coeff,
            discount=discount,
            dataset=dataset,
            counter=counter,
            logger=logger,
            checkpoint=checkpoint,
        )

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
示例#11
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        burn_in_length: int,
        trace_length: int,
        replay_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        batch_size: int = 32,
        prefetch_size: int = tf.data.experimental.AUTOTUNE,
        target_update_period: int = 100,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        epsilon_init: float = 1.0,
        epsilon_final: float = 0.01,
        epsilon_schedule_timesteps: float = 20000,
        learning_rate: float = 1e-3,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        store_lstm_state: bool = True,
        max_priority_weight: float = 0.9,
        checkpoint: bool = True,
    ):

        if store_lstm_state:
            extra_spec = {
                'core_state':
                tf2_utils.squeeze_batch_dim(network.initial_state(1)),
            }
        else:
            extra_spec = ()

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        self._adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=replay_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = make_reverb_dataset(server_address=address,
                                      batch_size=batch_size,
                                      prefetch_size=prefetch_size,
                                      sequence_length=sequence_length)

        target_network = copy.deepcopy(network)
        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            sequence_length=sequence_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=store_lstm_state,
            max_priority_weight=max_priority_weight,
        )

        self._saver = tf2_savers.Saver(learner.state)

        policy_network = snt.DeepRNN([
            network,
            EpsilonGreedyExploration(
                epsilon_init=epsilon_init,
                epsilon_final=epsilon_final,
                epsilon_schedule_timesteps=epsilon_schedule_timesteps)
        ])
        actor = actors.RecurrentActor(policy_network,
                                      self._adder,
                                      store_recurrent_state=store_lstm_state)

        max_Q_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(),
        ])
        self._deterministic_actor = actors.RecurrentActor(
            max_Q_network, self._adder, store_recurrent_state=store_lstm_state)

        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
示例#12
0
    def __init__(self,
                 environment_spec: specs.EnvironmentSpec,
                 network: snt.RNNCore,
                 target_network: snt.RNNCore,
                 burn_in_length: int,
                 trace_length: int,
                 replay_period: int,
                 demonstration_dataset: tf.data.Dataset,
                 demonstration_ratio: float,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 discount: float = 0.99,
                 batch_size: int = 32,
                 target_update_period: int = 100,
                 importance_sampling_exponent: float = 0.2,
                 epsilon: float = 0.01,
                 learning_rate: float = 1e-3,
                 log_to_bigtable: bool = False,
                 log_name: str = 'agent',
                 checkpoint: bool = True,
                 min_replay_size: int = 1000,
                 max_replay_size: int = 1000000,
                 samples_per_insert: float = 32.0):

        extra_spec = {
            'core_state': network.initial_state(1),
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.SequenceAdder.signature(environment_spec,
                                                     extra_spec))
        self._server = reverb.Server([replay_table], port=None)
        address = f'localhost:{self._server.port}'

        sequence_length = burn_in_length + trace_length + 1
        # Component to add things into replay.
        sequence_kwargs = dict(
            period=replay_period,
            sequence_length=sequence_length,
        )
        adder = adders.SequenceAdder(client=reverb.Client(address),
                                     **sequence_kwargs)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               sequence_length=sequence_length)

        # Combine with demonstration dataset.
        transition = functools.partial(_sequence_from_episode,
                                       extra_spec=extra_spec,
                                       **sequence_kwargs)
        dataset_demos = demonstration_dataset.map(transition)
        dataset = tf.data.experimental.sample_from_datasets(
            [dataset, dataset_demos],
            [1 - demonstration_ratio, demonstration_ratio])

        # Batch and prefetch.
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tf2_utils.create_variables(network, [environment_spec.observations])
        tf2_utils.create_variables(target_network,
                                   [environment_spec.observations])

        learner = learning.R2D2Learner(
            environment_spec=environment_spec,
            network=network,
            target_network=target_network,
            burn_in_length=burn_in_length,
            dataset=dataset,
            reverb_client=reverb.TFClient(address),
            counter=counter,
            logger=logger,
            sequence_length=sequence_length,
            discount=discount,
            target_update_period=target_update_period,
            importance_sampling_exponent=importance_sampling_exponent,
            max_replay_size=max_replay_size,
            learning_rate=learning_rate,
            store_lstm_state=False,
        )

        self._checkpointer = tf2_savers.Checkpointer(
            subdirectory='r2d2_learner',
            time_delta_minutes=60,
            objects_to_save=learner.state,
            enable_checkpointing=checkpoint,
        )

        self._snapshotter = tf2_savers.Snapshotter(
            objects_to_save={'network': network}, time_delta_minutes=60.)

        policy_network = snt.DeepRNN([
            network,
            lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
        ])

        actor = actors.RecurrentActor(policy_network, adder)
        observations_per_step = (float(replay_period * batch_size) /
                                 samples_per_insert)
        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=replay_period *
                         max(batch_size, min_replay_size),
                         observations_per_step=observations_per_step)
示例#13
0
文件: agent.py 项目: wilixx/acme
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.RNNCore,
      burn_in_length: int,
      trace_length: int,
      replay_period: int,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      discount: float = 0.99,
      batch_size: int = 32,
      prefetch_size: int = tf.data.experimental.AUTOTUNE,
      target_update_period: int = 100,
      importance_sampling_exponent: float = 0.2,
      priority_exponent: float = 0.6,
      epsilon: float = 0.01,
      learning_rate: float = 1e-3,
      min_replay_size: int = 1000,
      max_replay_size: int = 1000000,
      samples_per_insert: float = 32.0,
      store_lstm_state: bool = True,
      max_priority_weight: float = 0.9,
      checkpoint: bool = True,
  ):

    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    replay_table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Prioritized(priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=max_replay_size,
        rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
        signature=adders.SequenceAdder.signature(environment_spec,
                                                   extra_spec))
    self._server = reverb.Server([replay_table], port=None)
    address = f'localhost:{self._server.port}'

    sequence_length = burn_in_length + trace_length + 1
    # Component to add things into replay.
    adder = adders.SequenceAdder(
        client=reverb.Client(address),
        period=replay_period,
        sequence_length=sequence_length,
    )

    # The dataset object to learn from.
    reverb_client = reverb.TFClient(address)
    dataset = datasets.make_reverb_dataset(
        client=reverb_client,
        environment_spec=environment_spec,
        batch_size=batch_size,
        prefetch_size=prefetch_size,
        extra_spec=extra_spec,
        sequence_length=sequence_length)

    target_network = copy.deepcopy(network)
    tf2_utils.create_variables(network, [environment_spec.observations])
    tf2_utils.create_variables(target_network, [environment_spec.observations])

    learner = learning.R2D2Learner(
        environment_spec=environment_spec,
        network=network,
        target_network=target_network,
        burn_in_length=burn_in_length,
        sequence_length=sequence_length,
        dataset=dataset,
        reverb_client=reverb_client,
        counter=counter,
        logger=logger,
        discount=discount,
        target_update_period=target_update_period,
        importance_sampling_exponent=importance_sampling_exponent,
        max_replay_size=max_replay_size,
        learning_rate=learning_rate,
        store_lstm_state=store_lstm_state,
        max_priority_weight=max_priority_weight,
    )

    self._checkpointer = tf2_savers.Checkpointer(
        subdirectory='r2d2_learner',
        time_delta_minutes=60,
        objects_to_save=learner.state,
        enable_checkpointing=checkpoint,
    )
    self._snapshotter = tf2_savers.Snapshotter(
        objects_to_save={'network': network}, time_delta_minutes=60.)

    policy_network = snt.DeepRNN([
        network,
        lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(),
    ])

    actor = actors.RecurrentActor(policy_network, adder)
    observations_per_step = (
        float(replay_period * batch_size) / samples_per_insert)
    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=replay_period * max(batch_size, min_replay_size),
        observations_per_step=observations_per_step)
示例#14
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: snt.RNNCore,
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        max_abs_reward: Optional[float] = None,
        max_gradient_norm: Optional[float] = None,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state': network.initial_state(1),
            'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32)
        }
        # Remove batch dimensions.
        extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)

        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=adders.SequenceAdder.signature(
                                       environment_spec,
                                       extras_spec=extra_spec,
                                       sequence_length=sequence_length))
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(server_address=address,
                                               batch_size=batch_size)

        tf2_utils.create_variables(network, [environment_spec.observations])

        self._actor = acting.IMPALAActor(network, adder)
        self._learner = learning.IMPALALearner(
            environment_spec=environment_spec,
            network=network,
            dataset=dataset,
            counter=counter,
            logger=logger,
            discount=discount,
            learning_rate=learning_rate,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_gradient_norm=max_gradient_norm,
            max_abs_reward=max_abs_reward,
        )