Beispiel #1
0
    def _train_policy_latent(self, replay_buffer, eval_fn):
        for it in range(self.args["actor_iterations"]):
            batch = replay_buffer.sample(self.args["actor_batch_size"])
            batch = to_torch(batch, torch.float, device=self.args["device"])
            rew = batch.rew
            done = batch.done
            obs = batch.obs
            act = batch.act
            obs_next = batch.obs_next

            # Critic Training
            with torch.no_grad():
                _, _, next_action = self.actor_target(obs_next,
                                                      self.vae.decode)

                target_q1 = self.critic1_target(obs_next, next_action)
                target_q2 = self.critic2_target(obs_next, next_action)

                target_q = self.args["lmbda"] * torch.min(
                    target_q1, target_q2
                ) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2)
                target_q = rew + (1 - done) * self.args["discount"] * target_q

            current_q1 = self.critic1(obs, act)
            current_q2 = self.critic2(obs, act)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
                current_q2, target_q)

            self.critic1_opt.zero_grad()
            self.critic2_opt.zero_grad()
            critic_loss.backward()
            self.critic1_opt.step()
            self.critic2_opt.step()

            # Actor Training
            latent_actions, mid_actions, actions = self.actor(
                obs, self.vae.decode)
            actor_loss = -self.critic1(obs, actions).mean()

            self.actor.zero_grad()
            actor_loss.backward()
            self.actor_opt.step()

            # update target network
            self._sync_weight(self.actor_target, self.actor)
            self._sync_weight(self.critic1_target, self.critic1)
            self._sync_weight(self.critic2_target, self.critic2)

            if (it + 1) % 1000 == 0:
                print("mid_actions :", torch.abs(actions - mid_actions).mean())
                if eval_fn is None:
                    self.eval_policy()
                else:
                    self.vae._actor = copy.deepcopy(self.actor)
                    res = eval_fn(self.get_policy())
                    self.log_res((it + 1) // 1000, res)
Beispiel #2
0
    def _train_policy(self, train_buffer, callback_fn):
        for it in range(self.args["actor_iterations"]):
            batch = train_buffer.sample(self.args["actor_batch_size"])
            batch = to_torch(batch, torch.float, device=self.args["device"])
            rew = batch.rew
            done = batch.done
            obs = batch.obs
            act = batch.act
            obs_next = batch.obs_next

            # Critic Training
            with torch.no_grad():
                action_next_actor,_ = self.actor_target(obs_next)
                action_next_vae = self.vae.decode(obs_next, z = action_next_actor)

                target_q1 = self.critic1_target(obs_next, action_next_vae)
                target_q2 = self.critic2_target(obs_next, action_next_vae)
 
                target_q = self.args["lmbda"] * torch.min(target_q1, target_q2) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2)
                target_q = rew + (1 - done) * self.args["discount"] * target_q

            current_q1 = self.critic1(obs, act)
            current_q2 = self.critic2(obs, act)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            self.critic1_opt.zero_grad()
            self.critic2_opt.zero_grad()
            critic_loss.backward()
            self.critic1_opt.step()
            self.critic2_opt.step()
            
            # Actor Training
            action_actor,_ = self.actor(obs)
            action_vae = self.vae.decode(obs, z = action_actor)
            actor_loss = -self.critic1(obs, action_vae).mean()
            
            self.actor.zero_grad()
            actor_loss.backward()
            self.actor_opt.step()

            # update target network
            self._sync_weight(self.actor_target, self.actor)
            self._sync_weight(self.critic1_target, self.critic1)
            self._sync_weight(self.critic2_target, self.critic2)
            
            if (it + 1) % 1000 == 0:
                if callback_fn is None:
                    self.eval_policy()
                else:
                    res = callback_fn(self.get_policy())
                    self.log_res((it + 1) // 1000, res)
Beispiel #3
0
    def train(self, train_buffer, val_buffer, callback_fn):
        training_iters = 0
        while training_iters < self.args["max_timesteps"]:

            # Sample replay buffer
            batch = train_buffer.sample(self.args["batch_size"])
            batch = to_torch(batch, torch.float, device=self.args["device"])
            reward = batch.rew
            done = batch.done
            state = batch.obs
            action = batch.act.to(torch.int64)
            next_state = batch.obs_next

            # Compute the target Q value
            with torch.no_grad():
                q, imt, i = self.Q(next_state)
                imt = imt.exp()
                imt = (imt / imt.max(1, keepdim=True)[0] >
                       self.threshold).float()

                # Use large negative number to mask actions from argmax
                next_action = (imt * q + (1 - imt) * -1e8).argmax(1,
                                                                  keepdim=True)

                q, imt, i = self.Q_target(next_state)
                target_Q = reward + done * self.discount * q.gather(
                    1, next_action).reshape(-1, 1)

            # Get current Q estimate
            current_Q, imt, i = self.Q(state)

            current_Q = current_Q.gather(1, action)

            # Compute Q loss
            q_loss = F.smooth_l1_loss(current_Q, target_Q)
            i_loss = F.nll_loss(imt, action.reshape(-1))

            Q_loss = q_loss + i_loss + 1e-2 * i.pow(2).mean()

            # Optimize the Q
            self.Q_optimizer.zero_grad()
            Q_loss.backward()
            self.Q_optimizer.step()

            # Update target network by polyak or full copy every X iterations.
            self.maybe_update_target()
            training_iters += 1
            #print(training_iters ,self.args["eval_freq"])
            if training_iters % self.args["eval_freq"] == 0:
                res = callback_fn(self.get_policy())

                self.log_res(training_iters // self.args["eval_freq"], res)
Beispiel #4
0
    def _train_vae_step(self, batch):
        batch = to_torch(batch, torch.float, device=self.args["device"])
        obs = batch.obs
        act = batch.act
        
        recon, mean, std = self.vae(obs, act)
        recon_loss = F.mse_loss(recon, act)
        KL_loss = -self.args["vae_kl_weight"] * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        self.vae_opt.zero_grad()
        vae_loss.backward()
        self.vae_opt.step()
        
        return vae_loss.cpu().data.numpy(), recon_loss.cpu().data.numpy(), KL_loss.cpu().data.numpy()
Beispiel #5
0
    def _train_bc(self, buffer):
        bc_loss_list = []
        rew_loss_list = []
        for step in range(100000):
            for i in range(len(self.bcs)):
                batch = buffer.sample(256)
                batch = to_torch(batch,
                                 torch.float,
                                 device=self.args["device"])
                rew = batch.rew
                obs = batch.obs
                act = batch.act
                obs_next = batch.obs_next

                obs_act = torch.cat([obs, act], axis=1)
                obs_next_pre, _ = self.bcs[i](obs_act)
                rew_pre, _ = self.rews[i](torch.cat([obs, act, obs_next],
                                                    axis=1))

                bc_loss = F.mse_loss(obs_next_pre, obs_next)
                rew_loss = F.mse_loss(rew_pre, rew)

                self.bcs_opt[i].zero_grad()
                self.rews_opt[i].zero_grad()
                bc_loss.backward()
                rew_loss.backward()
                self.bcs_opt[i].step()
                self.rews_opt[i].step()

                bc_loss_list.append(bc_loss.item())
                rew_loss_list.append(rew_loss.item())

            if (step + 1) % 1000 == 0:
                logger.info('BC Epoch : {}, bc_loss : {:.4}',
                            (step + 1) // 1000, np.mean(bc_loss_list))
                logger.info('BC Epoch : {}, recon_loss : {:.4}',
                            (step + 1) // 1000, np.mean(rew_loss_list))
Beispiel #6
0
    def _train(self, batch):
        self._current_epoch += 1
        batch = to_torch(batch, torch.float, device=self.args["device"])
        rewards = batch.rew
        terminals = batch.done
        obs = batch.obs
        actions = batch.act
        next_obs = batch.obs_next
        """
        Policy and Alpha Loss
        """
        new_obs_actions, log_pi = self.forward(obs)

        if self.args["use_automatic_entropy_tuning"]:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_opt.zero_grad()
            alpha_loss.backward()
            self.alpha_opt.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self._current_epoch < self.args["policy_bc_steps"]:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.actor.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
        else:
            q_new_actions = torch.min(
                self.critic1(obs, new_obs_actions),
                self.critic2(obs, new_obs_actions),
            )

            policy_loss = (alpha * log_pi - q_new_actions).mean()
        self.actor_opt.zero_grad()
        policy_loss.backward()
        self.actor_opt.step()
        """
        QF Loss
        """
        q1_pred = self.critic1(obs, actions)
        q2_pred = self.critic2(obs, actions)

        new_next_actions, new_log_pi = self.forward(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        new_curr_actions, new_curr_log_pi = self.forward(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        if self.args["type_q_backup"] == "max":
            target_q_values = torch.max(
                self.critic1_target(next_obs, new_next_actions),
                self.critic2_target(next_obs, new_next_actions),
            )
            target_q_values = target_q_values - alpha * new_log_pi

        elif self.args["type_q_backup"] == "min":
            target_q_values = torch.min(
                self.critic1_target(next_obs, new_next_actions),
                self.critic2_target(next_obs, new_next_actions),
            )
            target_q_values = target_q_values - alpha * new_log_pi
        elif self.args["type_q_backup"] == "medium":
            target_q1_next = self.critic1_target(next_obs, new_next_actions)
            target_q2_next = self.critic2_target(next_obs, new_next_actions)
            target_q_values = self.args["q_backup_lmbda"] * torch.min(target_q1_next, target_q2_next) \
                            + (1 - self.args["q_backup_lmbda"]) * torch.max(target_q1_next, target_q2_next)
            target_q_values = target_q_values - alpha * new_log_pi

        else:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(
                next_obs, num_actions=10, network=self.forward)
            target_qf1_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.critic1).max(1)[0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.critic2).max(1)[0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.args["reward_scale"] * rewards + (
            1. - terminals) * self.args["discount"] * target_q_values.detach()

        qf1_loss = self.critic_criterion(q1_pred, q_target)
        qf2_loss = self.critic_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = torch.FloatTensor(
            q2_pred.shape[0] * self.args["num_random"],
            actions.shape[-1]).uniform_(-1, 1).to(self.args["device"])
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(
            obs, num_actions=self.args["num_random"], network=self.forward)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(
            next_obs,
            num_actions=self.args["num_random"],
            network=self.forward)
        q1_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.critic1)
        q2_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.critic2)
        q1_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.critic1)
        q2_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.critic2)
        q1_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.critic1)
        q2_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.critic2)

        cat_q1 = torch.cat(
            [q1_rand,
             q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1)
        cat_q2 = torch.cat(
            [q2_rand,
             q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1)

        if self.args["min_q_version"] == 3:
            # importance sammpled version
            random_density = np.log(0.5**curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat([
                q1_rand - random_density, q1_next_actions -
                new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()
            ], 1)
            cat_q2 = torch.cat([
                q2_rand - random_density, q2_next_actions -
                new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()
            ], 1)

        min_qf1_loss = torch.logsumexp(
            cat_q1 / self.args["temp"],
            dim=1,
        ).mean() * self.args["min_q_weight"] * self.args["temp"]
        min_qf2_loss = torch.logsumexp(
            cat_q2 / self.args["temp"],
            dim=1,
        ).mean() * self.args["min_q_weight"] * self.args["temp"]
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.args["min_q_weight"]
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.args["min_q_weight"]

        if self.args["lagrange_thresh"] >= 0:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(),
                                      min=0.0,
                                      max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss -
                                          self.args["lagrange_thresh"])
            min_qf2_loss = alpha_prime * (min_qf2_loss -
                                          self.args["lagrange_thresh"])

            self.alpha_prime_opt.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_opt.step()

        qf1_loss = self.args["explore"] * qf1_loss + (
            2 - self.args["explore"]) * min_qf1_loss
        qf2_loss = self.args["explore"] * qf2_loss + (
            2 - self.args["explore"]) * min_qf2_loss
        """
        Update critic networks
        """
        self.critic1_opt.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic1_opt.step()

        self.critic2_opt.zero_grad()
        qf2_loss.backward()
        self.critic2_opt.step()
        """
        Soft Updates target network
        """
        self._sync_weight(self.critic1_target, self.critic1,
                          self.args["soft_target_tau"])
        self._sync_weight(self.critic2_target, self.critic2,
                          self.args["soft_target_tau"])

        self._n_train_steps_total += 1
Beispiel #7
0
    def _train_policy(self, replay_buffer, eval_fn):
        for it in range(self.args["actor_iterations"]):
            batch = replay_buffer.sample(self.args["actor_batch_size"])
            batch = to_torch(batch, torch.float, device=self.args["device"])
            rew = batch.rew
            done = batch.done
            obs = batch.obs
            act = batch.act
            obs_next = batch.obs_next

            rew_list = []
            obs = [obs for _ in self.bcs]
            for i in range(5):
                act = [self.actor_target(o)[0] for o in obs]
                obs_act = [torch.cat([o, a], axis=1) for o, a in zip(obs, act)]
                obs_next = [net(oa)[0] for net, oa in zip(self.bcs, obs_act)]
                obs_act_obs = [
                    torch.cat([oa, on], axis=1)
                    for oa, on in zip(obs_act, obs_next)
                ]
                rew = [net(oao)[0] for net, oao in zip(self.rews, obs_act_obs)]

                if i == 0:
                    r_p = [
                        torch.mean(torch.abs(self.vae(o, a)[0] - oa).detach(),
                                   axis=1)
                        for o, a, oa in zip(obs, act, obs_act)
                    ]
                    r_p = torch.mean(torch.cat(rew, axis=1), axis=1)

                rew = torch.cat(rew, axis=1)
                #rew_min,_ = torch.min(rew, axis=1)
                #rew_mean = torch.mean(rew, axis=1)

                #rew = (0.5 * rew_min) + (0.5 *rew_mean) - (0.5*r_p)

                rew_list.append(rew)
                obs = obs_next
            r_e = None
            for index in range(len(rew_list)):
                if r_e is None:
                    r_e = rew_list[index]
                else:
                    r_e += rew_list[index] * (0.99**index)

            rew_min, _ = torch.min(r_e, axis=1)
            rew_mean = torch.mean(r_e, axis=1)

            rew = (0.5 * rew_min) + (0.5 * rew_mean) - (0.5 * r_p)
            done = batch.done
            obs = batch.obs
            act = batch.act
            obs_next = batch.obs_next

            # Critic Training
            with torch.no_grad():
                action_next, _ = self.actor_target(obs_next)
                target_q1 = self.critic1_target(obs_next, action_next)
                target_q2 = self.critic2_target(obs_next, action_next)

                target_q = self.args["lmbda"] * torch.min(
                    target_q1, target_q2
                ) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2)
                target_q = rew + (1 - done) * self.args["discount"] * target_q

            current_q1 = self.critic1(obs, act)
            current_q2 = self.critic2(obs, act)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
                current_q2, target_q)

            self.critic1_opt.zero_grad()
            self.critic2_opt.zero_grad()
            critic_loss.backward()
            self.critic1_opt.step()
            self.critic2_opt.step()

            # Actor Training
            action, _ = self.actor(obs)

            actor_loss = -self.critic1(obs, action).mean()

            self.actor.zero_grad()
            actor_loss.backward()
            self.actor_opt.step()

            # update target network
            self._sync_weight(self.actor_target, self.actor)
            self._sync_weight(self.critic1_target, self.critic1)
            self._sync_weight(self.critic2_target, self.critic2)

            if (it + 1) % 1000 == 0:
                if eval_fn is None:
                    self.eval_policy()
                else:
                    self.vae._actor = copy.deepcopy(self.actor)
                    res = eval_fn(self.get_policy())
                    self.log_res((it + 1) // 1000, res)
Beispiel #8
0
 def get_action(self, obs):
     obs_tensor = to_torch(obs, device=next(self.parameters()).device, dtype=torch.float32)
     act = to_array_as(self.policy_infer(obs_tensor), obs)
     
     return act