def forward_loss(self, states, R, actions, action_rewards, hidden, dones):
     states, R, actions, action_rewards, dones = totorch_many(states, R, actions, action_rewards, dones, device=self.device)
     hidden = totorch_many(*hidden, device=self.device)
     actions_onehot = F.one_hot(actions.long(), num_classes=self.action_size)
     policies, values, _ = self.forward(states, action_rewards, hidden, dones)
     forward_loss = self.policy.loss(policies, R, values, actions_onehot)
     return forward_loss
 def backprop(self, state, R, action, action_reward, hidden, done):
     state, R, action, action_reward, done = totorch_many(state, R, action, action_reward, done, device=self.device)
     hidden = totorch_many(*hidden, device=self.device)
     action_onehot = F.one_hot(action.long(), num_classes=self.action_size)
     policy, value, hidden = self.forward(state, action_reward, hidden, done)
     loss = self.loss(policy, R, value, action_onehot)
     loss.backward()
     if self.grad_clip is not None:
         torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
     self.optimiser.step()
     self.optimiser.zero_grad()
     self.scheduler.step()
     return loss.detach().cpu().numpy()
示例#3
0
    def auxiliary_loss(self, reward_states, rewards, Qaux_target, Qaux_actions,
                       replay_states, replay_R):
        reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R = totorch_many(
            reward_states,
            rewards,
            Qaux_target,
            Qaux_actions,
            replay_states,
            replay_R,
            device=self.device)

        policy_enc = self.policy.model(replay_states)
        replay_values = self.policy.Ve(policy_enc)
        reward_loss = self.reward_loss(reward_states, rewards)
        replay_loss = self.replay_loss(replay_R, replay_values)
        aux_loss = self.RP * reward_loss + self.VR * replay_loss

        Qaux_actions = Qaux_actions.long()

        if self.pixel_control:
            Qaux = self.Qaux(policy_enc)
            pixel_loss = self.pixel_loss(Qaux, Qaux_actions, Qaux_target)
            aux_loss += self.PC * pixel_loss

        return aux_loss
示例#4
0
 def forward_loss(self, states, actions, Re, Ri, Adv, old_policy):
     states, actions, Re, Ri, Adv, old_policy = totorch_many(
         states, actions, Re, Ri, Adv, old_policy, device=self.device)
     actions_onehot = F.one_hot(actions.long(), self.action_size)
     policy, Ve, Vi = self.forward(states)
     forward_loss = self.policy.loss(policy, Re, Ri, Ve, Vi, Adv,
                                     actions_onehot, old_policy)
     return forward_loss
示例#5
0
 def intrinsic_reward(self, next_state: np.ndarray, state_mean: np.ndarray,
                      state_std):
     next_state, state_mean, state_std = totorch_many(next_state,
                                                      state_mean,
                                                      state_std,
                                                      device=self.device)
     with torch.no_grad():
         intr_reward = self._intr_reward(next_state, state_mean, state_std)
     return tonumpy(intr_reward)
示例#6
0
 def predictor_loss(self, next_states, state_mean, state_std):
     'loss for predictor network'
     next_states, state_mean, state_std = totorch_many(next_states,
                                                       state_mean,
                                                       state_std,
                                                       device=self.device)
     predictor_loss = self._intr_reward(next_states, state_mean,
                                        state_std).mean()
     return predictor_loss
示例#7
0
 def evaluate(self,
              state: np.ndarray,
              hidden: np.ndarray = None,
              done=None):
     state = totorch(state, self.device)
     hidden = totorch_many(
         *hidden, device=self.device) if hidden is not None else None
     with torch.no_grad():
         policy, value, hidden = self.forward(state, hidden, done)
     return tonumpy(policy), tonumpy(value), tonumpy_many(*hidden)
示例#8
0
    def backprop(self, state, R):
        state, R = totorch_many(state, R, device=self.device)
        value = self.forward(state)
        loss = self.loss(value, R)

        loss.backward()
        if self.grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
        self.optimiser.step()
        self.optimiser.zero_grad()
        self.scheduler.step()
        return loss.detach().cpu().numpy()
示例#9
0
 def get_intr_reward(self, state, action, next_state):
     state, next_state, action = totorch_many(state,
                                              next_state,
                                              action,
                                              device=self.device)
     action = action.long()
     phi1 = self.phi(state)
     phi2 = self.phi(next_state)
     action_onehot = F.one_hot(action, self.action_size)
     with torch.no_grad():
         intr_reward = self.intr_reward(phi1, action_onehot, phi2)
     return intr_reward.cpu().numpy()
示例#10
0
 def backprop(self, state: np.ndarray, R: np.ndarray, action: np.ndarray):
     state, R, action = totorch_many(state, R, action, device=self.device)
     action_onehot = F.one_hot(action.long(), num_classes=self.action_size)
     Qsa = self.forward(state)
     loss = self.loss(Qsa, R, action_onehot)
     loss.backward()
     if self.grad_clip is not None:
         torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
     self.optimiser.step()
     self.optimiser.zero_grad()
     self.scheduler.step()
     return loss.detach().cpu().numpy()
示例#11
0
    def backprop(self, state, Re, Ri, Adv, action, old_policy):
        state, action, Re, Ri, Adv, old_policy = totorch_many(
            state, action, Re, Ri, Adv, old_policy, device=self.device)
        action_onehot = F.one_hot(action.long(), self.action_size)
        policy, Ve, Vi = self.forward(state)
        loss = self.loss(policy, Re, Ri, Ve, Vi, Adv, action_onehot,
                         old_policy)

        loss.backward()
        if self.grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
        self.optimiser.step()
        self.optimiser.zero_grad()
        self.scheduler.step()
        return loss.detach().cpu().numpy()
示例#12
0
 def auxiliary_loss(self, reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, replay_hidden, replay_dones, replay_actsrews):
     reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, replay_dones, replay_actsrews = totorch_many(reward_states, rewards, Qaux_target, Qaux_actions,
                                                                                                                     replay_states, replay_R, replay_dones, replay_actsrews, device=self.device)
     replay_hidden = totorch_many(*replay_hidden, device=self.device)
     lstm_state, _ = self.policy.lstm_forward(replay_states, replay_actsrews, replay_hidden, replay_dones)
     replay_values = self.policy.V(lstm_state)
     
     reward_loss = self.reward_loss(reward_states, rewards)
     replay_loss = self.replay_loss(replay_R, replay_values)
     aux_loss = self.RP * reward_loss + self.VR * replay_loss
     
     if self.pixel_control:
         Qaux = self.Qaux(lstm_state)
         pixel_loss = self.pixel_loss(Qaux, Qaux_actions, Qaux_target)
         aux_loss += self.PC * pixel_loss
     
     return aux_loss
示例#13
0
 def backprop(self, state, next_state, R, Adv, action, state_mean,
              state_std):
     state, next_state, R, Adv, action, state_mean, state_std = totorch_many(
         state,
         next_state,
         R,
         Adv,
         action,
         state_mean,
         state_std,
         device=self.device)
     policy, value = self.AC.forward(state)
     action_onehot = F.one_hot(action.long(), self.action_size)
     policy_loss = self.AC.loss(policy, R, value, action_onehot)
     ICM_loss = self.ICM.loss((state - state_mean) / state_std, action,
                              (next_state - state_mean) / state_std)
     loss = self.policy_importance * policy_loss + self.reward_scale * ICM_loss
     loss.backward()
     if self.grad_clip is not None:
         torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
     self.optimiser.step()
     self.optimiser.zero_grad()
     self.scheduler.step()
     return loss.detach().cpu().numpy()
示例#14
0
    def backprop(self, state, next_state, R_extr, R_intr, Adv, actions,
                 old_policy, state_mean, state_std):
        state, next_state, R_extr, R_intr, Adv, actions, old_policy, state_mean, state_std = totorch_many(
            state,
            next_state,
            R_extr,
            R_intr,
            Adv,
            actions,
            old_policy,
            state_mean,
            state_std,
            device=self.device)
        policy, Ve, Vi = self.policy.forward(state)
        actions_onehot = F.one_hot(actions.long(), self.action_size)
        policy_loss = self.policy.loss(policy, R_extr, R_intr, Ve, Vi, Adv,
                                       actions_onehot, old_policy)

        predictor_loss = self._intr_reward(next_state, state_mean,
                                           state_std).mean()
        loss = policy_loss + predictor_loss

        loss.backward()
        if self.grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip)
        self.optimiser.step()
        self.optimiser.zero_grad()
        self.scheduler.step()
        return loss.detach().cpu().numpy()
示例#15
0
 def get_pred_action(self, state, next_state):
     state, next_state = totorch_many(state, next_state, device=self.device)
     return self.pred_action(state, next_state)
示例#16
0
 def get_pixel_control(self, state:np.ndarray, action_reward, hidden):
     state, action_reward, hidden = totorch(state, self.device), totorch(action_reward, self.device), totorch_many(*hidden, device=self.device)
     with torch.no_grad():
         lstm_state, _ = self.policy.lstm_forward(state, action_reward, hidden, done=None)
         Qaux = self.Qaux(lstm_state)
     return tonumpy(Qaux)