def _rl_learn(self, inputs, actions, next_values, rewards, states): action = actions["action"] reward = rewards["reward"] values, states_update = self.model.value(inputs, states) value = values["v_value"] next_value = next_values["v_value"] assert value.size() == next_value.size() critic_value = reward + self.discount_factor * next_value td_error = (critic_value - value).squeeze(-1) value_cost = td_error**2 dist, _ = self.model.policy(inputs, states) dist = dist["action"] if action.dtype == torch.int64 or action.dtype == torch.int32: ## for discrete actions, we need to provide scalars to log_prob() pg_cost = -dist.log_prob(action.squeeze(-1)) else: pg_cost = -dist.log_prob(action) value_cost *= self.value_cost_weight pg_cost *= td_error.detach() entropy_cost = self.prob_entropy_weight * dist.entropy() cost = value_cost + pg_cost - entropy_cost ## increase entropy for exploration sum_cost, _ = comf.sum_cost_tensor(cost) sum_cost.backward(retain_graph=True) return dict(cost=cost, pg_cost=pg_cost, value_cost=value_cost, entropy_cost=entropy_cost), \ states_update
def _rl_learn(self, inputs, actions, next_values, rewards, states): action = actions["action"] reward = rewards["reward"] values, states_update = self.model.value(inputs, states) q_value = values["q_value"] next_value = next_values["q_value"] critic_value = reward + self.discount_factor * next_value cost = (critic_value - comf.idx_select(q_value, action))**2 sum_cost, _ = comf.sum_cost_tensor(cost) sum_cost.backward(retain_graph=True) return dict(cost=cost), states_update
def learn(self, inputs, next_inputs, states, next_states, next_alive, actions, next_actions, rewards): self.model.train() if self.update_ref_interval and self.total_batches % self.update_ref_interval == 0: ## copy parameters from self.model to self.ref_model self.__update_model(self.ref_model, self.model) self.total_batches += 1 reward = rewards["reward"] self.critic_optim.zero_grad() values, states_update = self.model.value(inputs, states, actions) q_value = values["q_value"] with torch.no_grad(): next_values, next_states_update = self.ref_model.value(next_inputs, next_states) next_value = next_values["q_value"] * torch.abs(next_alive[ "alive"]) assert q_value.size() == next_value.size() critic_value = reward + self.discount_factor * next_value critic_loss = (q_value - critic_value).squeeze(-1)**2 sum_critic_loss, _ = comf.sum_cost_tensor(critic_loss) sum_critic_loss.backward() self.critic_optim.step() self.policy_optim.zero_grad() values2, _ = self.model.value(inputs, states) policy_loss = -values2["q_value"].squeeze(-1) sum_policy_loss, _ = comf.sum_cost_tensor(policy_loss) sum_policy_loss.backward() self.policy_optim.step() return dict(critic_loss=critic_loss, policy_loss=policy_loss)
def _rl_learn(self, inputs, actions, next_values, rewards, states): """ This learn() is expected to be called multiple times on each minibatch """ action = actions["action"] log_prob = actions["action_log_prob"] reward = rewards["reward"] values, states_update = self.model.value(inputs, states) value = values["v_value"] next_value = next_values["v_value"] assert value.size() == next_value.size() critic_value = reward + self.discount_factor * next_value td_error = (critic_value - value).squeeze(-1) value_cost = self.value_cost_weight * td_error**2 dist, _ = self.model.policy(inputs, states) dist = dist["action"] if action.dtype == torch.int64 or action.dtype == torch.int32: ## for discrete actions, we need to provide scalars to log_prob() new_log_prob = dist.log_prob(action.squeeze(-1)) else: new_log_prob = dist.log_prob(action) ratio = torch.exp(new_log_prob - log_prob.squeeze(-1)) ### clip pg_cost according to the ratio clipped_ratio = torch.clamp(ratio, min=1 - self.epsilon, max=1 + self.epsilon) pg_obj = torch.min(input=ratio * td_error.detach(), other=clipped_ratio * td_error.detach()) entropy_cost = self.prob_entropy_weight * dist.entropy() cost = value_cost - pg_obj - entropy_cost ## increase entropy for exploration sum_cost, _ = comf.sum_cost_tensor(cost) sum_cost.backward(retain_graph=True) return dict(cost=cost, pg_obj=pg_obj, value_cost=value_cost, entropy_cost=entropy_cost), \ states_update
def _rl_learn(self, inputs, actions, next_values, rewards, states): if self.update_ref_interval and self.total_batches % self.update_ref_interval == 0: ## copy parameters from self.model to self.ref_model self.ref_model.load_state_dict(self.model.state_dict()) self.total_batches += 1 action = actions["action"] reward = rewards["reward"] values, states_update = self.model.value(inputs, states) q_value = values["q_value"] next_value = next_values["q_value"] value = comf.idx_select(q_value, action) critic_value = reward + self.discount_factor * next_value cost = (critic_value - value)**2 sum_cost, _ = comf.sum_cost_tensor(cost) sum_cost.backward(retain_graph=True) return dict(cost=cost), states_update
def learn(self, inputs, next_inputs, states, next_states, next_alive, actions, next_actions, rewards): """ We keep predict() the same with SimpleQ. We have to override learn() to implement the learning of SR. This function requires four functions implemented by self.model: 1. self.model.state_embedding() - receives an observation input and outputs a compact state feature vector for predicting immediate rewards 2. self.model.goal() - outputs a goal vector that has the same length with the compact state feature vector. Sometimes, the goal might depend on some inputs. 3. self.model.sr() - given the input, returns a tensor of successor representations, each one corresponding to an action. BxAxD where B is the batch size, A is the number of actions, and D is the dim of state embedding """ self.model.train() action = actions["action"] next_action = next_actions["action"] reward = rewards["reward"] self.optim.zero_grad() ## 1. learn to predict rewards next_state_embedding, recon_cost = self.model.state_embedding( next_inputs) recon_cost *= self.recon_cost_weight # the goal and reward evaluation should be based on the current inputs goal = self.model.goal(inputs) pred_reward = comf.inner_prod(next_state_embedding, goal) reward_cost = (pred_reward - reward).squeeze(-1)**2 * self.reward_cost_weight ## 2. use Bellman equation to learn successor representation srs, states_update = self.model.sr(inputs, states) ## BxAxD state_embedding_dim = srs.shape[-1] sr = torch.gather(input=srs, dim=1, index=action.unsqueeze(-1).expand( -1, -1, state_embedding_dim)) sr = sr.squeeze(1) ## BxD with torch.no_grad(): next_srs, next_states_update = self.model.sr( next_inputs, next_states) next_sr = torch.gather(input=next_srs, dim=1, index=next_action.unsqueeze(-1).expand( -1, -1, state_embedding_dim)) next_sr = next_sr.squeeze(1) * torch.abs(next_alive["alive"]) sr_cost = (next_state_embedding.detach() + self.discount_factor * next_sr - sr)**2 sr_cost = sr_cost.mean(-1) cost = recon_cost + reward_cost + sr_cost sum_cost, _ = comf.sum_cost_tensor(cost) sum_cost.backward() self.optim.step() return dict(cost=cost, reconstruction_cost=recon_cost, reward_cost=reward_cost, sr_cost=sr_cost)