예제 #1
0
    def _rl_learn(self, inputs, actions, next_values, rewards, states):
        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        value = values["v_value"]
        next_value = next_values["v_value"]
        assert value.size() == next_value.size()

        critic_value = reward + self.discount_factor * next_value
        td_error = (critic_value - value).squeeze(-1)
        value_cost = td_error**2

        dist, _ = self.model.policy(inputs, states)
        dist = dist["action"]

        if action.dtype == torch.int64 or action.dtype == torch.int32:
            ## for discrete actions, we need to provide scalars to log_prob()
            pg_cost = -dist.log_prob(action.squeeze(-1))
        else:
            pg_cost = -dist.log_prob(action)

        value_cost *= self.value_cost_weight
        pg_cost *= td_error.detach()
        entropy_cost = self.prob_entropy_weight * dist.entropy()

        cost = value_cost + pg_cost - entropy_cost  ## increase entropy for exploration

        sum_cost, _ = comf.sum_cost_tensor(cost)
        sum_cost.backward(retain_graph=True)
        return dict(cost=cost,
                    pg_cost=pg_cost,
                    value_cost=value_cost,
                    entropy_cost=entropy_cost), \
            states_update
예제 #2
0
    def _rl_learn(self, inputs, actions, next_values, rewards, states):
        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        q_value = values["q_value"]
        next_value = next_values["q_value"]

        critic_value = reward + self.discount_factor * next_value
        cost = (critic_value - comf.idx_select(q_value, action))**2

        sum_cost, _ = comf.sum_cost_tensor(cost)
        sum_cost.backward(retain_graph=True)
        return dict(cost=cost), states_update
예제 #3
0
파일: ddpg.py 프로젝트: yu239-zz/flare
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):
        self.model.train()

        if self.update_ref_interval and self.total_batches % self.update_ref_interval == 0:
            ## copy parameters from self.model to self.ref_model
            self.__update_model(self.ref_model, self.model)
        self.total_batches += 1

        reward = rewards["reward"]

        self.critic_optim.zero_grad()
        values, states_update = self.model.value(inputs, states, actions)
        q_value = values["q_value"]

        with torch.no_grad():
            next_values, next_states_update = self.ref_model.value(next_inputs,
                                                                   next_states)
            next_value = next_values["q_value"] * torch.abs(next_alive[
                "alive"])

        assert q_value.size() == next_value.size()

        critic_value = reward + self.discount_factor * next_value
        critic_loss = (q_value - critic_value).squeeze(-1)**2
        sum_critic_loss, _ = comf.sum_cost_tensor(critic_loss)
        sum_critic_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        values2, _ = self.model.value(inputs, states)
        policy_loss = -values2["q_value"].squeeze(-1)
        sum_policy_loss, _ = comf.sum_cost_tensor(policy_loss)
        sum_policy_loss.backward()
        self.policy_optim.step()

        return dict(critic_loss=critic_loss, policy_loss=policy_loss)
예제 #4
0
    def _rl_learn(self, inputs, actions, next_values, rewards, states):
        """
        This learn() is expected to be called multiple times on each minibatch
        """

        action = actions["action"]
        log_prob = actions["action_log_prob"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        value = values["v_value"]
        next_value = next_values["v_value"]
        assert value.size() == next_value.size()

        critic_value = reward + self.discount_factor * next_value
        td_error = (critic_value - value).squeeze(-1)
        value_cost = self.value_cost_weight * td_error**2

        dist, _ = self.model.policy(inputs, states)
        dist = dist["action"]

        if action.dtype == torch.int64 or action.dtype == torch.int32:
            ## for discrete actions, we need to provide scalars to log_prob()
            new_log_prob = dist.log_prob(action.squeeze(-1))
        else:
            new_log_prob = dist.log_prob(action)

        ratio = torch.exp(new_log_prob - log_prob.squeeze(-1))
        ### clip pg_cost according to the ratio
        clipped_ratio = torch.clamp(ratio,
                                    min=1 - self.epsilon,
                                    max=1 + self.epsilon)

        pg_obj = torch.min(input=ratio * td_error.detach(),
                           other=clipped_ratio * td_error.detach())
        entropy_cost = self.prob_entropy_weight * dist.entropy()
        cost = value_cost - pg_obj - entropy_cost  ## increase entropy for exploration

        sum_cost, _ = comf.sum_cost_tensor(cost)
        sum_cost.backward(retain_graph=True)
        return dict(cost=cost,
                    pg_obj=pg_obj,
                    value_cost=value_cost,
                    entropy_cost=entropy_cost), \
            states_update
예제 #5
0
    def _rl_learn(self, inputs, actions, next_values, rewards, states):

        if self.update_ref_interval and self.total_batches % self.update_ref_interval == 0:
            ## copy parameters from self.model to self.ref_model
            self.ref_model.load_state_dict(self.model.state_dict())
        self.total_batches += 1

        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        q_value = values["q_value"]
        next_value = next_values["q_value"]

        value = comf.idx_select(q_value, action)
        critic_value = reward + self.discount_factor * next_value
        cost = (critic_value - value)**2

        sum_cost, _ = comf.sum_cost_tensor(cost)
        sum_cost.backward(retain_graph=True)
        return dict(cost=cost), states_update
예제 #6
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):
        """
        We keep predict() the same with SimpleQ.
        We have to override learn() to implement the learning of SR.
        This function requires four functions implemented by self.model:

        1. self.model.state_embedding() - receives an observation input
                                          and outputs a compact state feature vector
                                          for predicting immediate rewards

        2. self.model.goal()            - outputs a goal vector that has the same
                                          length with the compact state feature vector.
                                          Sometimes, the goal might depend on some inputs.

        3. self.model.sr()              - given the input,
                                          returns a tensor of successor representations, each
                                          one corresponding to an action.
                                          BxAxD where B is the batch size, A is the number of
                                          actions, and D is the dim of state embedding
        """
        self.model.train()

        action = actions["action"]
        next_action = next_actions["action"]
        reward = rewards["reward"]

        self.optim.zero_grad()

        ## 1. learn to predict rewards
        next_state_embedding, recon_cost = self.model.state_embedding(
            next_inputs)
        recon_cost *= self.recon_cost_weight

        # the goal and reward evaluation should be based on the current inputs
        goal = self.model.goal(inputs)
        pred_reward = comf.inner_prod(next_state_embedding, goal)
        reward_cost = (pred_reward -
                       reward).squeeze(-1)**2 * self.reward_cost_weight

        ## 2. use Bellman equation to learn successor representation
        srs, states_update = self.model.sr(inputs, states)  ## BxAxD
        state_embedding_dim = srs.shape[-1]
        sr = torch.gather(input=srs,
                          dim=1,
                          index=action.unsqueeze(-1).expand(
                              -1, -1, state_embedding_dim))
        sr = sr.squeeze(1)  ## BxD

        with torch.no_grad():
            next_srs, next_states_update = self.model.sr(
                next_inputs, next_states)
            next_sr = torch.gather(input=next_srs,
                                   dim=1,
                                   index=next_action.unsqueeze(-1).expand(
                                       -1, -1, state_embedding_dim))
            next_sr = next_sr.squeeze(1) * torch.abs(next_alive["alive"])

        sr_cost = (next_state_embedding.detach() +
                   self.discount_factor * next_sr - sr)**2
        sr_cost = sr_cost.mean(-1)

        cost = recon_cost + reward_cost + sr_cost
        sum_cost, _ = comf.sum_cost_tensor(cost)
        sum_cost.backward()

        self.optim.step()

        return dict(cost=cost,
                    reconstruction_cost=recon_cost,
                    reward_cost=reward_cost,
                    sr_cost=sr_cost)