예제 #1
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        batch_size = obs.size(0)
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_all, _, _ = self.qf(obs, new_obs_actions)
        q_all = q_all.view(batch_size, self.num_samples, 1)
        q_new_actions, _ = torch.min(q_all, dim=1)

        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q_pred, mu, std = self.qf(obs, actions)
        print(q_pred)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )

        target_all, _, _ = self.target_qf(next_obs, new_next_actions)
        target_all = target_all.view(self.num_samples, batch_size, 1)

        print(target_all)
        print(target_all.size())

        target_q_values, _ = torch.min(target_all, dim=1)

        print(target_q_values)
        target_q_values = target_q_values - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values

        if self.weighted_mse:
            raise NotImplementedError
        else:
            q_target = q_target.repeat_interleave(self.num_samples, dim=0)
            qf_loss = self.qf_criterion(q_pred, q_target.detach())

        qf_loss += self.beta * kl_divergence(mu, std)
        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        # exit()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #2
0
    def train_from_torch(self, batch):

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        # make rewrads posutive
        if self.rewards_shift_param is not None:
            rewards = rewards - self.rewards_shift_param

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward(retain_graph=True)
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        # use rnd in policy
        if self.use_rnd_policy:
            actor_bonus = self._get_bonus(obs, new_obs_actions)
            q_new_actions = q_new_actions - self.beta * actor_bonus

        policy_loss = (alpha * log_pi - q_new_actions).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )

        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        # use rnd in critic
        if self.use_rnd_critic:
            critic_bonus = self._get_bonus(next_obs, new_next_actions)
            target_q_values = target_q_values - self.beta * critic_bonus

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Update networks
        """
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #3
0
파일: twin_sac.py 프로젝트: szk9876/rlkit
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        v_pred = self.vf(obs)
        # Make sure policy accounts for squashing functions like tanh correctly!
        policy_outputs = self.policy(
            obs,
            reparameterize=self.train_policy_with_reparameterization,
            return_log_prob=True)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        """
        Alpha Loss (if applicable)
        """
        if self.use_automatic_entropy_tuning:
            """
            Alpha Loss
            """
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.fixed_entropy
            alpha_loss = 0
        """
        QF Loss
        """
        target_v_values = self.target_vf(next_obs)
        q_target = self.reward_scale * rewards.squeeze_().unsqueeze_(-1) + (
            1. - terminals) * self.discount * target_v_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        VF Loss
        """
        q_new_actions = torch.min(
            self.qf1(obs, new_actions),
            self.qf2(obs, new_actions),
        )
        v_target = q_new_actions - alpha * log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        policy_loss = None
        if self._n_train_steps_total % self.policy_update_period == 0:
            """
            Policy Loss
            """
            if self.train_policy_with_reparameterization:
                policy_loss = (alpha * log_pi - q_new_actions).mean()
            else:
                log_policy_target = q_new_actions - v_pred
                policy_loss = (
                    log_pi *
                    (alpha * log_pi - log_policy_target).detach()).mean()
            mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**
                                                           2).mean()
            std_reg_loss = self.policy_std_reg_weight * (policy_log_std**
                                                         2).mean()
            pre_tanh_value = policy_outputs[-1]
            pre_activation_reg_loss = self.policy_pre_activation_weight * (
                (pre_tanh_value**2).sum(dim=1).mean())
            policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
            policy_loss = policy_loss + policy_reg_loss

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.vf, self.target_vf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            if policy_loss is None:
                if self.train_policy_with_reparameterization:
                    policy_loss = (log_pi - q_new_actions).mean()
                else:
                    log_policy_target = q_new_actions - v_pred
                    policy_loss = (
                        log_pi * (log_pi - log_policy_target).detach()).mean()

                mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**
                                                               2).mean()
                std_reg_loss = self.policy_std_reg_weight * (policy_log_std**
                                                             2).mean()
                pre_tanh_value = policy_outputs[-1]
                pre_activation_reg_loss = self.policy_pre_activation_weight * (
                    (pre_tanh_value**2).sum(dim=1).mean())
                policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
                policy_loss = policy_loss + policy_reg_loss

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
예제 #4
0
    def train_from_torch(self, batch):
        rewards_n = batch['rewards'].detach()
        terminals_n = batch['terminals'].detach()
        obs_n = batch['observations'].detach()
        actions_n = batch['actions'].detach()
        next_obs_n = batch['next_observations'].detach()

        batch_size = rewards_n.shape[0]
        num_agent = rewards_n.shape[1]
        whole_obs = obs_n.view(batch_size, -1)
        whole_actions = actions_n.view(batch_size, -1)
        whole_next_obs = next_obs_n.view(batch_size, -1) 

        """
        Policy operations.
        """
        online_actions_n, online_pre_values_n, online_log_pis_n = [], [], []
        for agent in range(num_agent):
            policy_actions, info = self.policy_n[agent](
                obs_n[:,agent,:], return_info=True,
            )
            online_actions_n.append(policy_actions)
            online_pre_values_n.append(info['preactivation'])
            online_log_pis_n.append(info['log_prob'])
        k0_actions = torch.stack(online_actions_n) # num_agent x batch x a_dim
        k0_actions = k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

        k0_inputs = torch.cat([obs_n, k0_actions],dim=-1)
        k1_actions = self.cactor(k0_inputs, deterministic=self.deterministic_cactor_in_graph)
        k1_inputs = torch.cat([obs_n, k1_actions],dim=-1)

        k1_contexts_1 = self.cg1(k1_inputs)
        k1_contexts_2 = self.cg2(k1_inputs)

        policy_gradients_n = []
        alpha_n = []
        for agent in range(num_agent):
            policy_actions = online_actions_n[agent]
            pre_value = online_pre_values_n[agent]
            log_pi = online_log_pis_n[agent]
            if self.pre_activation_weight > 0.:
                pre_activation_policy_loss = (
                    (pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_policy_loss = torch.tensor(0.).to(ptu.device) 
            if self.use_entropy_loss:
                if self.use_automatic_entropy_tuning:
                    if self.state_dependent_alpha:
                        alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp()
                    else:
                        alpha = self.log_alpha_n[agent].exp()
                    alpha_loss = -(alpha * (log_pi + self.target_entropy).detach()).mean()
                    self.alpha_optimizer_n[agent].zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer_n[agent].step()
                    if self.state_dependent_alpha:
                        alpha = self.log_alpha_n[agent](obs_n[:,agent,:]).exp().detach()
                    else:
                        alpha = self.log_alpha_n[agent].exp().detach()
                        alpha_n.append(alpha)
                else:
                    alpha_loss = torch.tensor(0.).to(ptu.device)
                    alpha = torch.tensor(self.init_alpha).to(ptu.device)
                    alpha_n.append(alpha)
                entropy_loss = (alpha*log_pi).mean()
            else:
                entropy_loss = torch.tensor(0.).to(ptu.device)

            q1_input = torch.cat([policy_actions,k1_contexts_1[:,agent,:]],dim=-1)
            q1_output = self.qf1(q1_input)
            q2_input = torch.cat([policy_actions,k1_contexts_2[:,agent,:]],dim=-1)
            q2_output = self.qf2(q2_input)
            q_output = torch.min(q1_output,q2_output)
            raw_policy_loss = -q_output.mean()
            policy_loss = (
                    raw_policy_loss +
                    pre_activation_policy_loss * self.pre_activation_weight +
                    entropy_loss
            )

            policy_gradients_n.append(torch.autograd.grad(policy_loss, self.policy_n[agent].parameters(),retain_graph=True))

            if self._need_to_update_eval_statistics:
                self.eval_statistics['Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    policy_loss
                ))
                self.eval_statistics['Raw Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_policy_loss
                ))
                self.eval_statistics['Preactivation Policy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_policy_loss
                ))
                self.eval_statistics['Entropy Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    entropy_loss
                ))
                if self.use_entropy_loss:
                    if self.state_dependent_alpha:
                        self.eval_statistics.update(create_stats_ordered_dict(
                            'Alpha {}'.format(agent),
                            ptu.get_numpy(alpha),
                        ))
                    else:
                        self.eval_statistics['Alpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                            alpha
                        ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy Action {}'.format(agent),
                    ptu.get_numpy(policy_actions),
                ))

        for agent in range(num_agent):
            # self.policy_optimizer_n[agent].zero_grad()
            for pid,p in enumerate(self.policy_n[agent].parameters()):
                p.grad = policy_gradients_n[agent][pid]
            self.policy_optimizer_n[agent].step()

        """
        Critic operations.
        """
        with torch.no_grad():
            next_actions_n, next_log_pis_n = [], []
            for agent in range(num_agent):
                next_actions, next_info = self.policy_n[agent](
                    next_obs_n[:,agent,:], return_info=True,
                    deterministic=self.deterministic_next_action,
                )
                next_actions_n.append(next_actions)
                next_log_pis_n.append(next_info['log_prob'])
            next_k0_actions = torch.stack(next_actions_n) # num_agent x batch x a_dim
            next_k0_actions = next_k0_actions.transpose(0,1).contiguous() # batch x num_agent x a_dim

            next_k0_inputs = torch.cat([next_obs_n, next_k0_actions],dim=-1)
            next_k1_actions = self.cactor(next_k0_inputs, deterministic=self.deterministic_cactor_in_graph)

            next_k1_inputs = torch.cat([next_obs_n, next_k1_actions],dim=-1)
            next_k1_contexts_1 = self.cg1(next_k1_inputs)
            next_k1_contexts_2 = self.cg2(next_k1_inputs)

        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)

        buffer_contexts_1 = self.cg1(buffer_inputs) # batch x num_agent x c_dim
        q1_inputs = torch.cat([actions_n, buffer_contexts_1],dim=-1)
        q1_preds_n = self.qf1(q1_inputs)

        buffer_contexts_2 = self.cg2(buffer_inputs) # batch x num_agent x c_dim
        q2_inputs = torch.cat([actions_n, buffer_contexts_2],dim=-1)
        q2_preds_n = self.qf2(q2_inputs)

        raw_qf1_loss_n, raw_qf2_loss_n, q_target_n = [], [], []
        for agent in range(num_agent):
            with torch.no_grad():
                next_policy_actions = next_actions_n[agent]
                next_log_pi = next_log_pis_n[agent]

                next_q1_input = torch.cat([next_policy_actions,next_k1_contexts_1[:,agent,:]],dim=-1)
                next_target_q1_values = self.target_qf1(next_q1_input)

                next_q2_input = torch.cat([next_policy_actions,next_k1_contexts_2[:,agent,:]],dim=-1)
                next_target_q2_values = self.target_qf2(next_q2_input)

                next_target_q_values = torch.min(next_target_q1_values, next_target_q2_values)

                if self.use_entropy_reward:
                    if self.state_dependent_alpha:
                        next_alpha = self.log_alpha_n[agent](next_obs_n[:,agent,:]).exp()
                    else:
                        next_alpha = alpha_n[agent]
                    next_target_q_values =  next_target_q_values - next_alpha * next_log_pi

                q_target = self.reward_scale*rewards_n[:,agent,:] + (1. - terminals_n[:,agent,:]) * self.discount * next_target_q_values
                q_target = torch.clamp(q_target, self.min_q_value, self.max_q_value)
                q_target_n.append(q_target)

            q1_pred = q1_preds_n[:,agent,:]
            raw_qf1_loss = self.qf_criterion(q1_pred, q_target)
            raw_qf1_loss_n.append(raw_qf1_loss)

            q2_pred = q2_preds_n[:,agent,:]
            raw_qf2_loss = self.qf_criterion(q2_pred, q_target)
            raw_qf2_loss_n.append(raw_qf2_loss)

            if self._need_to_update_eval_statistics:
                self.eval_statistics['QF1 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf1_loss))
                self.eval_statistics['QF2 Loss {}'.format(agent)] = np.mean(ptu.get_numpy(raw_qf2_loss))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q1 Predictions {}'.format(agent),
                    ptu.get_numpy(q1_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions {}'.format(agent),
                    ptu.get_numpy(q2_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q Targets {}'.format(agent),
                    ptu.get_numpy(q_target),
                ))

        if self.sum_n_loss:
            raw_qf1_loss = torch.sum(torch.stack(raw_qf1_loss_n))
            raw_qf2_loss = torch.sum(torch.stack(raw_qf2_loss_n))
        else:
            raw_qf1_loss = torch.mean(torch.stack(raw_qf1_loss_n))
            raw_qf2_loss = torch.mean(torch.stack(raw_qf2_loss_n))

        if self.negative_sampling:
            perturb_actions = actions_n.clone() # batch x agent x |A|
            batch_size, num_agent, a_dim = perturb_actions.shape
            perturb_agents = torch.randint(low=0,high=num_agent,size=(batch_size,))
            neg_actions = torch.rand(batch_size,a_dim)*2.-1. # ranged in -1 to 1
            perturb_actions[torch.arange(batch_size),perturb_agents,:] = neg_actions
                
            perturb_inputs = torch.cat([obs_n,perturb_actions],dim=-1)

            perturb_contexts_1 = self.cg1(perturb_inputs) # batch x num_agent x c_dim
            perturb_q1_inputs = torch.cat([actions_n, perturb_contexts_1],dim=-1)
            perturb_q1_preds = self.qf1(perturb_q1_inputs)[torch.arange(batch_size),perturb_agents,:]

            perturb_contexts_2 = self.cg2(perturb_inputs) # batch x num_agent x c_dim
            perturb_q2_inputs = torch.cat([actions_n, perturb_contexts_2],dim=-1)
            perturb_q2_preds = self.qf2(perturb_q2_inputs)[torch.arange(batch_size),perturb_agents,:]

            perturb_q_targets = torch.stack(q_target_n).transpose(0,1).contiguous()[torch.arange(batch_size),perturb_agents,:]

            neg_loss1 = self.qf_criterion(perturb_q1_preds, perturb_q_targets)
            neg_loss2 = self.qf_criterion(perturb_q2_preds, perturb_q_targets)
        else:
            neg_loss1, neg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device)

        if self.qf_weight_decay > 0:
            reg_loss1 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in list(self.qf1.regularizable_parameters())+list(self.cg1.regularizable_parameters())
            )

            reg_loss2 = self.qf_weight_decay * sum(
                torch.sum(param ** 2)
                for param in list(self.qf2.regularizable_parameters())+list(stack.cg2.regularizable_parameters())
            )
        else:
            reg_loss1, reg_loss2 = torch.tensor(0.).to(ptu.device), torch.tensor(0.).to(ptu.device)

        qf1_loss = raw_qf1_loss + reg_loss1 + neg_loss1
        qf2_loss = raw_qf2_loss + reg_loss2 + neg_loss2

        if self._need_to_update_eval_statistics:
            self.eval_statistics['raw_qf1_loss'] = np.mean(ptu.get_numpy(raw_qf1_loss))
            self.eval_statistics['raw_qf2_loss'] = np.mean(ptu.get_numpy(raw_qf2_loss))
            self.eval_statistics['neg_qf1_loss'] = np.mean(ptu.get_numpy(neg_loss1))
            self.eval_statistics['neg_qf2_loss'] = np.mean(ptu.get_numpy(neg_loss2))
            self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss1))
            self.eval_statistics['reg_qf2_loss'] = np.mean(ptu.get_numpy(reg_loss2))

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        """
        Central actor operations.
        """
        buffer_inputs = torch.cat([obs_n, actions_n],dim=-1)
        cactor_actions, cactor_infos = self.cactor(buffer_inputs,return_info=True)
        # batch x agent_num x |A|
        buffer_contexts_1 = self.cg1(buffer_inputs)
        buffer_contexts_2 = self.cg2(buffer_inputs)

        cactor_loss_n = []
        for agent in range(num_agent):
            cactor_pre_value = cactor_infos['preactivation'][:,agent,:]
            if self.pre_activation_weight > 0:
                pre_activation_cactor_loss = (
                    (cactor_pre_value**2).sum(dim=1).mean()
                )
            else:
                pre_activation_cactor_loss = torch.tensor(0.).to(ptu.device)
            if self.use_cactor_entropy_loss:
                cactor_log_pi = cactor_infos['log_prob'][:,agent,:]
                if self.use_automatic_entropy_tuning:
                    if self.state_dependent_alpha:
                        calpha = self.log_calpha_n[agent](whole_obs).exp()
                    else:
                        calpha = self.log_calpha_n[agent].exp()
                    calpha_loss = -(calpha * (cactor_log_pi + self.target_entropy).detach()).mean()
                    self.calpha_optimizer_n[agent].zero_grad()
                    calpha_loss.backward()
                    self.calpha_optimizer_n[agent].step()
                    if self.state_dependent_alpha:
                        calpha = self.log_calpha_n[agent](whole_obs).exp().detach()
                    else:
                        calpha = self.log_calpha_n[agent].exp().detach()
                else:
                    calpha_loss = torch.tensor(0.).to(ptu.device)
                    calpha = torch.tensor(self.init_alpha).to(ptu.device)
                cactor_entropy_loss = (calpha*cactor_log_pi).mean()
            else:
                cactor_entropy_loss = torch.tensor(0.).to(ptu.device)
            
            q1_input = torch.cat([cactor_actions[:,agent,:],buffer_contexts_1[:,agent,:]],dim=-1)
            q1_output = self.qf1(q1_input)
            q2_input = torch.cat([cactor_actions[:,agent,:],buffer_contexts_2[:,agent,:]],dim=-1)
            q2_output = self.qf2(q2_input)
            q_output = torch.min(q1_output,q2_output)
            raw_cactor_loss = -q_output.mean()
            cactor_loss = (
                    raw_cactor_loss +
                    pre_activation_cactor_loss * self.pre_activation_weight +
                    cactor_entropy_loss
            )
            cactor_loss_n.append(cactor_loss)

            if self._need_to_update_eval_statistics:
                if self.use_cactor_entropy_loss:
                    if self.state_dependent_alpha:
                        self.eval_statistics.update(create_stats_ordered_dict(
                            'CAlpha {}'.format(agent),
                            ptu.get_numpy(calpha),
                        ))
                    else:
                        self.eval_statistics['CAlpha {} Mean'.format(agent)] = np.mean(ptu.get_numpy(
                            calpha
                        ))
                self.eval_statistics['Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_loss
                ))
                self.eval_statistics['Raw Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    raw_cactor_loss
                ))
                self.eval_statistics['Preactivation Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    pre_activation_cactor_loss
                ))
                self.eval_statistics['Entropy Cactor Loss {}'.format(agent)] = np.mean(ptu.get_numpy(
                    cactor_entropy_loss
                ))
        if self.sum_n_loss:
            cactor_loss = torch.sum(torch.stack(cactor_loss_n))
        else:
            cactor_loss = torch.mean(torch.stack(cactor_loss_n))
        self.cactor_optimizer.zero_grad()
        cactor_loss.backward()
        self.cactor_optimizer.step()
                
        self._need_to_update_eval_statistics = False
        self._update_target_networks()
        self._n_train_steps_total += 1
예제 #5
0
    def train_from_torch(self, batch):
        obs = batch['observations']
        old_log_pi = batch['log_prob']
        advantage = batch['advantage']
        returns = batch['returns']
        actions = batch['actions']

        """
        Policy Loss
        """
        _, _, all_probs = self.policy(obs)
        new_log_pi = torch.zeros(all_probs.shape[0], 1)
        for i in range(all_probs.shape[0]):
            new_log_pi[i] = Categorical(all_probs[i]).log_prob(actions[i]).sum()

        # Advantage Clip
        ratio = torch.exp(new_log_pi - old_log_pi)
        left = ratio * advantage
        right = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantage

        policy_loss = (-1 * torch.min(left, right)).mean()

        """
        VF Loss
        """
        v_pred = self.vf(obs)
        v_target = returns
        vf_loss = self.vf_criterion(v_pred, v_target)

        """
        Update networks
        """
        loss = policy_loss + vf_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.last_approx_kl is None or not self._need_to_update_eval_statistics:
            self.last_approx_kl = (old_log_pi - new_log_pi).detach()
        
        approx_ent = -new_log_pi

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            policy_grads = torch.cat([p.grad.flatten() for p in self.policy.parameters()])
            value_grads = torch.cat([p.grad.flatten() for p in self.vf.parameters()])

            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Predictions',
                ptu.get_numpy(v_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'V Target',
                ptu.get_numpy(v_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Gradients',
                ptu.get_numpy(policy_grads),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Value Gradients',
                ptu.get_numpy(value_grads),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy KL',
                ptu.get_numpy(self.last_approx_kl),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy Entropy',
                ptu.get_numpy(approx_ent),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'New Log Pis',
                ptu.get_numpy(new_log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Old Log Pis',
                ptu.get_numpy(old_log_pi),
            ))
        self._n_train_steps_total += 1
예제 #6
0
 def get_param_values_np(self):
     state_dict = self.state_dict()
     np_dict = OrderedDict()
     for key, tensor in state_dict.items():
         np_dict[key] = ptu.get_numpy(tensor)
     return np_dict
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, torch.autograd.Variable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
예제 #8
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Critic operations.
        """

        # add some gaussian noise
        next_actions = self.target_policy(next_obs)
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        bellman_errors = []
        if self.stop_critic_training is None or self.stop_critic_training >= self.current_epoch:

            for idx in range(self.num_q):
                target_q_values = self.target_qfs[idx](next_obs,
                                                       noisy_next_actions)
                q_target = rewards + (
                    1. - terminals) * self.discount * target_q_values
                q_target = q_target.detach()
                q_pred = self.qfs[idx](obs, actions)
                bellman_errors.append((q_pred - q_target)**2)
        """
        Update Networks
        """
        if self.stop_critic_training is None or self.stop_critic_training >= self.current_epoch:
            for idx in range(self.num_q):
                self.qf_optimizers[idx].zero_grad()
                bellman_errors[idx].mean().backward()
                self.qf_optimizers[idx].step()

        policy_actions = policy_loss = None
        var_q_grad_sum = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            """
            update target network 
            """
            if self.stop_actor_training is None or self.stop_actor_training >= self.current_epoch:
                policy_actions = self.policy(obs)
                policy_actions.retain_grad()
                q_output = self.qfs[0](obs, policy_actions)
                policy_loss = -q_output.mean()

                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

                ptu.soft_update_from_to(self.policy, self.target_policy,
                                        self.tau)
                for idx in range(self.num_q):
                    ptu.soft_update_from_to(self.qfs[idx],
                                            self.target_qfs[idx], self.tau)

        if self.eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            if policy_loss is None:
                """
                Compute the policy loss which computed by Q-value functions 
                """
                policy_actions = self.policy(obs)
                q_output = self.qfs[0](obs, policy_actions)
                policy_loss = -q_output.mean()

            if var_q_grad_sum is None:
                """
                Compute the gradient of taking different variables 
                """
                ensemble_q_grads = []
                ensemble_j_grads = []

                for idx in range(self.num_q):
                    policy_actions = self.policy(obs)
                    policy_actions.retain_grad()

                    q_output = self.qfs[idx](obs, policy_actions)
                    policy_loss = -q_output.mean()

                    self.policy_optimizer.zero_grad()
                    policy_loss.backward()

                    ensemble_q_grads.append(
                        torch.mean(policy_actions.grad, 0).view(1, -1))

                    all_grad = torch.cat(
                        [(torch.squeeze(param.grad.view(1, -1))).view(1, -1)
                         for param in self.policy.parameters()],
                        dim=1)
                    ensemble_j_grads.append(all_grad)

                ensemble_q_grads = torch.cat(ensemble_q_grads)
                ensemble_j_grads = torch.cat(ensemble_j_grads)

                average_g_grads = torch.mean(ensemble_q_grads, dim=0)
                average_j_grads = torch.mean(ensemble_j_grads, dim=0)

                average_g_grad_norm = torch.norm(average_g_grads, p=2)
                average_j_grad_norm = torch.norm(average_j_grads, p=2)

                var_q_grads = ensemble_q_grads.std(dim=0)**2
                var_j_grads = ensemble_j_grads.std(dim=0)**2

                var_q_grad_sum = torch.sum(var_q_grads)
                var_j_grad_sum = torch.sum(var_j_grads)

            self.eval_statistics = OrderedDict()
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))

            self.eval_statistics['Analysis: Var Q gradients'] = np.mean(
                ptu.get_numpy(var_q_grad_sum))
            self.eval_statistics['Analysis: Var J gradients'] = np.mean(
                ptu.get_numpy(var_j_grad_sum))
            self.eval_statistics['Analysis: Mean Q grad norm'] = np.mean(
                ptu.get_numpy(average_g_grad_norm))
            self.eval_statistics['Analysis: Mean J grad norm'] = np.mean(
                ptu.get_numpy(average_j_grad_norm))
예제 #9
0
    def pretrain_policy_with_bc(
        self,
        policy,
        train_buffer,
        test_buffer,
        steps,
        label="policy",
    ):
        logger.remove_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )

        optimizer = self.optimizers[policy]
        prev_time = time.time()
        for i in range(steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_stats = self.run_bc_batch(
                train_buffer, policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            optimizer.zero_grad()
            train_policy_loss.backward()
            optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_stats = self.run_bc_batch(
                test_buffer, policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(
                    self.policy,
                    open(logger.get_snapshot_dir() + '/bc_%s.pkl' % label,
                         "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_%s.csv' % label,
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
예제 #10
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Behavior clone a policy
        """
        recon, mean, std = self.vae(obs, actions)
        recon_loss = self.qf_criterion(recon, actions)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                          std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()
        """
        Critic Training
        """
        with torch.no_grad():
            # Duplicate state 10 times (10 is a hyperparameter chosen by BCQ)
            state_rep = next_obs.unsqueeze(1).repeat(1, 10, 1).view(
                next_obs.shape[0] * 10, next_obs.shape[1])  # 10BxS
            # Compute value of perturbed actions sampled from the VAE
            action_rep = self.policy(state_rep)[0]  # 10BxA
            target_qf1 = self.target_qf1(state_rep, action_rep)  # 10Bx1
            target_qf2 = self.target_qf2(state_rep, action_rep)  # 10Bx1
            # Soft Clipped Double Q-learning
            target_Q = 0.75 * torch.min(target_qf1,
                                        target_qf2) + 0.25 * torch.max(
                                            target_qf1, target_qf2)  # 10Bx1
            max_target_action = target_Q.view(next_obs.shape[0],
                                              -1).max(1)  # 10Bx1 > Bx10 > B,B
            target_Q = max_target_action[0].view(-1, 1)  # B > Bx1
            # 10BxA > Bx10xA > BxA
            max_actions = action_rep.view(
                next_obs.shape[0], 10,
                action_rep.shape[1])[torch.arange(next_obs.shape[0]),
                                     max_target_action[1]]
            target_Q = self.reward_scale * rewards + (
                1.0 - terminals) * self.discount * target_Q  # Bx1

        qf1_pred = self.qf1(obs, actions)  # Bx1
        qf2_pred = self.qf2(obs, actions)  # Bx1
        critic_unc = uncertainty(next_obs, max_actions, self.T, self.beta,
                                 self.pre_model, self.pre_model_name)
        qf1_loss = ((qf1_pred - target_Q.detach()).pow(2) * critic_unc).mean()
        qf2_loss = ((qf2_pred - target_Q.detach()).pow(2) * critic_unc).mean()
        """
        Actor Training
        """
        sampled_actions, raw_sampled_actions = self.vae.decode_multiple(
            obs, num_decode=self.num_samples_mmd_match)
        actor_samples, _, _, _, _, _, _, raw_actor_actions = self.policy(
            obs.unsqueeze(1).repeat(1, self.num_samples_mmd_match,
                                    1).view(-1, obs.shape[1]),
            return_log_prob=True)
        actor_samples = actor_samples.view(obs.shape[0],
                                           self.num_samples_mmd_match,
                                           actions.shape[1])
        raw_actor_actions = raw_actor_actions.view(obs.shape[0],
                                                   self.num_samples_mmd_match,
                                                   actions.shape[1])

        if self.kernel_choice == 'laplacian':
            mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions,
                                               raw_actor_actions,
                                               sigma=self.mmd_sigma)
        elif self.kernel_choice == 'gaussian':
            mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions,
                                              raw_actor_actions,
                                              sigma=self.mmd_sigma)

        action_divergence = ((sampled_actions - actor_samples)**2).sum(-1)
        raw_action_divergence = ((raw_sampled_actions -
                                  raw_actor_actions)**2).sum(-1)

        q_val1 = self.qf1(obs, actor_samples[:, 0, :])
        q_val2 = self.qf2(obs, actor_samples[:, 0, :])
        actor_unc = uncertainty(obs, actor_samples[:, 0, :], self.T, self.beta,
                                self.pre_model, self.pre_model_name)

        if self.policy_update_style == '0':
            policy_loss = torch.min(q_val1, q_val2)[:, 0] * actor_unc[:, 0]
        elif self.policy_update_style == '1':
            policy_loss = torch.mean(q_val1, q_val2)[:, 0] * actor_unc[:, 0]

        # Use uncertainty after some epochs
        if self._n_train_steps_total >= 40000:
            if self.mode == 'auto':
                policy_loss = (-policy_loss + self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = (-policy_loss + 100 * mmd_loss).mean()
        else:
            if self.mode == 'auto':
                policy_loss = (self.log_alpha.exp() *
                               (mmd_loss - self.target_mmd_thresh)).mean()
            else:
                policy_loss = 100 * mmd_loss.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        if self.mode == 'auto':
            policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        if self.mode == 'auto':
            self.alpha_optimizer.zero_grad()
            (-policy_loss).backward()
            self.alpha_optimizer.step()
            self.log_alpha.data.clamp_(min=-5.0, max=10.0)
        """
        Update networks
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(qf1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(qf2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(target_Q),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict('MMD Loss', ptu.get_numpy(mmd_loss)))
            self.eval_statistics.update(
                create_stats_ordered_dict('Action Divergence',
                                          ptu.get_numpy(action_divergence)))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Raw Action Divergence',
                    ptu.get_numpy(raw_action_divergence)))
            if self.mode == 'auto':
                self.eval_statistics['Alpha'] = self.log_alpha.exp().item()

        self._n_train_steps_total += 1
예제 #11
0
 def encode(self, obs):
     if self.normalize:
         return ptu.get_numpy(self.model.encode(
             ptu.from_numpy(obs) / 255.0))
     return ptu.get_numpy(self.model.encode(ptu.from_numpy(obs)))
예제 #12
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Classifier and Policy
        """
        class_actions = self.policy(obs)
        class_prob = self.classifier(obs, actions)
        prob_target = 1 + rewards[:, -1] #MAYBE INSTEAD PREDICT WHOLE ARRAY? THEN JUST USE LAST ONE

        neg_log_prob = - torch.log(self.classifier(obs, class_actions))
        policy_loss = (neg_log_prob).mean()

        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #13
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = ptu.randn(next_actions.shape) * self.target_policy_noise
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            for i in range(policy_actions.shape[1]):
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy Action%s' % i,
                        ptu.get_numpy(policy_actions[:, i:i + 1]),
                    ))
            policy_grads = np.array([
                np.linalg.norm(ptu.get_numpy(a.grad))
                if a.grad is not None else [0]
                for a in self.policy.parameters()
            ]).flatten()
            pg_stats = OrderedDict({})
            pg_stats["Policy Gradient Norm/Mean"] = np.mean(policy_grads)
            self.eval_statistics.update(pg_stats)
        self._n_train_steps_total += 1
예제 #14
0
    def train_from_torch(self, batch, fast_batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Prob Clasifier
        """
        if self.imp:
            fast_obs = fast_batch['observations']
            fast_actions = fast_batch['actions']

            slow_samples = torch.cat((obs, actions), dim=1)
            fast_samples = torch.cat((fast_obs, fast_actions), dim=1)

            zeros = torch.zeros(slow_samples.size(0))
            ones = torch.ones(fast_samples.size(0))

            if obs.is_cuda:
                zeros = zeros.cuda()
                ones = ones.cuda()

            slow_preds = self.prob_classifier(slow_samples)
            fast_preds = self.prob_classifier(fast_samples)

            loss = F.binary_cross_entropy(F.sigmoid(slow_preds), zeros) + \
                    F.binary_cross_entropy(F.sigmoid(fast_preds), ones)

            if self._n_train_steps_total % 1000 == 0:
                print(loss)
    
            self.prob_classifier_optimizer.zero_grad()
            loss.backward()
            self.prob_classifier_optimizer.step()

            importance_weights = F.sigmoid(slow_preds/self.temperature).detach()
            importance_weights = importance_weights / torch.sum(importance_weights)

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )

        if self.imp and self.policy_net:
            policy_loss = (alpha*log_pi - q_new_actions)
            policy_loss = (policy_loss * importance_weights.detach()).sum()
        else:
            policy_loss = (alpha*log_pi - q_new_actions).mean()


        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values

        if self.imp and self.q_net and self.residual:
            q1_imp = (q1_pred - q_target.detach()) * importance_weights.detach() 
            q2_imp = (q2_pred - q_target.detach()) * importance_weights.detach() 
#            qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
#            qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

            qf1_loss = (q1_imp ** 2).sum()
            qf2_loss = (q2_imp ** 2).sum()
#            qf1_loss = (qf1_loss * importance_weights.detach()).sum()
#            qf2_loss = (qf2_loss * importance_weights.detach()).sum()
        elif self.imp and self.q_net and not self.residual:
            qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
            qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

            qf1_loss = (qf1_loss * importance_weights.detach()).sum()
            qf2_loss = (qf2_loss * importance_weights.detach()).sum()
        else:
            qf1_loss = self.qf_criterion(q1_pred, q_target.detach()).mean()
            qf2_loss = self.qf_criterion(q2_pred, q_target.detach()).mean()

        """
        Update networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            ptu.soft_update_from_to(
                self.qf2, self.target_qf2, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #15
0
 def decode_np(self, inputs, cont=True):
     return np.clip(
         ptu.get_numpy(self.decode(ptu.from_numpy(inputs), cont=cont)), 0,
         1)
예제 #16
0
    def test_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())

        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        policy_loss = (log_pi - q_new_actions).mean()

        self.eval_statistics['validation/QF1 Loss'] = np.mean(
            ptu.get_numpy(qf1_loss))
        self.eval_statistics['validation/QF2 Loss'] = np.mean(
            ptu.get_numpy(qf2_loss))
        self.eval_statistics['validation/Policy Loss'] = np.mean(
            ptu.get_numpy(policy_loss))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q2 Predictions',
                ptu.get_numpy(q2_pred),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Q Targets',
                ptu.get_numpy(q_target),
            ))
        self.eval_statistics.update(
            create_stats_ordered_dict(
                'validation/Log Pis',
                ptu.get_numpy(log_pi),
            ))
        policy_statistics = add_prefix(dist.get_diagnostics(),
                                       "validation/policy/")
        self.eval_statistics.update(policy_statistics)
    def _take_step(self, indices, context):

        num_tasks = len(indices)

        # data is (task, batch, feat)
        # obs, actions, rewards, next_obs, terms = self.sample_data(indices)#从Replay Buffer采集数据,s,a,r,s',d
        obs, actions, rewards, next_obs, terms = self.sample_sac(indices)

        # run inference in networks
        # policy_outputs, task_z = self.agent(obs, context)#策略forward的输出,以及任务隐变量Z
        policy_outputs, task_z = self.agent(obs, context)  # 策略forward的输出,以及任务隐变量Z
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]  # 下一个状态下策略所采取的动作,其log概率 line 63
        # flattens out the task dimension:
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)

        with torch.no_grad():
            q1_next_target, q2_next_target = self.critic_target(next_obs, new_actions,
                                                                task_z)  # target q                     line 64
            min_qf_next_target = torch.min(q1_next_target, q2_next_target)  # 计算较小的target Q         line 65
            next_q_value = rewards_flat + (1. - terms_flat) * self.discount * (
                        min_qf_next_target - self.alpha * log_pi)  # q=r+(1-d)γ(Vst+1)         line 66
        q1, q2 = self.critic(obs, actions,
                             task_z)  # forward                                                           line 68
        q1_loss = F.mse_loss(q1,
                             next_q_value)  # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]       line 69
        q2_loss = F.mse_loss(q2, next_q_value)  # line 70
        # pi, log_pi, _ = self.agent.policy.sample(obs)  # 动作,动作的对数概率
        # print(obs.size())  #[1024,27]
        # print(task_z.size())
        in_policy = torch.cat([obs, task_z], 1)
        pi, _, _, log_pi, _, _, _, _, = self.agent.policy(in_policy)  # line 72

        q1_pi, q2_pi = self.critic(obs, pi,
                                   task_z.detach())  # 动作的Q值                                                line 74
        min_q_pi = torch.min(q1_pi, q2_pi)  # line 75

        # KL constraint on z if probabilistic
        self.context_optimizer.zero_grad()
        if self.use_information_bottleneck:
            kl_div = self.agent.compute_kl_div()  # context_encoder前向传播得到μ和σ,计算该分布与先验分布的kl散度
            kl_loss = self.kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        self.critic_optimizer.zero_grad()
        q1_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        self.critic_optimizer.zero_grad()
        q2_loss.backward(retain_graph=True)
        self.critic_optimizer.step()

        self.context_optimizer.step()

        soft_update(self.critic_target, self.critic, self.soft_target_tau)

        policy_loss = ((self.alpha * log_pi) - min_q_pi).mean()  # line 77
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()  # E[-αlogπ(at|st)-αH]
            self.alpha_optim.zero_grad()
            alpha_loss.backward(retain_graph=True)
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0])))
                z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0]))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)

            self.eval_statistics['Q1 Loss'] = np.mean(ptu.get_numpy(q1_loss))
            self.eval_statistics['Q2 Loss'] = np.mean(ptu.get_numpy(q2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            if self.automatic_entropy_tuning:
                self.eval_statistics['Alpha Loss'] = np.mean(ptu.get_numpy(alpha_loss))
                self.eval_statistics['Alpha'] = np.mean(ptu.get_numpy(self.alpha))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q2 Predictions',
                ptu.get_numpy(q2),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy mu',
                ptu.get_numpy(policy_mean),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Policy log std',
                ptu.get_numpy(policy_log_std),
            ))
예제 #18
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch == True:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif self.normalize_over_batch == False:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1
예제 #19
0
def np_ify(tensor_or_other):
    if isinstance(tensor_or_other, Variable):
        return ptu.get_numpy(tensor_or_other)
    else:
        return tensor_or_other
예제 #20
0
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['goals']
        num_steps_left = batch['num_steps_left']
        """
        Policy operations.
        """
        policy_actions, pre_tanh_value = self.policy(
            obs,
            goals,
            num_steps_left,
            return_preactivations=True,
        )
        pre_activation_policy_loss = ((pre_tanh_value**2).sum(dim=1).mean())
        q_output = self.qf(
            observations=obs,
            actions=policy_actions,
            num_steps_left=num_steps_left,
            goals=goals,
        )
        raw_policy_loss = -q_output.mean()
        policy_loss = (
            raw_policy_loss +
            pre_activation_policy_loss * self.policy_pre_activation_weight)
        """
        Critic operations.
        """
        next_actions = self.target_policy(
            observations=next_obs,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        # speed up computation by not backpropping these gradients
        next_actions.detach()
        target_q_values = self.target_qf(
            observations=next_obs,
            actions=next_actions,
            goals=goals,
            num_steps_left=num_steps_left - 1,
        )
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()
        q_pred = self.qf(
            observations=obs,
            actions=actions,
            goals=goals,
            num_steps_left=num_steps_left,
        )
        if self.tdm_normalizer:
            q_pred = self.tdm_normalizer.distance_normalizer.normalize_scale(
                q_pred)
            q_target = self.tdm_normalizer.distance_normalizer.normalize_scale(
                q_target)
        bellman_errors = (q_pred - q_target)**2
        qf_loss = self.qf_criterion(q_pred, q_target)
        """
        Update Networks
        """
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self._update_target_networks()
        """
        Save some statistics for eval using just one batch.
        """
        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics['Raw Policy Loss'] = np.mean(
                ptu.get_numpy(raw_policy_loss))
            self.eval_statistics['Preactivation Policy Loss'] = (
                self.eval_statistics['Policy Loss'] -
                self.eval_statistics['Raw Policy Loss'])
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors',
                    ptu.get_numpy(bellman_errors),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
예제 #21
0
    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SACLosses, LossStatistics]:
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(-1)
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            #             alpha = 0
            #Temperature Decay
            alpha = self.log_alpha

        q_new_actions = torch.min(
            self.qf1(obs, new_obs_actions),
            self.qf2(obs, new_obs_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(-1)
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            eval_statistics.update(policy_statistics)
            if self.use_automatic_entropy_tuning:
                eval_statistics['Alpha'] = alpha.item()
                eval_statistics['Alpha Loss'] = alpha_loss.item()
            else:
                eval_statistics['Alpha'] = alpha

        loss = SACLosses(
            policy_loss=policy_loss,
            qf1_loss=qf1_loss,
            qf2_loss=qf2_loss,
            alpha_loss=alpha_loss,
        )

        return loss, eval_statistics
예제 #22
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        raw_actions = batch['raw_actions']
        next_obs = batch['next_observations']
        """
        Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = torch.tensor(1.)
        """
        Policy and VF Loss
        """
        v1_pred = self.vf1(obs)
        v2_pred = self.vf2(obs)
        pi_pred = self.policy.log_prob(obs, actions, raw_actions)

        target_v_values = torch.min(
            self.target_vf1(next_obs),
            self.target_vf2(next_obs),
        )
        # target_v_values = self.target_vf1(next_obs)
        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_v_values

        vf1_loss = self.vf_criterion(
            v1_pred,
            q_target.detach() - (alpha * pi_pred).detach())
        vf2_loss = self.vf_criterion(
            v2_pred,
            q_target.detach() - (alpha * pi_pred).detach())
        policy_loss = self.pi_criterion(
            alpha.detach() * pi_pred,
            q_target.detach() - torch.min(v1_pred, v2_pred).detach())
        # policy_loss = self.pi_criterion(alpha.detach()*pi_pred, q_target.detach()-v1_pred.detach())
        """
        Update networks
        """
        self.vf1_optimizer.zero_grad()
        vf1_loss.backward()
        if self.clip_gradient:
            nn.utils.clip_grad_norm_(self.vf1.parameters(), self.clip_gradient)
        self.vf1_optimizer.step()

        self.vf2_optimizer.zero_grad()
        vf2_loss.backward()
        if self.clip_gradient:
            nn.utils.clip_grad_norm_(self.vf2.parameters(), self.clip_gradient)
        self.vf2_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        if self.clip_gradient:
            nn.utils.clip_grad_norm_(self.policy.parameters(),
                                     self.clip_gradient)
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.vf1, self.target_vf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.vf2, self.target_vf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['VF1 Loss'] = np.mean(ptu.get_numpy(vf1_loss))
            self.eval_statistics['VF2 Loss'] = np.mean(ptu.get_numpy(vf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V1 Predictions',
                    ptu.get_numpy(v1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V2 Predictions',
                    ptu.get_numpy(v2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #23
0
파일: td3.py 프로젝트: johndpope/DRL
    def _do_training(self):
        batch = self.get_batch()
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Critic operations.
        """

        next_actions = self.target_policy(next_obs)
        noise = torch.normal(
            torch.zeros_like(next_actions),
            self.target_policy_noise,
        )
        noise = torch.clamp(noise, -self.target_policy_noise_clip,
                            self.target_policy_noise_clip)
        noisy_next_actions = next_actions + noise

        target_q1_values = self.target_qf1(next_obs, noisy_next_actions)
        target_q2_values = self.target_qf2(next_obs, noisy_next_actions)
        target_q_values = torch.min(target_q1_values, target_q2_values)
        q_target = rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        q1_pred = self.qf1(obs, actions)
        bellman_errors_1 = (q1_pred - q_target)**2
        qf1_loss = bellman_errors_1.mean()

        q2_pred = self.qf2(obs, actions)
        bellman_errors_2 = (q2_pred - q_target)**2
        qf2_loss = bellman_errors_2.mean()
        """
        Update Networks
        """
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()

        policy_actions = policy_loss = None
        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            policy_actions = self.policy(obs)
            q_output = self.qf1(obs, policy_actions)
            policy_loss = -q_output.mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
            ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

        if self.eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            if policy_loss is None:
                policy_actions = self.policy(obs)
                q_output = self.qf1(obs, policy_actions)
                policy_loss = -q_output.mean()

            self.eval_statistics = OrderedDict()
            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 1',
                    ptu.get_numpy(bellman_errors_1),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Bellman Errors 2',
                    ptu.get_numpy(bellman_errors_2),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))
예제 #24
0
def get_np_prediction(model, state, action):
    state = ptu.np_to_var(np.expand_dims(state, 0))
    action = ptu.np_to_var(np.expand_dims(action, 0))
    delta = model(state, action)
    return ptu.get_numpy(delta.squeeze(0))
예제 #25
0
파일: td4.py 프로젝트: xtma/dsac
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update QF
        """
        with torch.no_grad():
            next_actions = self.target_policy(next_obs)
            noise = ptu.randn(next_actions.shape) * self.target_policy_noise
            noise = torch.clamp(noise, -self.target_policy_noise_clip,
                                self.target_policy_noise_clip)
            noisy_next_actions = torch.clamp(next_actions + noise,
                                             -self.max_action, self.max_action)

            next_tau, next_tau_hat, next_presum_tau = self.get_tau(
                next_obs, noisy_next_actions, fp=self.target_fp)
            target_z1_values = self.target_zf1(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z2_values = self.target_zf2(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z_values = torch.min(target_z1_values, target_z2_values)
            z_target = self.reward_scale * rewards + (
                1. - terminals) * self.discount * target_z_values

        tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp)
        z1_pred = self.zf1(obs, actions, tau_hat)
        z2_pred = self.zf2(obs, actions, tau_hat)
        zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat,
                                     next_presum_tau)
        zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat,
                                     next_presum_tau)
        gt.stamp('preback_zf', unique=False)

        self.zf1_optimizer.zero_grad()
        zf1_loss.backward()
        self.zf1_optimizer.step()
        gt.stamp('backward_zf1', unique=False)

        self.zf2_optimizer.zero_grad()
        zf2_loss.backward()
        self.zf2_optimizer.step()
        gt.stamp('backward_zf2', unique=False)
        """
        Update FP
        """
        if self.tau_type == 'fqf':
            with torch.no_grad():
                dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) -
                                z1_pred[:, :-1] - z1_pred[:, 1:] +
                                2 * self.zf2(obs, actions, tau[:, :-1]) -
                                z2_pred[:, :-1] - z2_pred[:, 1:])
                dWdtau /= dWdtau.shape[0]  # (N, T-1)
            gt.stamp('preback_fp', unique=False)
            self.fp_optimizer.zero_grad()
            tau[:, :-1].backward(gradient=dWdtau)
            self.fp_optimizer.step()
            gt.stamp('backward_fp', unique=False)
        """
        Policy Loss
        """
        policy_actions = self.policy(obs)
        risk_param = self.risk_schedule(self._n_train_steps_total)

        if self.risk_type == 'VaR':
            tau_ = ptu.ones_like(rewards) * risk_param
            q_new_actions = self.zf1(obs, policy_actions, tau_)
        else:
            with torch.no_grad():
                new_tau, new_tau_hat, new_presum_tau = self.get_tau(
                    obs, policy_actions, fp=self.fp)
            z_new_actions = self.zf1(obs, policy_actions, new_tau_hat)
            if self.risk_type in ['neutral', 'std']:
                q_new_actions = torch.sum(new_presum_tau * z_new_actions,
                                          dim=1,
                                          keepdims=True)
                if self.risk_type == 'std':
                    q_std = new_presum_tau * (z_new_actions -
                                              q_new_actions).pow(2)
                    q_new_actions -= risk_param * q_std.sum(
                        dim=1, keepdims=True).sqrt()
            else:
                with torch.no_grad():
                    risk_weights = distortion_de(new_tau_hat, self.risk_type,
                                                 risk_param)
                q_new_actions = torch.sum(risk_weights * new_presum_tau *
                                          z_new_actions,
                                          dim=1,
                                          keepdims=True)

        policy_loss = -q_new_actions.mean()

        gt.stamp('preback_policy', unique=False)

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(),
                                                  self.clip_norm)
            self.policy_optimizer.step()
            gt.stamp('backward_policy', unique=False)

            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf1, self.target_zf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf2, self.target_zf2,
                                    self.soft_target_tau)
            if self.tau_type == 'fqf':
                ptu.soft_update_from_to(self.fp, self.target_fp,
                                        self.soft_target_tau)
        gt.stamp('soft_update', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['ZF1 Loss'] = zf1_loss.item()
            self.eval_statistics['ZF2 Loss'] = zf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z1 Predictions',
                    ptu.get_numpy(z1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z2 Predictions',
                    ptu.get_numpy(z2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z Targets',
                    ptu.get_numpy(z_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))

        self._n_train_steps_total += 1
예제 #26
0
파일: sac.py 프로젝트: xtma/dsac
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update Alpha
        """
        new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        Update QF
        """
        with torch.no_grad():
            new_next_actions, _, _, new_log_pi, *_ = self.policy(
                next_obs,
                reparameterize=True,
                return_log_prob=True,
            )
            target_q_values = torch.min(
                self.target_qf1(next_obs, new_next_actions),
                self.target_qf2(next_obs, new_next_actions),
            ) - alpha * new_log_pi

            q_target = self.reward_scale * rewards + (
                1. - terminals) * self.discount * target_q_values
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)
        gt.stamp('preback_qf', unique=False)

        self.qf1_optimizer.zero_grad()
        qf1_loss.backward()
        self.qf1_optimizer.step()
        gt.stamp('backward_qf1', unique=False)

        self.qf2_optimizer.zero_grad()
        qf2_loss.backward()
        self.qf2_optimizer.step()
        gt.stamp('backward_qf2', unique=False)
        """
        Update Policy
        """
        q_new_actions = torch.min(
            self.qf1(obs, new_actions),
            self.qf2(obs, new_actions),
        )
        policy_loss = (alpha * log_pi - q_new_actions).mean()
        gt.stamp('preback_policy', unique=False)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(),
                                              self.clip_norm)
        self.policy_optimizer.step()
        gt.stamp('backward_policy', unique=False)
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['QF1 Loss'] = qf1_loss.item()
            self.eval_statistics['QF2 Loss'] = qf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #27
0
    def _do_policy_training(self, epoch):
        if self.policy_optim_batch_size_from_expert > 0:
            policy_batch_from_policy_buffer = self.get_batch(
                self.policy_optim_batch_size -
                self.policy_optim_batch_size_from_expert, False)
            policy_batch_from_expert_buffer = self.get_batch(
                self.policy_optim_batch_size_from_expert, True)
            policy_batch = {}
            for k in policy_batch_from_policy_buffer:
                policy_batch[k] = torch.cat([
                    policy_batch_from_policy_buffer[k],
                    policy_batch_from_expert_buffer[k]
                ],
                                            dim=0)
        else:
            policy_batch = self.get_batch(self.policy_optim_batch_size, False)

        obs = policy_batch['observations'] / self.rescale
        obs = torch.index_select(obs, 1, self.state_indices)

        if self.wrap_absorbing:
            pass
            # obs = torch.cat([obs, policy_batch['absorbing'][:, 0:1]], dim=-1)
        else:
            self.ebm.eval()
            ebm_input = obs

        ebm_rew = self.get_energy(ebm_input).detach()

        if self.clamp_magnitude > 0:
            print(self.clamp_magnitude, self.clamp_magnitude is None)
            ebm_rew = torch.clamp(ebm_rew,
                                  min=-1.0 * self.clamp_magnitude,
                                  max=self.clamp_magnitude)

        # compute the reward using the algorithm
        shape_reward = torch.index_select(
            obs, 1,
            torch.LongTensor([0]).to(ptu.device)
        )  # torch.index_select(policy_batch['actions'], 1, torch.LongTensor([0]).to(ptu.device))
        policy_batch['rewards'] = reward_func(
            self.rew_func, self.cons)(ebm_rew) * shape_reward

        # print('ebm_obs: ', np.mean(ptu.get_numpy(policy_batch['observations']), axis=0))
        # print('ebm: ', np.mean(ptu.get_numpy(ebm_rew)))
        # print('ebm_rew: ', np.mean(ptu.get_numpy(policy_batch['rewards'])))

        # buffer = self.target_state_buffer
        # exp_obs = buffer[np.random.choice(buffer.shape[0], size=100)]
        # exp_input = torch.Tensor(exp_obs / self.rescale).to(ptu.device)
        # exp_ebm = self.get_energy(exp_input).detach()
        # exp_rew = reward_func(self.rew_func, self.cons)(exp_ebm)

        # print('exp_obs: ', np.mean(exp_obs,axis=0))
        # print("exp data ebm: ", np.mean(ptu.get_numpy(exp_ebm)))
        # print("exp data rew: ", np.mean(ptu.get_numpy(exp_rew)))

        # policy optimization step
        self.policy_trainer.train_step(policy_batch)
        """
        Save some statistics for eval
        """
        if self.ebm_eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.ebm_eval_statistics = OrderedDict()

        self.ebm_eval_statistics['ebm Rew Mean'] = np.mean(
            ptu.get_numpy(policy_batch['rewards']))
        self.ebm_eval_statistics['ebm Rew Std'] = np.std(
            ptu.get_numpy(policy_batch['rewards']))
        self.ebm_eval_statistics['ebm Rew Max'] = np.max(
            ptu.get_numpy(policy_batch['rewards']))
        self.ebm_eval_statistics['ebm Rew Min'] = np.min(
            ptu.get_numpy(policy_batch['rewards']))
예제 #28
0
 def decode_one_np(self, inputs, cont=True):
     return np.clip(
         ptu.get_numpy(
             self.decode(ptu.from_numpy(inputs).reshape(1, -1),
                         cont=cont))[0], 0, 1)
예제 #29
0
 def encode_np(self, inputs, cont=True):
     return ptu.get_numpy(self.encode(ptu.from_numpy(inputs), cont=cont))
예제 #30
0
 def sample_np(self, batch_size):
     latents = np.random.normal(size=(batch_size, self.latent_dim))
     latents_torch = ptu.from_numpy(latents)
     return ptu.get_numpy(self.decode(latents_torch))