def forward_loss(self, states, R, actions, action_rewards, hidden, dones): states, R, actions, action_rewards, dones = totorch_many(states, R, actions, action_rewards, dones, device=self.device) hidden = totorch_many(*hidden, device=self.device) actions_onehot = F.one_hot(actions.long(), num_classes=self.action_size) policies, values, _ = self.forward(states, action_rewards, hidden, dones) forward_loss = self.policy.loss(policies, R, values, actions_onehot) return forward_loss
def backprop(self, state, R, action, action_reward, hidden, done): state, R, action, action_reward, done = totorch_many(state, R, action, action_reward, done, device=self.device) hidden = totorch_many(*hidden, device=self.device) action_onehot = F.one_hot(action.long(), num_classes=self.action_size) policy, value, hidden = self.forward(state, action_reward, hidden, done) loss = self.loss(policy, R, value, action_onehot) loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def auxiliary_loss(self, reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R): reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R = totorch_many( reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, device=self.device) policy_enc = self.policy.model(replay_states) replay_values = self.policy.Ve(policy_enc) reward_loss = self.reward_loss(reward_states, rewards) replay_loss = self.replay_loss(replay_R, replay_values) aux_loss = self.RP * reward_loss + self.VR * replay_loss Qaux_actions = Qaux_actions.long() if self.pixel_control: Qaux = self.Qaux(policy_enc) pixel_loss = self.pixel_loss(Qaux, Qaux_actions, Qaux_target) aux_loss += self.PC * pixel_loss return aux_loss
def forward_loss(self, states, actions, Re, Ri, Adv, old_policy): states, actions, Re, Ri, Adv, old_policy = totorch_many( states, actions, Re, Ri, Adv, old_policy, device=self.device) actions_onehot = F.one_hot(actions.long(), self.action_size) policy, Ve, Vi = self.forward(states) forward_loss = self.policy.loss(policy, Re, Ri, Ve, Vi, Adv, actions_onehot, old_policy) return forward_loss
def intrinsic_reward(self, next_state: np.ndarray, state_mean: np.ndarray, state_std): next_state, state_mean, state_std = totorch_many(next_state, state_mean, state_std, device=self.device) with torch.no_grad(): intr_reward = self._intr_reward(next_state, state_mean, state_std) return tonumpy(intr_reward)
def predictor_loss(self, next_states, state_mean, state_std): 'loss for predictor network' next_states, state_mean, state_std = totorch_many(next_states, state_mean, state_std, device=self.device) predictor_loss = self._intr_reward(next_states, state_mean, state_std).mean() return predictor_loss
def evaluate(self, state: np.ndarray, hidden: np.ndarray = None, done=None): state = totorch(state, self.device) hidden = totorch_many( *hidden, device=self.device) if hidden is not None else None with torch.no_grad(): policy, value, hidden = self.forward(state, hidden, done) return tonumpy(policy), tonumpy(value), tonumpy_many(*hidden)
def backprop(self, state, R): state, R = totorch_many(state, R, device=self.device) value = self.forward(state) loss = self.loss(value, R) loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def get_intr_reward(self, state, action, next_state): state, next_state, action = totorch_many(state, next_state, action, device=self.device) action = action.long() phi1 = self.phi(state) phi2 = self.phi(next_state) action_onehot = F.one_hot(action, self.action_size) with torch.no_grad(): intr_reward = self.intr_reward(phi1, action_onehot, phi2) return intr_reward.cpu().numpy()
def backprop(self, state: np.ndarray, R: np.ndarray, action: np.ndarray): state, R, action = totorch_many(state, R, action, device=self.device) action_onehot = F.one_hot(action.long(), num_classes=self.action_size) Qsa = self.forward(state) loss = self.loss(Qsa, R, action_onehot) loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def backprop(self, state, Re, Ri, Adv, action, old_policy): state, action, Re, Ri, Adv, old_policy = totorch_many( state, action, Re, Ri, Adv, old_policy, device=self.device) action_onehot = F.one_hot(action.long(), self.action_size) policy, Ve, Vi = self.forward(state) loss = self.loss(policy, Re, Ri, Ve, Vi, Adv, action_onehot, old_policy) loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def auxiliary_loss(self, reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, replay_hidden, replay_dones, replay_actsrews): reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, replay_dones, replay_actsrews = totorch_many(reward_states, rewards, Qaux_target, Qaux_actions, replay_states, replay_R, replay_dones, replay_actsrews, device=self.device) replay_hidden = totorch_many(*replay_hidden, device=self.device) lstm_state, _ = self.policy.lstm_forward(replay_states, replay_actsrews, replay_hidden, replay_dones) replay_values = self.policy.V(lstm_state) reward_loss = self.reward_loss(reward_states, rewards) replay_loss = self.replay_loss(replay_R, replay_values) aux_loss = self.RP * reward_loss + self.VR * replay_loss if self.pixel_control: Qaux = self.Qaux(lstm_state) pixel_loss = self.pixel_loss(Qaux, Qaux_actions, Qaux_target) aux_loss += self.PC * pixel_loss return aux_loss
def backprop(self, state, next_state, R, Adv, action, state_mean, state_std): state, next_state, R, Adv, action, state_mean, state_std = totorch_many( state, next_state, R, Adv, action, state_mean, state_std, device=self.device) policy, value = self.AC.forward(state) action_onehot = F.one_hot(action.long(), self.action_size) policy_loss = self.AC.loss(policy, R, value, action_onehot) ICM_loss = self.ICM.loss((state - state_mean) / state_std, action, (next_state - state_mean) / state_std) loss = self.policy_importance * policy_loss + self.reward_scale * ICM_loss loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def backprop(self, state, next_state, R_extr, R_intr, Adv, actions, old_policy, state_mean, state_std): state, next_state, R_extr, R_intr, Adv, actions, old_policy, state_mean, state_std = totorch_many( state, next_state, R_extr, R_intr, Adv, actions, old_policy, state_mean, state_std, device=self.device) policy, Ve, Vi = self.policy.forward(state) actions_onehot = F.one_hot(actions.long(), self.action_size) policy_loss = self.policy.loss(policy, R_extr, R_intr, Ve, Vi, Adv, actions_onehot, old_policy) predictor_loss = self._intr_reward(next_state, state_mean, state_std).mean() loss = policy_loss + predictor_loss loss.backward() if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) self.optimiser.step() self.optimiser.zero_grad() self.scheduler.step() return loss.detach().cpu().numpy()
def get_pred_action(self, state, next_state): state, next_state = totorch_many(state, next_state, device=self.device) return self.pred_action(state, next_state)
def get_pixel_control(self, state:np.ndarray, action_reward, hidden): state, action_reward, hidden = totorch(state, self.device), totorch(action_reward, self.device), totorch_many(*hidden, device=self.device) with torch.no_grad(): lstm_state, _ = self.policy.lstm_forward(state, action_reward, hidden, done=None) Qaux = self.Qaux(lstm_state) return tonumpy(Qaux)