Esempio n. 1
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):

        if self.update_ref_interval and self.total_batches % self.update_ref_interval == 0:
            ## copy parameters from self.model to self.ref_model
            self.ref_model.load_state_dict(self.model.state_dict())
        self.total_batches += 1

        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        q_value = values["q_value"]

        with torch.no_grad():
            next_values, next_states_update = self.ref_model.value(
                next_inputs, next_states)
            next_q_value = next_values["q_value"] * torch.abs(
                next_alive["alive"])
            next_value, _ = next_q_value.max(-1)
            next_value = next_value.unsqueeze(-1)

        assert q_value.size() == next_q_value.size()

        value = comf.idx_select(q_value, action)
        critic_value = reward + self.discount_factor * next_value
        cost = (critic_value - value)**2

        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)

        return dict(cost=cost), states_update, next_states_update
Esempio n. 2
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):
        """
        We keep predict() the same with SimpleQ.
        We have to override learn() to implement the learning of SR.
        This function requires four functions implemented by self.model:

        1. self.model.state_embedding() - receives an observation input
                                          and outputs a compact state feature vector
                                          for predicting immediate rewards

        2. self.model.goal()            - outputs a goal vector that has the same
                                          length with the compact state feature vector.
                                          Sometimes, the goal might depend on some inputs.

        3. self.model.sr()              - given the input,
                                          returns a tensor of successor representations, each
                                          one corresponding to an action.
                                          BxAxD where B is the batch size, A is the number of
                                          actions, and D is the dim of state embedding
        """
        action = actions["action"]
        next_action = next_actions["action"]
        reward = rewards["reward"]

        ## 1. learn to predict rewards
        next_state_embedding = self.model.state_embedding(next_inputs)
        # the goal and reward evaluation should be based on the current inputs
        goal = self.model.goal(inputs)
        pred_reward = comf.inner_prod(next_state_embedding, goal)
        reward_cost = (pred_reward - reward)**2 * self.reward_cost_weight

        ## 2. use Bellman equation to learn successor representation
        srs, states_update = self.model.sr(inputs, states)  ## BxAxD
        state_embedding_dim = srs.shape[-1]
        sr = torch.gather(input=srs,
                          dim=1,
                          index=action.unsqueeze(-1).expand(
                              -1, -1, state_embedding_dim))
        sr = sr.squeeze(1)  ## BxD

        with torch.no_grad():
            next_srs, next_states_update = self.model.sr(
                next_inputs, next_states)
            next_sr = torch.gather(input=next_srs,
                                   dim=1,
                                   index=next_action.unsqueeze(-1).expand(
                                       -1, -1, state_embedding_dim))
            next_sr = next_sr.squeeze(1) * torch.abs(next_alive["alive"])

        sr_cost = (next_state_embedding.detach() +
                   self.discount_factor * next_sr - sr)**2
        sr_cost = sr_cost.mean(-1).unsqueeze(-1)

        avg_cost = comf.get_avg_cost(reward_cost + sr_cost)
        avg_cost.backward(retain_graph=True)

        return dict(cost=reward_cost +
                    sr_cost), states_update, next_states_update
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):
        self.model.train()

        if self.update_ref_interval \
                and self.total_batches % self.update_ref_interval == 0:
            ## copy parameters from self.model to self.ref_model
            self.ref_model.load_state_dict(self.model.state_dict())
        self.total_batches += 1

        action = actions["action"]
        reward = rewards["reward"]

        self.optim.zero_grad()

        values, states_update = self.get_current_values(inputs, states)
        q_distributions = values["q_value_distribution"]
        q_distribution = self.select_q_distribution(q_distributions, action)

        with torch.no_grad():
            next_values, next_states_update, next_value = self.get_next_values(
                next_inputs, next_states)
            filter = next_alive["alive"]

            ## get next Q distributions for all actions.
            next_q_distributions = self.check_alive(
                next_values["q_value_distribution"], filter)

            ## get next greedy action on expected Q value.
            next_expected_q_values = next_value * torch.abs(filter)
            _, next_action = next_expected_q_values.max(-1)
            next_action = next_action.unsqueeze(-1)

            ## select next Q distribution with action
            next_q_distribution = self.select_q_distribution(
                next_q_distributions, next_action)
            ### check batch size.
            ### distributions may have different numbers of parameters.
            assert q_distribution.size()[0] == next_q_distribution.size()[0]

        cost = self.get_cost(q_distribution, next_q_distribution, reward,
                             values, next_values)
        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)

        if self.grad_clip:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.grad_clip)
        self.optim.step()

        return dict(cost=cost)
Esempio n. 4
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):
        """
        This learn() is expected to be called multiple times on each minibatch
        """

        action = actions["action"]
        log_prob = actions["action_log_prob"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        value = values["v_value"]

        with torch.no_grad():
            next_values, next_states_update = self.model.value(
                next_inputs, next_states)
            next_value = next_values["v_value"] * torch.abs(
                next_alive["alive"])

        assert value.size() == next_value.size()

        critic_value = reward + self.discount_factor * next_value
        td_error = (critic_value - value).squeeze(-1)
        value_cost = td_error**2

        dist, _ = self.model.policy(inputs, states)
        dist = dist["action"]

        if action.dtype == torch.int64 or action.dtype == torch.int32:
            ## for discrete actions, we need to provide scalars to log_prob()
            new_log_prob = dist.log_prob(action.squeeze(-1))
        else:
            new_log_prob = dist.log_prob(action)

        ratio = torch.exp(new_log_prob - log_prob.squeeze(-1))
        ### clip pg_cost according to the ratio
        clipped_ratio = torch.clamp(ratio,
                                    min=1 - self.epsilon,
                                    max=1 + self.epsilon)

        pg_obj = torch.min(input=ratio * td_error.detach(),
                           other=clipped_ratio * td_error.detach())
        cost = self.value_cost_weight * value_cost - pg_obj

        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)

        return dict(cost=cost), states_update, next_states_update
Esempio n. 5
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):

        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        V = values["v_value"]
        Q = values["q_value"]

        with torch.no_grad():
            next_values, next_states_update = self.model.value(
                next_inputs, next_states)
            V_hat = next_values["v_value"] * torch.abs(next_alive["alive"])
        assert V.size() == V_hat.size()

        dist, _ = self.model.policy(inputs, states)
        pi = dist["action"]
        pi_dist = pi.probs
        log_pi = pi_dist.log()

        # J_V
        target_V = Q - log_pi
        expected_target_V = torch.matmul(pi_dist.unsqueeze(1),
                                         target_V.unsqueeze(2)).squeeze(-1)
        V_diff = V - expected_target_V
        J_V = 0.5 * (V_diff**2)

        # J_Q
        Q_hat = reward + self.discount_factor * V_hat
        Q_i = comf.idx_select(Q, action)
        Q_diff = Q_i - Q_hat
        J_Q = 0.5 * (Q_diff**2)

        # J_Pi
        target = F.softmax(Q, 1)
        J_pi = F.kl_div(log_pi, target, reduction="none").sum(-1, keepdim=True)

        cost = self.lambda_V * J_V + self.lambda_Q * J_Q + self.lambda_Pi * J_pi

        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)
        return dict(cost=cost), states_update, next_states_update
Esempio n. 6
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):

        action = actions["action"]
        next_action = next_actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        q_value = values["q_value"]

        with torch.no_grad():
            next_values, next_states_update = self.model.value(
                next_inputs, next_states)
            next_value = comf.idx_select(next_values["q_value"], next_action)
            next_value = next_value * torch.abs(next_alive["alive"])

        critic_value = reward + self.discount_factor * next_value
        cost = (critic_value - comf.idx_select(q_value, action))**2

        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)

        return dict(cost=cost), states_update, next_states_update
Esempio n. 7
0
    def learn(self, inputs, next_inputs, states, next_states, next_alive,
              actions, next_actions, rewards):

        action = actions["action"]
        reward = rewards["reward"]

        values, states_update = self.model.value(inputs, states)
        value = values["v_value"]

        with torch.no_grad():
            next_values, next_states_update = self.model.value(
                next_inputs, next_states)
            next_value = next_values["v_value"] * torch.abs(
                next_alive["alive"])

        assert value.size() == next_value.size()

        critic_value = reward + self.discount_factor * next_value
        td_error = (critic_value - value).squeeze(-1)
        value_cost = td_error**2

        dist, _ = self.model.policy(inputs, states)
        dist = dist["action"]

        if action.dtype == torch.int64 or action.dtype == torch.int32:
            ## for discrete actions, we need to provide scalars to log_prob()
            pg_cost = -dist.log_prob(action.squeeze(-1))
        else:
            pg_cost = -dist.log_prob(action)

        cost = self.value_cost_weight * value_cost \
               + pg_cost * td_error.detach() \
               - self.prob_entropy_weight * dist.entropy()  ## increase entropy for exploration

        avg_cost = comf.get_avg_cost(cost)
        avg_cost.backward(retain_graph=True)
        return dict(cost=cost), states_update, next_states_update