예제 #1
0
 def rsample(self, *args, **kwargs):
     event = super().rsample(*args, **kwargs)
     clipped = torch.max(
         torch.min(event,
                   ptu.ones_like(event) - self._clip),
         -1 * ptu.ones_like(event) + self._clip,
     )
     event = event - event.detach() + clipped.detach()
     event *= self._mult
     return event
예제 #2
0
 def decode(self, input):
     output = self(input)
     if self.output_var == 'learned':
         mu, logvar = torch.split(output, 2, dim=1)
         var = logvar.exp()
     else:
         mu = output
         var = self.output_var * ptu.ones_like(mu)
     return mu, var
예제 #3
0
    def encode(self,
               input,
               lstm_hidden=None,
               return_hidden=False,
               return_vae_latent=False):
        '''
        input: [seq_len x batch x flatten_img_dim] of flattened images
        lstm_hidden: [lstm_layers x batch x lstm_hidden_size] 
        mark: change depends on how latent distribution parameters are used
        '''
        seq_len, batch_size, feature_size = input.shape
        # print("in lstm encode: ", seq_len, batch_size, feature_size)
        input = input.reshape((-1, feature_size))
        feature = self.encoder(input)  # [seq_len x batch x conv_output_size]

        vae_mu = self.vae_fc1(feature)
        if self.log_min_variance is None:
            vae_logvar = self.vae_fc2(feature)
        else:
            vae_logvar = self.log_min_variance + torch.abs(
                self.vae_fc2(feature))

        # lstm_input = self.rsample((vae_mu, vae_logvar))
        # if self.detach_vae_output:
        #     lstm_input = lstm_input.detach()
        if self.detach_vae_output:
            lstm_input = vae_mu.detach().clone()
        else:
            lstm_input = vae_mu
        lstm_input = lstm_input.view((seq_len, batch_size, -1))
        # if self.detach_vae_output:
        #     lstm_input = lstm_input.detach()

        if lstm_hidden is None:
            lstm_hidden = (ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size), \
                     ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size))

        h, hidden = self.lstm(
            lstm_input, lstm_hidden)  # [seq_len, batch_size, lstm_hidden_size]

        lstm_latent = self.lstm_fc(h)

        ret = (lstm_latent, ptu.ones_like(lstm_latent))
        if return_vae_latent:
            ret += (vae_mu, vae_logvar)

        if return_hidden:
            return ret, hidden
        return ret  #, lstm_input # [seq_len, batch_size, representation_size]
예제 #4
0
파일: dsac.py 프로젝트: xtma/dsac
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update Alpha
        """
        new_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha.exp() *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        gt.stamp('preback_alpha', unique=False)
        """
        Update ZF
        """
        with torch.no_grad():
            new_next_actions, _, _, new_log_pi, *_ = self.target_policy(
                next_obs,
                reparameterize=True,
                return_log_prob=True,
            )
            next_tau, next_tau_hat, next_presum_tau = self.get_tau(
                next_obs, new_next_actions, fp=self.target_fp)
            target_z1_values = self.target_zf1(next_obs, new_next_actions,
                                               next_tau_hat)
            target_z2_values = self.target_zf2(next_obs, new_next_actions,
                                               next_tau_hat)
            target_z_values = torch.min(target_z1_values,
                                        target_z2_values) - alpha * new_log_pi
            z_target = self.reward_scale * rewards + (
                1. - terminals) * self.discount * target_z_values

        tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp)
        z1_pred = self.zf1(obs, actions, tau_hat)
        z2_pred = self.zf2(obs, actions, tau_hat)
        zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat,
                                     next_presum_tau)
        zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat,
                                     next_presum_tau)
        gt.stamp('preback_zf', unique=False)

        self.zf1_optimizer.zero_grad()
        zf1_loss.backward()
        self.zf1_optimizer.step()
        gt.stamp('backward_zf1', unique=False)

        self.zf2_optimizer.zero_grad()
        zf2_loss.backward()
        self.zf2_optimizer.step()
        gt.stamp('backward_zf2', unique=False)
        """
        Update FP
        """
        if self.tau_type == 'fqf':
            with torch.no_grad():
                dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) -
                                z1_pred[:, :-1] - z1_pred[:, 1:] +
                                2 * self.zf2(obs, actions, tau[:, :-1]) -
                                z2_pred[:, :-1] - z2_pred[:, 1:])
                dWdtau /= dWdtau.shape[0]  # (N, T-1)
            gt.stamp('preback_fp', unique=False)

            self.fp_optimizer.zero_grad()
            tau[:, :-1].backward(gradient=dWdtau)
            self.fp_optimizer.step()
            gt.stamp('backward_fp', unique=False)
        """
        Update Policy
        """
        risk_param = self.risk_schedule(self._n_train_steps_total)

        if self.risk_type == 'VaR':
            tau_ = ptu.ones_like(rewards) * risk_param
            q1_new_actions = self.zf1(obs, new_actions, tau_)
            q2_new_actions = self.zf2(obs, new_actions, tau_)
        else:
            with torch.no_grad():
                new_tau, new_tau_hat, new_presum_tau = self.get_tau(
                    obs, new_actions, fp=self.fp)
            z1_new_actions = self.zf1(obs, new_actions, new_tau_hat)
            z2_new_actions = self.zf2(obs, new_actions, new_tau_hat)
            if self.risk_type in ['neutral', 'std']:
                q1_new_actions = torch.sum(new_presum_tau * z1_new_actions,
                                           dim=1,
                                           keepdims=True)
                q2_new_actions = torch.sum(new_presum_tau * z2_new_actions,
                                           dim=1,
                                           keepdims=True)
                if self.risk_type == 'std':
                    q1_std = new_presum_tau * (z1_new_actions -
                                               q1_new_actions).pow(2)
                    q2_std = new_presum_tau * (z2_new_actions -
                                               q2_new_actions).pow(2)
                    q1_new_actions -= risk_param * q1_std.sum(
                        dim=1, keepdims=True).sqrt()
                    q2_new_actions -= risk_param * q2_std.sum(
                        dim=1, keepdims=True).sqrt()
            else:
                with torch.no_grad():
                    risk_weights = distortion_de(new_tau_hat, self.risk_type,
                                                 risk_param)
                q1_new_actions = torch.sum(risk_weights * new_presum_tau *
                                           z1_new_actions,
                                           dim=1,
                                           keepdims=True)
                q2_new_actions = torch.sum(risk_weights * new_presum_tau *
                                           z2_new_actions,
                                           dim=1,
                                           keepdims=True)
        q_new_actions = torch.min(q1_new_actions, q2_new_actions)

        policy_loss = (alpha * log_pi - q_new_actions).mean()
        gt.stamp('preback_policy', unique=False)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(),
                                              self.clip_norm)
        self.policy_optimizer.step()
        gt.stamp('backward_policy', unique=False)
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf1, self.target_zf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf2, self.target_zf2,
                                    self.soft_target_tau)
            if self.tau_type == 'fqf':
                ptu.soft_update_from_to(self.fp, self.target_fp,
                                        self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['ZF1 Loss'] = zf1_loss.item()
            self.eval_statistics['ZF2 Loss'] = zf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z1 Predictions',
                    ptu.get_numpy(z1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z2 Predictions',
                    ptu.get_numpy(z2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z Targets',
                    ptu.get_numpy(z_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()
        self._n_train_steps_total += 1
예제 #5
0
 def forward_ith_model(self, input, i):
     mean = self.ensemble[i](input)
     return self.get_dist(mean, ptu.ones_like(mean))
예제 #6
0
 def mode(self):
     mode = torch.max(
         torch.min(self.mean, ptu.ones_like(self.mean)),
         -1 * ptu.ones_like(self.mean),
     )
     return mode
예제 #7
0
    def forward(
        self,
        obs,
        action,
        use_network_action=False,
        state=None,
        batch_indices=None,
        raps_obs_indices=None,
    ):
        """
        Forward world model on trajectory.

        :param obs: (batch_size, path_length, obs_dim)
        :param action: List [(batch_size, path_length, high_level_action_dim), (batch_size, path_length, low_level_action_dim)]
        :param use_network_action:
        :param state:
        :param batch_indices:
        :param raps_obs_indices:

        :return post: Dict
            mean: (batch_size, path_length, stoch_size)
            std: (batch_size, path_length, stoch_size)
            stoch: (batch_size, path_length, stoch_size)
            deter: (batch_size, path_length, deter_size)
        :return prior:
            mean: (batch_size, path_length, stoch_size)
            std: (batch_size, path_length, stoch_size)
            stoch: (batch_size, path_length, stoch_size)
            deter: (batch_size, path_length, deter_size)
        :return post_dist:
            mean: (batch_size*(path_length), stoch_size)
            std: (batch_size*(path_length), stoch_size)
        :return prior_dist:
            mean: (batch_size*(path_length), stoch_size)
            std: (batch_size*(path_length), stoch_size)
        :return image_dist:
            mean: (batch_size*(path_length), obs_dim)
        :return reward_dist:
            mean: (batch_size*(raps_path_length), 1)
        :return pred_discount_dist
            logits: (batch_size*(raps_path_length), 1)
        :return embed: (batch_size, path_length, embed_dim)
        :return low_level_action_preds: (batch_size, path_length, low_level_action_dim)
        """
        assert (obs.shape[:2] == action[0].shape[:2] == action[1].shape[:2]
                ), "Obs and action first two dimensions should be the same."
        original_batch_size = action[1].shape[0]
        path_length = action[1].shape[1]
        if state is None:
            state = self.initial(original_batch_size)
        post, prior = (
            dict(mean=[], std=[], stoch=[], deter=[]),
            dict(mean=[], std=[], stoch=[], deter=[]),
        )
        obs_path_len = obs.shape[1]
        obs = obs.reshape(-1, obs.shape[-1])
        embed = self.encode(obs)
        embedding_size = embed.shape[1]
        embed = embed.reshape(original_batch_size, obs_path_len,
                              embedding_size)

        if obs_path_len < path_length:
            idxs = raps_obs_indices.tolist()
        else:
            idxs = np.arange(
                0,
                path_length,
                1,
            ).tolist()

        post, prior, low_level_action_preds = self.forward_batch(
            path_length,
            action,
            embed,
            post,
            prior,
            state,
            idxs,
            use_network_action,
        )

        for key in post.keys():
            post[key] = torch.cat(post[key], dim=1)

        for key in prior.keys():
            prior[key] = torch.cat(prior[key], dim=1)

        if self.use_prior_instead_of_posterior:
            # in this case, o_hat_t depends on a_t-1 and o_t-1, reset obs decoded from null state + action
            # only works when first state is reset obs and never changes
            feat = self.get_features(prior)
        else:
            feat = self.get_features(post)

        raps_obs_feat = feat[:, raps_obs_indices]
        raps_obs_feat = raps_obs_feat.reshape(-1, raps_obs_feat.shape[-1])

        if batch_indices.shape != raps_obs_indices.shape:
            feat = get_indexed_arr_from_batch_indices(feat,
                                                      batch_indices).reshape(
                                                          -1, feat.shape[-1])
        else:
            feat = feat[:, batch_indices]

        images = self.decode(feat)
        rewards = self.reward(raps_obs_feat)
        pred_discounts = self.pred_discount(raps_obs_feat)

        if batch_indices.shape != raps_obs_indices.shape:
            post_dist = self.get_dist(
                get_indexed_arr_from_batch_indices(post["mean"],
                                                   batch_indices).reshape(
                                                       -1,
                                                       post["mean"].shape[-1]),
                get_indexed_arr_from_batch_indices(post["std"],
                                                   batch_indices).reshape(
                                                       -1,
                                                       post["std"].shape[-1]),
            )
            prior_dist = self.get_dist(
                get_indexed_arr_from_batch_indices(
                    prior["mean"],
                    batch_indices).reshape(-1, prior["mean"].shape[-1]),
                get_indexed_arr_from_batch_indices(prior["std"],
                                                   batch_indices).reshape(
                                                       -1,
                                                       prior["std"].shape[-1]),
            )
        else:
            post_dist = self.get_dist(
                post["mean"][:, batch_indices].reshape(-1,
                                                       post["mean"].shape[-1]),
                post["std"][:, batch_indices].reshape(-1,
                                                      post["std"].shape[-1]),
            )
            prior_dist = self.get_dist(
                prior["mean"][:,
                              batch_indices].reshape(-1,
                                                     prior["mean"].shape[-1]),
                prior["std"][:, batch_indices].reshape(-1,
                                                       prior["std"].shape[-1]),
            )
        image_dist = self.get_dist(images, ptu.ones_like(images), dims=3)
        if self.reward_classifier:
            reward_dist = self.get_dist(rewards, None, normal=False)
        else:
            reward_dist = self.get_dist(rewards, ptu.ones_like(rewards))
        pred_discount_dist = self.get_dist(pred_discounts, None, normal=False)
        return (
            post,
            prior,
            post_dist,
            prior_dist,
            image_dist,
            reward_dist,
            pred_discount_dist,
            embed,
            low_level_action_preds,
        )
예제 #8
0
    def forward(self, obs, action):
        """
        Forward world model on trajectory.

        :param obs: (batch_size, path_length, obs_dim)
        :param action: (batch_size, path_length, action_dim)

        :return post: Dict
            mean: (batch_size, path_length, stoch_size)
            std: (batch_size, path_length, stoch_size)
            stoch: (batch_size, path_length, stoch_size)
            deter: (batch_size, path_length, deter_size)
        :return prior:
            mean: (batch_size, path_length, stoch_size)
            std: (batch_size, path_length, stoch_size)
            stoch: (batch_size, path_length, stoch_size)
            deter: (batch_size, path_length, deter_size)
        :return post_dist:
            mean: (batch_size*(path_length), stoch_size)
            std: (batch_size*(path_length), stoch_size)
        :return prior_dist:
            mean: (batch_size*(path_length), stoch_size)
            std: (batch_size*(path_length), stoch_size)
        :return image_dist:
            mean: (batch_size*(path_length), obs_dim)
        :return reward_dist:
            mean: (batch_size*(raps_path_length), 1)
        :return pred_discount_dist
            logits: (batch_size*(raps_path_length), 1)
        :return embed: (batch_size, path_length, embed_dim)
        """
        original_batch_size = obs.shape[0]
        state = self.initial(original_batch_size)
        path_length = obs.shape[1]
        post, prior = (
            dict(mean=[], std=[], stoch=[], deter=[]),
            dict(mean=[], std=[], stoch=[], deter=[]),
        )
        obs = obs.reshape(-1, obs.shape[-1])
        embed = self.encode(obs)
        embedding_size = embed.shape[1]
        embed = embed.reshape(original_batch_size, path_length, embedding_size)
        post, prior = self.forward_batch(
            path_length,
            action,
            embed,
            post,
            prior,
            state,
        )

        for key in post.keys():
            post[key] = torch.cat(post[key], dim=1)

        for key in prior.keys():
            prior[key] = torch.cat(prior[key], dim=1)

        if self.use_prior_instead_of_posterior:
            # In this case, o_hat_t depends on a_t-1 and o_t-1, reset obs decoded from null state + action.
            # This only works when first state is reset obs and never changes.
            feat = self.get_features(prior)
        else:
            feat = self.get_features(post)
        feat = feat.reshape(-1, feat.shape[-1])
        images = self.decode(feat)
        rewards = self.reward(feat)
        pred_discounts = self.pred_discount(feat)
        post_dist = self.get_dist(
            post["mean"].reshape(-1, post["mean"].shape[-1]),
            post["std"].reshape(-1, post["std"].shape[-1]),
        )
        prior_dist = self.get_dist(
            prior["mean"].reshape(-1, prior["mean"].shape[-1]),
            prior["std"].reshape(-1, prior["std"].shape[-1]),
        )
        image_dist = self.get_dist(images, ptu.ones_like(images), dims=3)
        reward_dist = self.get_dist(rewards, ptu.ones_like(rewards))
        pred_discount_dist = self.get_dist(pred_discounts, None, normal=False)
        return (
            post,
            prior,
            post_dist,
            prior_dist,
            image_dist,
            reward_dist,
            pred_discount_dist,
            embed,
        )
예제 #9
0
파일: td4.py 프로젝트: xtma/dsac
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        gt.stamp('preback_start', unique=False)
        """
        Update QF
        """
        with torch.no_grad():
            next_actions = self.target_policy(next_obs)
            noise = ptu.randn(next_actions.shape) * self.target_policy_noise
            noise = torch.clamp(noise, -self.target_policy_noise_clip,
                                self.target_policy_noise_clip)
            noisy_next_actions = torch.clamp(next_actions + noise,
                                             -self.max_action, self.max_action)

            next_tau, next_tau_hat, next_presum_tau = self.get_tau(
                next_obs, noisy_next_actions, fp=self.target_fp)
            target_z1_values = self.target_zf1(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z2_values = self.target_zf2(next_obs, noisy_next_actions,
                                               next_tau_hat)
            target_z_values = torch.min(target_z1_values, target_z2_values)
            z_target = self.reward_scale * rewards + (
                1. - terminals) * self.discount * target_z_values

        tau, tau_hat, presum_tau = self.get_tau(obs, actions, fp=self.fp)
        z1_pred = self.zf1(obs, actions, tau_hat)
        z2_pred = self.zf2(obs, actions, tau_hat)
        zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat,
                                     next_presum_tau)
        zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat,
                                     next_presum_tau)
        gt.stamp('preback_zf', unique=False)

        self.zf1_optimizer.zero_grad()
        zf1_loss.backward()
        self.zf1_optimizer.step()
        gt.stamp('backward_zf1', unique=False)

        self.zf2_optimizer.zero_grad()
        zf2_loss.backward()
        self.zf2_optimizer.step()
        gt.stamp('backward_zf2', unique=False)
        """
        Update FP
        """
        if self.tau_type == 'fqf':
            with torch.no_grad():
                dWdtau = 0.5 * (2 * self.zf1(obs, actions, tau[:, :-1]) -
                                z1_pred[:, :-1] - z1_pred[:, 1:] +
                                2 * self.zf2(obs, actions, tau[:, :-1]) -
                                z2_pred[:, :-1] - z2_pred[:, 1:])
                dWdtau /= dWdtau.shape[0]  # (N, T-1)
            gt.stamp('preback_fp', unique=False)
            self.fp_optimizer.zero_grad()
            tau[:, :-1].backward(gradient=dWdtau)
            self.fp_optimizer.step()
            gt.stamp('backward_fp', unique=False)
        """
        Policy Loss
        """
        policy_actions = self.policy(obs)
        risk_param = self.risk_schedule(self._n_train_steps_total)

        if self.risk_type == 'VaR':
            tau_ = ptu.ones_like(rewards) * risk_param
            q_new_actions = self.zf1(obs, policy_actions, tau_)
        else:
            with torch.no_grad():
                new_tau, new_tau_hat, new_presum_tau = self.get_tau(
                    obs, policy_actions, fp=self.fp)
            z_new_actions = self.zf1(obs, policy_actions, new_tau_hat)
            if self.risk_type in ['neutral', 'std']:
                q_new_actions = torch.sum(new_presum_tau * z_new_actions,
                                          dim=1,
                                          keepdims=True)
                if self.risk_type == 'std':
                    q_std = new_presum_tau * (z_new_actions -
                                              q_new_actions).pow(2)
                    q_new_actions -= risk_param * q_std.sum(
                        dim=1, keepdims=True).sqrt()
            else:
                with torch.no_grad():
                    risk_weights = distortion_de(new_tau_hat, self.risk_type,
                                                 risk_param)
                q_new_actions = torch.sum(risk_weights * new_presum_tau *
                                          z_new_actions,
                                          dim=1,
                                          keepdims=True)

        policy_loss = -q_new_actions.mean()

        gt.stamp('preback_policy', unique=False)

        if self._n_train_steps_total % self.policy_and_target_update_period == 0:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_grad = ptu.fast_clip_grad_norm(self.policy.parameters(),
                                                  self.clip_norm)
            self.policy_optimizer.step()
            gt.stamp('backward_policy', unique=False)

            ptu.soft_update_from_to(self.policy, self.target_policy,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf1, self.target_zf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.zf2, self.target_zf2,
                                    self.soft_target_tau)
            if self.tau_type == 'fqf':
                ptu.soft_update_from_to(self.fp, self.target_fp,
                                        self.soft_target_tau)
        gt.stamp('soft_update', unique=False)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['ZF1 Loss'] = zf1_loss.item()
            self.eval_statistics['ZF2 Loss'] = zf2_loss.item()
            self.eval_statistics['Policy Loss'] = policy_loss.item()
            self.eval_statistics['Policy Grad'] = policy_grad
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z1 Predictions',
                    ptu.get_numpy(z1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z2 Predictions',
                    ptu.get_numpy(z2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Z Targets',
                    ptu.get_numpy(z_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy Action',
                    ptu.get_numpy(policy_actions),
                ))

        self._n_train_steps_total += 1
예제 #10
0
    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        context = batch['context']

        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist, p_z, task_z_with_grad = self.agent(
            obs,
            context,
            return_latent_posterior_and_task_z=True,
        )
        task_z_detached = task_z_with_grad.detach()
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        log_pi = log_pi.unsqueeze(1)
        next_dist = self.agent(next_obs, context)

        if self._debug_ignore_context:
            task_z_with_grad = task_z_with_grad * 0

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        unscaled_rewards_flat = rewards.view(t * b, 1)
        rewards_flat = unscaled_rewards_flat * self.reward_scale
        terms_flat = terminals.view(t * b, 1)

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        if self.backprop_q_loss_into_encoder:
            q1_pred = self.qf1(obs, actions, task_z_with_grad)
            q2_pred = self.qf2(obs, actions, task_z_with_grad)
        else:
            q1_pred = self.qf1(obs, actions, task_z_detached)
            q2_pred = self.qf2(obs, actions, task_z_detached)
        # Make sure policy accounts for squashing functions like tanh correctly!
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        new_log_pi = new_log_pi.unsqueeze(1)
        with torch.no_grad():
            target_q_values = torch.min(
                self.target_qf1(next_obs, new_next_actions, task_z_detached),
                self.target_qf2(next_obs, new_next_actions, task_z_detached),
            ) - alpha * new_log_pi

        q_target = rewards_flat + (
            1. - terms_flat) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Context Encoder Loss
        """
        if self._debug_use_ground_truth_context:
            kl_div = kl_loss = ptu.zeros(0)
        else:
            kl_div = kl_divergence(p_z,
                                   self.agent.latent_prior).mean(dim=0).sum()
            kl_loss = self.kl_lambda * kl_div

        if self.train_context_decoder:
            # TODO: change to use a distribution
            reward_pred = self.context_decoder(obs, actions, task_z_with_grad)
            reward_prediction_loss = ((reward_pred -
                                       unscaled_rewards_flat)**2).mean()
            context_loss = kl_loss + reward_prediction_loss
        else:
            context_loss = kl_loss
            reward_prediction_loss = ptu.zeros(1)
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions, task_z_detached)
        qf2_new_actions = self.qf2(obs, new_obs_actions, task_z_detached)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.vf_K > 1:
            vs = []
            for i in range(self.vf_K):
                u = dist.sample()
                q1 = self.qf1(obs, u, task_z_detached)
                q2 = self.qf2(obs, u, task_z_detached)
                v = torch.min(q1, q2)
                # v = q1
                vs.append(v)
            v_pi = torch.cat(vs, 1).mean(dim=1)
        else:
            # v_pi = self.qf1(obs, new_obs_actions)
            v1_pi = self.qf1(obs, new_obs_actions, task_z_detached)
            v2_pi = self.qf2(obs, new_obs_actions, task_z_detached)
            v_pi = torch.min(v1_pi, v2_pi)

        u = actions
        if self.awr_min_q:
            q_adv = torch.min(q1_pred, q2_pred)
        else:
            q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)
            beta_loss = ptu.zeros(1)

        score = q_adv - v_pi
        if self.mask_positive_advantage:
            score = torch.sign(score)

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        weights = batch.get('weights', None)
        if self.weight_loss and weights is None:
            if self.normalize_over_batch == True:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif self.normalize_over_batch == False:
                weights = score
            elif self.normalize_over_batch == 'uniform':
                weights = F.softmax(ptu.ones_like(score) / beta, dim=0)
            else:
                raise ValueError(self.normalize_over_batch)
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.train_reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            if self.train_encoder_decoder:
                self.context_optimizer.zero_grad()
            if self.train_agent:
                self.qf1_optimizer.zero_grad()
                self.qf2_optimizer.zero_grad()
            context_loss.backward(retain_graph=True)
            # retain graph because the encoder is trained by both QF losses
            qf1_loss.backward(retain_graph=True)
            qf2_loss.backward()
            if self.train_agent:
                self.qf1_optimizer.step()
                self.qf2_optimizer.step()
            if self.train_encoder_decoder:
                self.context_optimizer.step()

        if self.train_agent:
            if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
        self._num_gradient_steps += 1
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics['task_embedding/kl_divergence'] = (
                ptu.get_numpy(kl_div))
            self.eval_statistics['task_embedding/kl_loss'] = (
                ptu.get_numpy(kl_loss))
            self.eval_statistics['task_embedding/reward_prediction_loss'] = (
                ptu.get_numpy(reward_prediction_loss))
            self.eval_statistics['task_embedding/context_loss'] = (
                ptu.get_numpy(context_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))
            self.eval_statistics['reparam_weight'] = self.train_reparam_weight
            self.eval_statistics['num_gradient_steps'] = (
                self._num_gradient_steps)

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

        self._n_train_steps_total += 1
예제 #11
0
    def update_parameters(self, memory, batch_size, updates):
        """
        Update parameters of SAC-NF
        Exactly like SAC, but keep two separate Adam optimizers for the Gaussian policy AND the NF layers
        .backward() on them sequentially
        """
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        obs = torch.FloatTensor(state_batch).to(self.device)
        next_obs = torch.FloatTensor(next_state_batch).to(self.device)
        actions = torch.FloatTensor(action_batch).to(self.device)
        rewards = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        # for visualization
        #with torch.no_grad():
        #    sample_size = 500
        #    _action, _logprob, _preact, _, _ = self.policy.evaluate(state_batch, num_samples=sample_size)
        #    _action = _action.cpu().detach()
        #    _preact = _preact.cpu().detach()
        #    _logprob = _logprob.view(batch_size, sample_size, -1).cpu().detach()
        #    info = {
        #         'action': _action,
        #         'preact': _preact,
        #         'logprob': _logprob,
        #    }
        info = {}

        ''' update critic '''
        with torch.no_grad():
            new_next_actions, next_state_log_pi, _,_,_ = self.policy.evaluate(next_obs)
            next_tau, next_tau_hat, next_presum_tau = self.get_tau( new_next_actions)
            target_z1_values= self.target_zf1(next_obs, new_next_actions,next_tau_hat)
            target_z2_values = self.target_zf2(next_obs, new_next_actions,next_tau_hat)
            min_qf_next_target = torch.min(target_z1_values, target_z2_values) - self.alpha * next_state_log_pi
            z_target = rewards + mask_batch * self.gamma * (min_qf_next_target)
        tau, tau_hat, presum_tau = self.get_tau(actions)
        z1_pred = self.zf1(obs, actions, tau_hat)
        z2_pred = self.zf2(obs, actions, tau_hat)
          # Two Q-functions to mitigate positive bias in the policy improvement step
        zf1_loss = self.zf_criterion(z1_pred, z_target, tau_hat, next_presum_tau)
        zf2_loss = self.zf_criterion(z2_pred, z_target, tau_hat, next_presum_tau)

        new_actions, log_pi, _,_,_ = self.policy.evaluate(obs)

        # update
        self.zf1_optimizer.zero_grad()
        zf1_loss.backward()
        self.zf1_optimizer.step()
        self.zf2_optimizer.zero_grad()
        zf2_loss.backward()
        self.zf2_optimizer.step()
        risk_param = self.risk_schedule(self._n_train_steps_total)

        if self.risk_type == 'VaR':
            tau_ = ptu.ones_like(rewards) * risk_param
            q1_new_actions = self.zf1(obs, new_actions, tau_)
            q2_new_actions = self.zf2(obs, new_actions, tau_)
        else:
            with torch.no_grad():
                new_tau, new_tau_hat, new_presum_tau = self.get_tau(obs, new_actions )
            z1_new_actions = self.zf1(obs, new_actions, new_tau_hat)
            z2_new_actions = self.zf2(obs, new_actions, new_tau_hat)
            if self.risk_type in ['neutral', 'std']:
                q1_new_actions = torch.sum(new_presum_tau * z1_new_actions, dim=1, keepdims=True)
                q2_new_actions = torch.sum(new_presum_tau * z2_new_actions, dim=1, keepdims=True)
                if self.risk_type == 'std':
                    q1_std = new_presum_tau * (z1_new_actions - q1_new_actions).pow(2)
                    q2_std = new_presum_tau * (z2_new_actions - q2_new_actions).pow(2)
                    q1_new_actions -= risk_param * q1_std.sum(dim=1, keepdims=True).sqrt()
                    q2_new_actions -= risk_param * q2_std.sum(dim=1, keepdims=True).sqrt()
            else:
                with torch.no_grad():
                    risk_weights = distortion_de(new_tau_hat, self.risk_type, risk_param)
                q1_new_actions = torch.sum(risk_weights * new_presum_tau * z1_new_actions, dim=1, keepdims=True)
                q2_new_actions = torch.sum(risk_weights * new_presum_tau * z2_new_actions, dim=1, keepdims=True)
        q_new_actions = torch.min(q1_new_actions, q2_new_actions)
        policy_loss = (self.alpha * log_pi - q_new_actions).mean()
        nf_loss = ((self.alpha * log_pi) - q_new_actions).mean()
        self.policy_optim.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optim.step()

        self.nf_optim.zero_grad()
        nf_loss.backward()
        self.nf_optim.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone() # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs

        # update target value fuctions
        if updates % self.target_update_interval == 0:
            soft_update(self.target_zf1, self.zf1, self.tau)
            soft_update(self.target_zf2, self.zf2, self.tau)
        return zf1_loss.item(), zf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item(), info