Beispiel #1
0
 def initialize_replay_buffer(self, examples, batch_spec, async_=False):
     example_to_buffer = SamplesToBuffer(
         observation=examples["observation"],
         action=examples["action"],
         reward=examples["reward"],
         done=examples["done"],
     )
     replay_kwargs = dict(
         example=example_to_buffer,
         size=self.replay_size,
         B=batch_spec.B,
         discount=self.discount,
         n_step_return=self.n_step_return,
     )
     if self.args['n_step_nce'] > 1 or self.args['n_step_nce'] < 0:
         ReplayCls = UniformSequenceReplayFrameBuffer
         replay_kwargs['rnn_state_interval'] = 0
         replay_kwargs['batch_T'] = batch_spec.T + self.warmup_T
     elif self.prioritized_replay:
         replay_kwargs.update(
             dict(
                 alpha=self.pri_alpha,
                 beta=self.pri_beta_init,
                 default_priority=self.default_priority,
             ))
         ReplayCls = (AsyncPrioritizedReplayFrameBuffer
                      if async_ else PrioritizedReplayFrameBuffer)
     else:
         ReplayCls = (AsyncUniformReplayFrameBuffer
                      if async_ else UniformReplayFrameBuffer)
     self.replay_buffer = ReplayCls(**replay_kwargs)
Beispiel #2
0
 def examples_to_buffer(self, examples):
     if self.store_latent:
         observation = examples["agent_info"].conv
     else:
         observation = examples["observation"]
     return SamplesToBuffer(
         observation=observation,
         action=examples["action"],
         reward=examples["reward"],
         done=examples["done"],
     )
Beispiel #3
0
    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        """
        Similar to DQN, except allows to compute the priorities of new samples
        as they enter the replay buffer (input priorities), instead of only once they are
        used in training (important because the replay-ratio is quite low, about 1,
        so must avoid un-informative samples).
        """

        # TODO: estimate priorities for samples entering the replay buffer.
        # Steven says: workers did this approximately by using the online
        # network only for td-errors (not the target network).
        # This could be tough since add samples before the priorities are ready
        # (next batch), and in async case workers must do it.
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr
        if samples is not None:
            samples_to_buffer = SamplesToBuffer(
                observation=samples.env.observation,
                action=samples.agent.action,
                reward=samples.env.reward,
                done=samples.env.done,
            )
            if self.store_rnn_state_interval > 0:
                samples_to_buffer = SamplesToBufferRnn(*samples_to_buffer,
                    prev_rnn_state=samples.agent.agent_info.prev_rnn_state)
            if self.input_priorities:
                priorities = self.compute_input_priorities(samples)
                samples_to_buffer = PrioritiesSamplesToBuffer(
                    priorities=priorities, samples=samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(self.batch_B)
            self.optimizer.zero_grad()
            loss, td_abs_errors, priorities = self.loss(samples_from_replay)
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.parameters(), self.clip_grad_norm)
            self.optimizer.step()
            if self.prioritized_replay:
                self.replay_buffer.update_batch_priorities(priorities)
            opt_info.loss.append(loss.item())
            opt_info.gradNorm.append(grad_norm)
            opt_info.tdAbsErr.extend(td_abs_errors[::8].numpy())
            opt_info.priority.extend(priorities)
            self.update_counter += 1
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target()
        self.update_itr_hyperparams(itr)
        return opt_info
Beispiel #4
0
 def samples_to_buffer(self, samples):
     """Defines how to add data from sampler into the replay buffer. Called
     in optimize_agent() if samples are provided to that method.  In
     asynchronous mode, will be called in the memory_copier process."""
     if self.store_latent:
         observation = samples.agent.agent_info.conv
     else:
         observation = samples.env.observation
     return SamplesToBuffer(
         observation=observation,
         action=samples.agent.action,
         reward=samples.env.reward,
         done=samples.env.done,
     )
Beispiel #5
0
 def initialize_replay_buffer(self, examples, batch_spec, async_=False):
     """Similar to DQN but uses replay buffers which return sequences, and
     stores the agent's recurrent state."""
     example_to_buffer = SamplesToBuffer(
         observation=examples["observation"],
         action=examples["action"],
         reward=examples["reward"],
         done=examples["done"],
     )
     if self.store_rnn_state_interval > 0:
         example_to_buffer = SamplesToBufferRnn(
             *example_to_buffer,
             prev_rnn_state=examples["agent_info"].prev_rnn_state,
         )
     replay_kwargs = dict(
         example=example_to_buffer,
         size=self.replay_size,
         B=batch_spec.B,
         discount=self.discount,
         n_step_return=self.n_step_return,
         rnn_state_interval=self.store_rnn_state_interval,
         # batch_T fixed for prioritized, (relax if rnn_state_interval=1 or 0).
         batch_T=self.batch_T + self.warmup_T,
     )
     if self.prioritized_replay:
         replay_kwargs.update(
             dict(
                 alpha=self.pri_alpha,
                 beta=self.pri_beta_init,
                 default_priority=self.default_priority,
                 input_priorities=self.input_priorities,  # True/False.
                 input_priority_shift=self.input_priority_shift,
             ))
         ReplayCls = (AsyncPrioritizedSequenceReplayFrameBuffer
                      if async_ else PrioritizedSequenceReplayFrameBuffer)
     else:
         ReplayCls = (AsyncUniformSequenceReplayFrameBuffer
                      if async_ else UniformSequenceReplayFrameBuffer)
     if self.ReplayBufferCls is not None:
         ReplayCls = self.ReplayBufferCls
         logger.log(
             f"WARNING: ignoring internal selection logic and using"
             f" input replay buffer class: {ReplayCls} -- compatibility not"
             " guaranteed.")
     self.replay_buffer = ReplayCls(**replay_kwargs)
     return self.replay_buffer
Beispiel #6
0
 def initialize_replay_buffer(self, examples, batch_spec, async_=False):
     example_to_buffer = SamplesToBuffer(
         observation=examples["observation"],
         action=examples["action"],
         reward=examples["reward"],
         done=examples["done"],
     )
     if self.store_rnn_state_interval > 0:
         example_to_buffer = SamplesToBufferRnn(
             *example_to_buffer,
             prev_rnn_state=examples["agent_info"].prev_rnn_state,
         )
     replay_kwargs = dict(
         example=example_to_buffer,
         size=self.replay_size,
         B=batch_spec.B,
         discount=self.discount,
         n_step_return=self.n_step_return,
         rnn_state_interval=self.store_rnn_state_interval,
         # batch_T fixed for prioritized, (relax if rnn_state_interval=1 or 0).
         batch_T=self.batch_T + self.warmup_T,
     )
     if self.prioritized_replay:
         replay_kwargs.update(
             dict(
                 alpha=self.pri_alpha,
                 beta=self.pri_beta_init,
                 default_priority=self.default_priority,
                 input_priorities=self.input_priorities,  # True/False.
                 input_priority_shift=self.input_priority_shift,
             ))
         ReplayCls = (AsyncPrioritizedSequenceReplayFrameBuffer
                      if async_ else PrioritizedSequenceReplayFrameBuffer)
     else:
         ReplayCls = (AsyncUniformSequenceReplayFrameBuffer
                      if async_ else UniformSequenceReplayFrameBuffer)
     self.replay_buffer = ReplayCls(**replay_kwargs)
     return self.replay_buffer
Beispiel #7
0
    def samples_to_buffer(self, samples):
        """
        Prepares samples for insertion
        to replay buffer. We insert poison
        rewards here, but from the last sample
        as we need the next observation to compute
        the attacker Q-function.

        So, we return the poisoned sample from the
        last sampler time step, which is fine since 
        we're just adding it to a replay buffer.
        """
        if self.opt_itr < self.first_poison_itr:
            self.last_samples = copy.deepcopy(samples)
            return super().samples_to_buffer(samples)

        last_obs, last_act, last_rew, last_done = self._unpack_samples(
            self.last_samples)
        curr_obs, _, _, _ = self._unpack_samples(samples)

        poisoned_last_rew = self._poison_rewards(
            last_obs.squeeze(
            ),  # squeezing singleton T dim, as we assert T=1 in initialize
            last_act.squeeze(),
            last_rew.squeeze(),
            last_done.squeeze(),
            curr_obs.squeeze()).unsqueeze(
                0)  # and add back singleton T dim at end

        self.last_samples = copy.deepcopy(
            samples)  # set this for next iteration

        return SamplesToBuffer(
            observation=last_obs,
            action=last_act,
            reward=poisoned_last_rew,  # poison inserted here
            done=last_done)