Beispiel #1
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             init_rnn_state=None):
        if init_rnn_state is not None:
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)

        else:
            dist_info, value = self.agent(*agent_inputs)

        dist = self.agent.distribution

        lr = dist.likelihood_ratio(action,
                                   old_dist_info=old_dist_info,
                                   new_dist_info=dist_info)
        kl = dist.kl(old_dist_info=old_dist_info, new_dist_info=dist_info)

        if init_rnn_state is not None:
            raise NotImplementedError
        else:
            mean_kl = valid_mean(kl)
            surr_loss = -valid_mean(lr * advantage)

        loss = surr_loss
        entropy = dist.mean_entropy(dist_info, valid)
        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, entropy, perplexity
Beispiel #2
0
    def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info,
            init_rnn_state=None):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
            new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
            1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = - valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_) ** 2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = - self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
Beispiel #3
0
    def loss(self, samples):
        agent_inputs = AgentInputs(
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        if self.agent.recurrent:
            init_rnn_state = self.samples.agent.agent_info.prev_rnn_state[
                0]  # T = 0.
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        # TODO: try to compute everyone on device.
        return_, advantage, valid = self.process_returns(samples)

        dist = self.agent.distribution
        logli = dist.log_likelihood(samples.agent.action, dist_info)
        pi_loss = -valid_mean(logli * advantage, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, entropy, perplexity
Beispiel #4
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             old_value,
             init_rnn_state=None):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state,
                                                      device=action.device)
        else:
            dist_info, value = self.agent(*agent_inputs, device=action.device)
        dist = self.agent.distribution

        # Surrogate policy loss
        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        # Surrogate value loss (if doing)
        if self.clip_vf_loss:
            v_loss_unclipped = (value - return_)**2
            v_clipped = old_value + torch.clamp(
                value - old_value, -self.ratio_clip, self.ratio_clip)
            v_loss_clipped = (v_clipped - return_)**2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            value_error = 0.5 * v_loss_max.mean()
        else:
            value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, pi_loss, value_loss, entropy, perplexity
Beispiel #5
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             init_rnn_state=None):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            if self.agent.both_actions:
                dist_info, og_dist_info, value, og_value = self.agent(
                    *agent_inputs)
            else:
                dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution
        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        if self.similarity_loss:  # Try KL next
            # pi_sim = self.agent.distribution.kl(og_dist_info, dist_info)
            pi_sim = F.cosine_similarity(dist_info.prob, og_dist_info.prob)
            value_sim = (value - og_value)**2

        loss += -self.similarity_coeff * pi_sim.mean() + 0.5 * value_sim.mean()
        # loss += self.similarity_coeff * (pi_sim.mean() + value_sim)

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
Beispiel #6
0
    def loss(
        self,
        agent_inputs,
        action,
        return_,
        advantage,
        valid,
        old_dist_info,
        init_rnn_state=None,
    ):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        ratio = ratio.clamp_max(1000)  # added (to prevent ratio == inf)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1.0 - self.ratio_clip,
                                    1.0 + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     probs, value, rnn_state = self.model(*agent_inputs,
                                          self.prev_rnn_state)
     dist_info = DistInfo(prob=probs)
     if self._mode == 'sample':
         action = self.distribution.sample(dist_info)
     elif self._mode == 'eval':
         action = torch.argmax(probs, dim=-1)
     # Model handles None, but Buffer does not, make zeros if needed:
     if self.prev_rnn_state is None:
         prev_rnn_state = buffer_func(rnn_state, torch.zeros_like)
     else:
         prev_rnn_state = self.prev_rnn_state
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #8
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.need_reset = np.zeros(len(self.envs), dtype=np.bool)
     self.done = np.zeros(len(
         self.envs), dtype=np.bool)  # 所有environment的done标志,初始化为"not done"
     self.temp_observation = buffer_method(
         self.samples_np.env.observation[0, :len(self.envs)], "copy")
    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        if samples is not None:
            self.replay_buffer.append_samples(self.samples_to_buffer(samples))

        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))

        batch_generator = self.replay_buffer.batch_generator(replay_ratio=self.epochs)
        for batch, init_rnn_state, buffer_wait_time in batch_generator:
            self.optimizer.zero_grad()
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            dist_info, value, _ = self.agent(*batch.agent_inputs, init_rnn_state)

            loss, opt_info = self.process_returns(reward=batch.reward,
                                                  done=batch.done,
                                                  value_prediction=value,
                                                  action=batch.action,
                                                  dist_info=dist_info,
                                                  old_dist_info=batch.dist_info,
                                                  opt_info=opt_info)
            loss.backward()
            self.optimizer.step()
            self.clamp_lagrange_multipliers()

            opt_info.loss.append(loss.item())
            opt_info.optim_buffer_wait_time.append(buffer_wait_time)
            self.update_counter += 1
        return opt_info
Beispiel #10
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.need_reset = np.zeros(len(self.envs), dtype=np.bool)
     # e.g. For episodic lives, hold the observation output when done, record
     # blanks for the rest of the batch, but reinstate the observation to start
     # next batch.
     self.temp_observation = buffer_method(self.step_buffer_np.observation, "copy")
Beispiel #11
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs,
                                                self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #12
0
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            int_pi, int_value, int_rnn_state = self.model_int(
                *agent_inputs, self.prev_int_rnn_state)
            dist_int_info = DistInfo(prob=int_pi)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)

        if self.dual_model:
            prev_int_rnn_state = self.prev_int_rnn_state or buffer_func(
                int_rnn_state, torch.zeros_like)
            prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose",
                                               0, 1)
            agent_info = AgentInfoRnnTwin(
                dist_info=dist_info,
                value=value,
                prev_rnn_state=prev_rnn_state,
                dist_int_info=dist_int_info,
                int_value=int_value,
                prev_int_rnn_state=prev_int_rnn_state)
        else:
            agent_info = AgentInfoRnn(dist_info=dist_info,
                                      value=value,
                                      prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        if self.dual_model:
            self.advance_int_rnn_state(int_rnn_state)
        return AgentStep(action=action, agent_info=agent_info)
Beispiel #13
0
    def loss(self, samples):
        """
        Computes the training loss: policy_loss + value_loss + entropy_loss.
        Policy loss: log-likelihood of actions * advantages
        Value loss: 0.5 * (estimated_value - return) ^ 2
        Organizes agent inputs from training samples, calls the agent instance
        to run forward pass on training data, and uses the
        ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        agent_inputs = AgentInputs(
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        if self.agent.recurrent:
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[
                0]  # T = 0.
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(
                *agent_inputs,
                init_rnn_state,
                device=agent_inputs.prev_action.device)
        else:
            dist_info, value = self.agent(
                *agent_inputs, device=agent_inputs.prev_action.device)
        # TODO: try to compute everyone on device.
        return_, advantage, valid = self.process_returns(samples)

        dist = self.agent.distribution
        logli = dist.log_likelihood(samples.agent.action, dist_info)
        pi_loss = -valid_mean(logli * advantage, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, pi_loss, value_loss, entropy, perplexity
Beispiel #14
0
    def beta_kl_losses(
        self,
        agent_inputs,
        action,
        return_,
        advantage,
        valid,
        old_dist_info,
        c_return,
        c_advantage,
        init_rnn_state=None,
    ):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                *agent_inputs, init_rnn_state)
        else:
            r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                *agent_inputs)
        dist = self.agent.distribution

        r_ratio = dist.likelihood_ratio(action,
                                        old_dist_info=old_dist_info,
                                        new_dist_info=r_dist_info)
        surr_1 = r_ratio * advantage
        r_clipped_ratio = torch.clamp(r_ratio, 1.0 - self.ratio_clip,
                                      1.0 + self.ratio_clip)
        surr_2 = r_clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        beta_r_loss = -valid_mean(surrogate, valid)

        c_ratio = dist.likelihood_ratio(action,
                                        old_dist_info=old_dist_info,
                                        new_dist_info=c_dist_info)
        c_surr_1 = c_ratio * c_advantage
        c_clipped_ratio = torch.clamp(c_ratio, 1.0 - self.ratio_clip,
                                      1.0 + self.ratio_clip)
        c_surr_2 = c_clipped_ratio * c_advantage
        c_surrogate = torch.max(c_surr_1, c_surr_2)
        beta_c_loss = valid_mean(c_surrogate, valid)

        return beta_r_loss, beta_c_loss
Beispiel #15
0
 def rollout_policy(self, steps: int, policy, prev_state: RSSMState):
     """
     Roll out the model with a policy function.
     :param steps: number of steps to roll out
     :param policy: RSSMState -> action
     :param prev_state: RSSM state, size(batch_size, state_size)
     :return: next states size(time_steps, batch_size, state_size),
              actions size(time_steps, batch_size, action_size)
     """
     state = prev_state
     next_states = []
     actions = []
     state = buffer_method(state, 'detach')
     for t in range(steps):
         action, _ = policy(buffer_method(state, 'detach'))
         state = self.transition_model(action, state)
         next_states.append(state)
         actions.append(action)
     next_states = stack_states(next_states, dim=0)
     actions = torch.stack(actions, dim=0)
     return next_states, actions
Beispiel #16
0
 def to_agent_step(self, output):
     """Convert the output of the NN model into step info for the agent.
     """
     q, rnn_state = output
     # q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state, action, q = buffer_to((prev_rnn_state, action, q), device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #17
0
    def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info,
            init_rnn_state=None):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)

        # TODO IF MULTIAGENT, reshape things
        # Just kidding. It seems that the dist.* functions can operate on multi-dimensional tensors
        # Need to double check this is true, but things seem to be ok
            # (entropy and likelihood ratios are computed along last dim)

        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
            new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
            1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = - valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_) ** 2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = - self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
Beispiel #18
0
 def step(self, observation, prev_action, prev_reward):
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info, value=value,
         prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #19
0
 def step(self, observation, prev_action, prev_reward):
     """Computes Q-values for states/observations and selects actions by
     epsilon-greedy (no grad).  Advances RNN state."""
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     q, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)  # Model handles None.
     q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state = buffer_to(prev_rnn_state, device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #20
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     q, rnn_state = self.model(*agent_inputs,
                               self.prev_rnn_state)  # Model handles None.
     q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state or buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state = buffer_to(prev_rnn_state, device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #21
0
    def imagine_trajectories(self, _initial_states: RSSMState, batch_t: int,
                             batch_b: int):
        ############# Imagine trajectories ##########
        ########### {sτ ; aτ } from each st ##########

        # no gradient for input (initial) states
        with torch.no_grad():
            initial_states = buffer_method(_initial_states[:-1, :], 'reshape',
                                           (batch_t - 1) * (batch_b),
                                           -1)  # RSSM mean..(2450, 30)

        # imagine trajectories with a finite horizon H
        w_transition_represent = self.agent.model.rollout
        policy = self.agent.model.policy
        with FreezeParameters(self.world_modules):
            imagined_states, _ = w_transition_represent.rollout_policy(
                self.horizon, policy,
                initial_states)  # RSSM mean..(10, 2450, 30)

        return imagined_states
Beispiel #22
0
 def optimize_agent(self, itr, samples=None, sampler_itr=None):
     """
     Train the agent, for multiple epochs over minibatches taken from the
     input samples.  Organizes agent inputs from the training data, and
     moves them to device (e.g. GPU) up front, so that minibatches are
     formed within device, without further data transfer.
     """
     opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
     agent_inputs = AgentInputs(  # Move inputs to device once, index there.
         observation=samples.env.observation,
         prev_action=samples.agent.prev_action,
         prev_reward=samples.env.prev_reward,
     )
     agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
     init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0],
                                 device=self.agent.device)
     T, B = samples.env.reward.shape[:2]
     mb_size = B // self.minibatches
     for _ in range(self.epochs):
         for idxs in iterate_mb_idxs(B, mb_size, shuffle=True):
             self.optimizer.zero_grad()
             init_rnn_state = buffer_method(init_rnn_states[idxs],
                                            "transpose", 0, 1)
             dist_info, value, _ = self.agent(*agent_inputs[:, idxs],
                                              init_rnn_state)
             loss, opt_info = self.process_returns(
                 samples.env.reward[:, idxs],
                 done=samples.env.done[:, idxs],
                 value_prediction=value.cpu(),
                 action=samples.agent.action[:, idxs],
                 dist_info=dist_info,
                 old_dist_info=samples.agent.agent_info.dist_info[:, idxs],
                 opt_info=opt_info)
             loss.backward()
             self.optimizer.step()
             self.clamp_lagrange_multipliers()
             opt_info.loss.append(loss.item())
             self.update_counter += 1
     return opt_info
Beispiel #23
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi, rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     terminations = torch.bernoulli(beta).bool()  # Sample terminations
     dist_info_omega = DistInfo(prob=pi)
     new_o = self.sample_option(terminations, dist_info_omega)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi * q).sum(-1),
                                 termination=terminations,
                                 inter_option_dist_info=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #24
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_option_input = self._prev_option
     if prev_option_input is None:  # Hack to extract previous option
         prev_option_input = torch.full_like(prev_action, -1)
     prev_action = self.distribution.to_onehot(prev_action)
     prev_option_input = self.distribution_omega.to_onehot_with_invalid(
         prev_option_input)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, prev_option_input),
         device=self.device)
     pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs,
                                                   self.prev_rnn_state)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi_omega * q).sum(-1),
                                 termination=terminations,
                                 dist_info_omega=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #25
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward), device=self.device
     )
     pi, value, rnn_state, conv = self.model(*model_inputs, self.prev_rnn_state)
     if self._act_uniform:
         pi[:] = 1.0 / pi.shape[-1]  # uniform
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnnConv(
         dist_info=dist_info,
         value=value,
         prev_rnn_state=prev_rnn_state,
         conv=conv if self.store_latent else None,
     )  # Don't write the extra data.
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Beispiel #26
0
    def loss(self, samples):
        """Samples have leading Time and Batch dimentions [T,B,..]. Move all
        samples to device first, and then slice for sub-sequences.  Use same
        init_rnn_state for agent and target; start both at same t.  Warmup the
        RNN state first on the warmup subsequence, then train on the remaining
        subsequence.

        Returns loss (usually use MSE, not Huber), TD-error absolute values,
        and new sequence-wise priorities, based on weighted sum of max and mean
        TD-error over the sequence."""
        all_observation, all_action, all_reward = buffer_to(
            (samples.all_observation, samples.all_action, samples.all_reward),
            device=self.agent.device)
        wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return
        if wT > 0:
            warmup_slice = slice(None, wT)  # Same for agent and target.
            warmup_inputs = AgentInputs(
                observation=all_observation[warmup_slice],
                prev_action=all_action[warmup_slice],
                prev_reward=all_reward[warmup_slice],
            )
        agent_slice = slice(wT, wT + bT)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice],
            prev_action=all_action[agent_slice],
            prev_reward=all_reward[agent_slice],
        )
        target_slice = slice(wT, None)  # Same start t as agent. (wT + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )
        action = samples.all_action[wT + 1:wT + 1 + bT]  # CPU.
        return_ = samples.return_[wT:wT + bT]
        done_n = samples.done_n[wT:wT + bT]
        if self.store_rnn_state_interval == 0:
            init_rnn_state = None
        else:
            # [B,N,H]-->[N,B,H] cudnn.
            init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
        if wT > 0:  # Do warmup.
            with torch.no_grad():
                _, target_rnn_state = self.agent.target(*warmup_inputs, init_rnn_state)
                _, init_rnn_state = self.agent(*warmup_inputs, init_rnn_state)
            # Recommend aligning sampling batch_T and store_rnn_interval with
            # warmup_T (and no mid_batch_reset), so that end of trajectory
            # during warmup leads to new trajectory beginning at start of
            # training segment of replay.
            warmup_invalid_mask = valid_from_done(samples.done[:wT])[-1] == 0  # [B]
            init_rnn_state[:, warmup_invalid_mask] = 0  # [N,B,H] (cudnn)
            target_rnn_state[:, warmup_invalid_mask] = 0
        else:
            target_rnn_state = init_rnn_state

        qs, _ = self.agent(*agent_inputs, init_rnn_state)  # [T,B,A]
        q = select_at_indexes(action, qs)
        with torch.no_grad():
            target_qs, _ = self.agent.target(*target_inputs, target_rnn_state)
            if self.double_dqn:
                next_qs, _ = self.agent(*target_inputs, init_rnn_state)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
            target_q = target_q[-bT:]  # Same length as q.

        disc = self.discount ** self.n_step_return
        y = self.value_scale(return_ + (1 - done_n.float()) * disc *
            self.inv_value_scale(target_q))  # [T,B]
        delta = y - q
        losses = 0.5 * delta ** 2
        abs_delta = abs(delta)
        # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights.unsqueeze(0)  # weights: [B] --> [1,B]
        valid = valid_from_done(samples.done[wT:])  # 0 after first done.
        loss = valid_mean(losses, valid)
        td_abs_errors = abs_delta.detach()
        if self.delta_clip is not None:
            td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip)  # [T,B]
        valid_td_abs_errors = td_abs_errors * valid
        max_d = torch.max(valid_td_abs_errors, dim=0).values
        mean_d = valid_mean(td_abs_errors, valid, dim=0)  # Still high if less valid.
        priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]

        return loss, valid_td_abs_errors, priorities
Beispiel #27
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             bc_observations,
             bc_actions,
             init_rnn_state=None):
        """
        Compute the BC-augmented training loss:
            policy_loss + value_loss + entropy_loss + bc_loss
        Policy loss: min(likelhood-ratio * advantage,
                         clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        BC loss: xent(policy(demo_states), action_labels)
        Calls the agent to compute forward pass on training data, and uses the
        ``agent.distribution`` to compute likelihoods and entropies. Valid for
        feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        # TODO: log the value error and correlation
        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        # BC loss (this is the only new part)
        if self.bc_loss_coeff:
            if init_rnn_state is not None:
                raise NotImplementedError("doesn't quite work with RNNs yet")
                # bc_dist_info, _, _ = self.agent(*bc_agent_inputs,
                #                                 init_rnn_state)
            else:
                # This will break if I have an agent/model that actually needs
                # the previous action and reward. (IIRC that only includes
                # recurrent agents in rlpyt, though)
                dummy_prev_action = bc_actions
                dummy_prev_reward = torch.zeros(bc_actions.shape[0],
                                                device=bc_actions.device)
                bc_dist_info, _ = self.agent(bc_observations,
                                             dummy_prev_action,
                                             dummy_prev_reward)
            expert_ll = dist.log_likelihood(bc_actions, bc_dist_info)
            # bc_loss = -self.bc_loss_coeff * valid_mean(expert_ll, bc_valid)
            # TODO: also log BC accuracy (or maybe do it somewhere else, IDK)
            bc_loss = -self.bc_loss_coeff * expert_ll.mean()
        else:
            bc_loss = 0.0

        loss = pi_loss + value_loss + entropy_loss + bc_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
Beispiel #28
0
    def combined_loss(self,
                      agent_inputs,
                      action,
                      next_obs,
                      ext_return,
                      ext_adv,
                      int_return,
                      int_adv,
                      valid,
                      old_dist_info,
                      init_rnn_state=None):
        """
        Alternative to ``loss`` in PPO.
        This functions runs ``bonus_call``, performing a forward pass of the intrinsic bonus model
        and producing a combined reward/advantage stream, and then a combined loss.
        """
        # Run base actor critic model
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, ext_value, int_value, _rnn_state = self.agent(
                *agent_inputs, init_rnn_state)
        else:
            dist_info, ext_value, int_value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        # Second call to bonus model, generates self-supervised bonus model loss
        # Leading batch dims have already been flattened after entering minibatch
        bonus_model_inputs = self.agent.extract_bonus_inputs(
            observation=agent_inputs.observation,
            next_observation=
            next_obs,  # May be same as observation (dummy placeholder) if algo set next_obs=False
            action=action)
        _, bonus_loss = self.agent.bonus_call(bonus_model_inputs)
        bonus_loss *= self.bonus_loss_coeff

        # Fuse reward streams by producing combined advantages
        advantage = self.ext_rew_coeff * ext_adv + self.int_rew_coeff * int_adv

        # Construct PPO loss
        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        ext_value_error = 0.5 * (ext_value - ext_return)**2
        int_value_error = 0.5 * (int_value - int_return)**2
        value_loss = self.value_loss_coeff * (
            valid_mean(ext_value_error, valid) + int_value_error.mean())

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss + bonus_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity, pi_loss, value_loss, entropy_loss, bonus_loss
    def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int):
        """
        Compute the loss for a batch of data.  This includes computing the model and reward losses on the given data,
        as well as using the dynamics model to generate additional rollouts, which are used for the actor and value
        components of the loss.
        :param samples: samples from replay
        :param sample_itr: sample iteration
        :param opt_itr: optimization iteration
        :return: FloatTensor containing the loss
        """
        model = self.agent.model

        observation = samples.all_observation[:
                                              -1]  # [t, t+batch_length+1] -> [t, t+batch_length]
        action = samples.all_action[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = samples.all_reward[
            1:]  # [t-1, t+batch_length] -> [t, t+batch_length]
        reward = reward.unsqueeze(2)
        done = samples.done
        done = done.unsqueeze(2)

        # Extract tensors from the Samples object
        # They all have the batch_t dimension first, but we'll put the batch_b dimension first.
        # Also, we convert all tensors to floats so they can be fed into our models.

        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            observation, 3)
        # squeeze batch sizes to single batch dimension for imagination roll-out
        batch_size = batch_t * batch_b

        # normalize image
        observation = observation.type(self.type) / 255.0 - 0.5
        # embed the image
        embed = model.observation_encoder(observation)

        prev_state = model.representation.initial_state(batch_b,
                                                        device=action.device,
                                                        dtype=action.dtype)
        # Rollout model by taking the same series of actions as the real model
        prior, post = model.rollout.rollout_representation(
            batch_t, embed, action, prev_state)
        # Flatten our data (so first dimension is batch_t * batch_b = batch_size)
        # since we're going to do a new rollout starting from each state visited in each batch.

        # Compute losses for each component of the model

        # Model Loss
        feat = get_feat(post)
        image_pred = model.observation_decoder(feat)
        reward_pred = model.reward_model(feat)
        reward_loss = -torch.mean(reward_pred.log_prob(reward))
        image_loss = -torch.mean(image_pred.log_prob(observation))
        pcont_loss = torch.tensor(0.)  # placeholder if use_pcont = False
        if self.use_pcont:
            pcont_pred = model.pcont(feat)
            pcont_target = self.discount * (1 - done.float())
            pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target))
        prior_dist = get_dist(prior)
        post_dist = get_dist(post)
        div = torch.mean(
            torch.distributions.kl.kl_divergence(post_dist, prior_dist))
        div = torch.max(div, div.new_full(div.size(), self.free_nats))
        model_loss = self.kl_scale * div + reward_loss + image_loss
        if self.use_pcont:
            model_loss += self.pcont_scale * pcont_loss

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Actor Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            if self.use_pcont:
                # "Last step could be terminal." Done in TF2 code, but unclear why
                flat_post = buffer_method(post[:-1, :], 'reshape',
                                          (batch_t - 1) * (batch_b), -1)
            else:
                flat_post = buffer_method(post, 'reshape', batch_size, -1)
        # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real.
        # imag_feat shape is [horizon, batch_t * batch_b, feature_size]
        with FreezeParameters(self.model_modules):
            imag_dist, _ = model.rollout.rollout_policy(
                self.horizon, model.policy, flat_post)

        # Use state features (deterministic and stochastic) to predict the image and reward
        imag_feat = get_feat(
            imag_dist)  # [horizon, batch_t * batch_b, feature_size]
        # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode
        # If we want to use other distributions we'll have to fix this.
        # We calculate the target here so no grad necessary

        # freeze model parameters as only action model gradients needed
        with FreezeParameters(self.model_modules + self.value_modules):
            imag_reward = model.reward_model(imag_feat).mean
            value = model.value_model(imag_feat).mean
        # Compute the exponential discounted sum of rewards
        if self.use_pcont:
            with FreezeParameters([model.pcont]):
                discount_arr = model.pcont(imag_feat).mean
        else:
            discount_arr = self.discount * torch.ones_like(imag_reward)
        returns = self.compute_return(imag_reward[:-1],
                                      value[:-1],
                                      discount_arr[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self.discount_lambda)
        # Make the top row 1 so the cumulative product starts with discount^0
        discount_arr = torch.cat(
            [torch.ones_like(discount_arr[:1]), discount_arr[1:]])
        discount = torch.cumprod(discount_arr[:-1], 0)
        actor_loss = -torch.mean(discount * returns)

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # Don't let gradients pass through to prevent overwriting gradients.
        # Value Loss

        # remove gradients from previously calculated tensors
        with torch.no_grad():
            value_feat = imag_feat[:-1].detach()
            value_discount = discount.detach()
            value_target = returns.detach()
        value_pred = model.value_model(value_feat)
        log_prob = value_pred.log_prob(value_target)
        value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2))

        # ------------------------------------------  Gradient Barrier  ------------------------------------------------
        # loss info
        with torch.no_grad():
            prior_ent = torch.mean(prior_dist.entropy())
            post_ent = torch.mean(post_dist.entropy())
            loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent,
                                 post_ent, div, reward_loss, image_loss,
                                 pcont_loss)

            if self.log_video:
                if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0:
                    self.write_videos(observation,
                                      action,
                                      image_pred,
                                      post,
                                      step=sample_itr,
                                      n=self.video_summary_b,
                                      t=self.video_summary_t)

        return model_loss, actor_loss, value_loss, loss_info
Beispiel #30
0
    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values and an entropy term.  Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        # SamplesFromReplay = namedarraytuple("SamplesFromReplay",
        # ["all_observation", "all_action", "all_reward", "return_", "done", "done_n", "init_rnn_state"])
        all_observation, all_action, all_reward = buffer_to(
            (samples.all_observation, samples.all_action, samples.all_reward),
            device=self.agent.device)  # all have (wT + bT + nsr) x bB
        wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return
        if wT > 0:
            warmup_slice = slice(None, wT)  # Same for agent and target.
            warmup_inputs = AgentInputs(
                observation=all_observation[warmup_slice],
                prev_action=all_action[warmup_slice],
                prev_reward=all_reward[warmup_slice],
            )
        agent_slice = slice(wT, wT + bT)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice],
            prev_action=all_action[agent_slice],
            prev_reward=all_reward[agent_slice],
        )
        target_slice = slice(wT,
                             None)  # Same start t as agent. (wT + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )
        warmup_action = samples.all_action[1:wT + 1]
        action = samples.all_action[
            wT + 1:wT + 1 +
            bT]  # 'current' action by shifting index by 1 from prev_action
        return_ = samples.return_[wT:wT + bT]
        done_n = samples.done_n[wT:wT + bT]
        if self.store_rnn_state_interval == 0:
            init_rnn_state = None
        else:
            # [B,N,H]-->[N,B,H] cudnn.
            init_rnn_state = buffer_method(samples.init_rnn_state, "transpose",
                                           0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
        if wT > 0:  # Do warmup.
            with torch.no_grad():
                _, target_q1_rnn_state, _, target_q2_rnn_state = self.agent.target_q(
                    *warmup_inputs, warmup_action, init_rnn_state,
                    init_rnn_state)
                _, _, _, init_rnn_state = self.agent.pi(
                    *warmup_inputs, init_rnn_state)
            # Recommend aligning sampling batch_T and store_rnn_interval with
            # warmup_T (and no mid_batch_reset), so that end of trajectory
            # during warmup leads to new trajectory beginning at start of
            # training segment of replay.
            warmup_invalid_mask = valid_from_done(
                samples.done[:wT])[-1] == 0  # [B]
            init_rnn_state[:, warmup_invalid_mask] = 0  # [N,B,H] (cudnn)
            target_q1_rnn_state[:, warmup_invalid_mask] = 0
            target_q2_rnn_state[:, warmup_invalid_mask] = 0
        else:
            target_q1_rnn_state = init_rnn_state
            target_q2_rnn_state = init_rnn_state

        valid = valid_from_done(samples.done)[-bT:]

        q1, _, q2, _ = self.agent.q(*agent_inputs, action, init_rnn_state,
                                    init_rnn_state)
        with torch.no_grad():
            target_action, target_log_pi, _, _ = self.agent.pi(
                *target_inputs, init_rnn_state)
            target_q1, _, target_q2, _ = self.agent.target_q(
                *target_inputs, target_action, target_q1_rnn_state,
                target_q2_rnn_state)
            target_q1 = target_q1[-bT:]  # Same length as q.
            target_q2 = target_q2[-bT:]
            target_log_pi = target_log_pi[-bT:]

        min_target_q = torch.min(target_q1, target_q2)
        target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * return_ +
             (1 - done_n.float()) * disc * target_value)
        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, (pi_mean, pi_log_std), _ = self.agent.pi(
            *agent_inputs, init_rnn_state)
        log_target1, _, log_target2, _ = self.agent.q(*agent_inputs,
                                                      new_action,
                                                      init_rnn_state,
                                                      init_rnn_state)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None and self.fixed_alpha is None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std))
        return losses, values