Пример #1
0
 def _update_network(self):
     # sample the episodes
     transitions = self.buffer.sample(self.args.batch_size)
     # pre-process the observation and goal
     o, o_next, g = transitions['obs'], transitions[
         'obs_next'], transitions['g']
     transitions['obs'], transitions['g'] = self._preproc_og(o, g)
     transitions['obs_next'], transitions['g_next'] = self._preproc_og(
         o_next, g)
     # start to do the update
     obs_norm = self.o_norm.normalize(transitions['obs'])
     g_norm = self.g_norm.normalize(transitions['g'])
     inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
     obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
     g_next_norm = self.g_norm.normalize(transitions['g_next'])
     inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
     # transfer them into the tensor
     inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
     inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                            dtype=torch.float32)
     actions_tensor = torch.tensor(transitions['actions'],
                                   dtype=torch.float32)
     r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
     if self.args.cuda:
         inputs_norm_tensor = inputs_norm_tensor.cuda()
         inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
         actions_tensor = actions_tensor.cuda()
         r_tensor = r_tensor.cuda()
     # calculate the target Q value function
     with torch.no_grad():
         # do the normalization
         # concatenate the stuffs
         actions_next = self.actor_target_network(inputs_next_norm_tensor)
         q_next_value = self.critic_target_network(inputs_next_norm_tensor,
                                                   actions_next)
         q_next_value = q_next_value.detach()
         target_q_value = r_tensor + self.args.gamma * q_next_value
         target_q_value = target_q_value.detach()
         # clip the q value
         clip_return = 1 / (1 - self.args.gamma)
         target_q_value = torch.clamp(target_q_value, -clip_return, 0)
     # the q loss
     real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
     critic_loss = (target_q_value - real_q_value).pow(2).mean()
     # the actor loss
     actions_real = self.actor_network(inputs_norm_tensor)
     actor_loss = -self.critic_network(inputs_norm_tensor,
                                       actions_real).mean()
     actor_loss += self.args.action_l2 * (
         actions_real / self.env_params['action_max']).pow(2).mean()
     # start to update the network
     self.actor_optim.zero_grad()
     actor_loss.backward()
     sync_grads(self.actor_network)
     self.actor_optim.step()
     # update the critic_network
     self.critic_optim.zero_grad()
     critic_loss.backward()
     sync_grads(self.critic_network)
     self.critic_optim.step()
Пример #2
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32).reshape(
            transitions['r'].shape[0], -1)
        # if self.args.scale_rewards:
        #     r_tensor = r_tensor/self.reward_scales
        #         print(r_tensor.shape)
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda()
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
            actions_tensor = actions_tensor.cuda()
            r_tensor = r_tensor.cuda()

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            rep = self.actor_target_network['rep'](inputs_next_norm_tensor)
            actions_next = self.actor_target_network[0](rep)
            q_next_value = self.critic_target_network(inputs_next_norm_tensor,
                                                      actions_next)
            q_next_value = q_next_value.detach()
            # print('r_tensor_shape :', r_tensor.shape)
            # print('q_next_value :', q_next_value.shape)
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the q loss
        # if self.args.ddpg_vq_version=='ver3':
        #     real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor)
        # else:
        real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)

        if self.args.critic_loss_type == 'MSE':
            critic_loss = (target_q_value - real_q_value).pow(2).mean()
        elif self.args.critic_loss_type == 'MAE':
            critic_loss = (target_q_value - real_q_value).abs().mean()

        # if self.args.actor_loss_type=='default':
        # actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).mean()
        # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # update_index = None

        tasks = list(range(self.env.num_reward))
        # elif self.args.actor_loss_type=='min':
        #     actor_loss = -((self.critic_network(inputs_norm_tensor, actions_real)).detach().cpu().numpy()/self.reward_scales).min(axis=1)[0].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        #     update_index = None
        # elif self.args.actor_loss_type=='batch_min':
        #     update_index = np.argmin((self.critic_network(inputs_norm_tensor, actions_real)).detach().cpu().numpy().mean(axis=0)/self.reward_scales)
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # elif self.args.actor_loss_type=='softmin':
        #     actor_loss = -(self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # elif self.args.actor_loss_type=='strict_random':
        #     update_index = np.random.choice(self.env.num_reward)
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean()
        #     # print(actor_loss)
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # elif self.args.actor_loss_type=='random':
        #     update_index_sampling_prob= F.softmin((self.critic_network(inputs_norm_tensor, actions_real).cpu().mean(axis=0))\
        #      /(self.args.softmax_temperature*torch.Tensor(self.reward_scales).cpu()), dim=0).detach().cpu().numpy()

        #     update_index = np.random.choice(self.env.num_reward, p= update_index_sampling_prob)
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # This is MGDA

        loss_data = {}
        grads = {}
        scale = {}
        mask = None
        masks = {}

        for t_num, t in enumerate(tasks):
            # Comptue gradients of each loss function wrt parameters
            self.actor_optim.zero_grad()
            # rep, mask = model['rep'](images, mask)
            # out_t, masks[t] = model[t](rep, None)
            # loss = loss_fn[t](out_t, labels[t])
            rep = self.actor_network['rep'](inputs_norm_tensor)
            actions_real = self.actor_network[t_num](rep)
            loss = -(self.critic_network(inputs_norm_tensor,
                                         actions_real))[:, t_num].mean()
            loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            # print(loss)
            loss_data[t] = loss.data.item()
            loss.backward()
            grads[t] = []
            for param in self.actor_network['rep'].parameters():
                if param.grad is not None:
                    # with torch.no_grad:
                    grads[t].append(
                        Variable(param.grad.data.clone(), requires_grad=False))
        rep.detach_()
        actions_real.detach_()
        # Normalize all gradients, this is optional and not included in the paper.
        # print(grads)
        gn = gradient_normalizers(grads, loss_data,
                                  self.args.normalization_type)
        for t in tasks:
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / gn[t]

            # Frank-Wolfe iteration to compute scales.
        sol, min_norm = MinNormSolver.find_min_norm_element(
            [grads[t] for t in tasks])
        for i, t in enumerate(tasks):
            scale[t] = float(sol[i])

        # print(scale)
        # Scaled back-propagation
        # actions_real_2 = self.actor_network(inputs_norm_tensor)

        self.actor_optim.zero_grad()
        # rep, _ = model['rep'](images, mask)
        each_loss = []
        rep = self.actor_network['rep'](inputs_norm_tensor)
        for i, t in enumerate(tasks):
            actions_real = self.actor_network[t_num](rep)
            loss_t = -(self.critic_network(inputs_norm_tensor,
                                           actions_real))[:, i].mean()
            # print(loss_t)
            loss_t += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            each_loss.append(loss_t)
            loss_data[t] = loss_t.data.item()
            if i > 0:
                loss = loss + scale[t] * loss_t
            else:
                loss = scale[t] * loss_t
        loss.backward()
        self.actor_optim.step()
        # print(actions_real)
        # scale_tensor = torch.Tensor(list(scale.values()))
        # if self.args.cuda:
        #     scale_tensor = scale_tensor.cuda()
        #         # for t_num, t in enumerate(tasks):
        #     # out_t, _ = model[t](rep, masks[t])
        #     # loss_t = loss_fn[t](out_t, labels[t])
        # loss_t = -(self.critic_network(inputs_norm_tensor, actions_real)).mean(axis=0)
        # # print(loss_t)
        # loss_t += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # print(loss_t)
        # loss_t.retain_grad()
        # loss_data[t] = loss_t.data.item()
        # if i > 0:
        #     loss = loss + scale[t]*loss_t
        # else:
        #     loss = scale[t]*loss_t
        # print(.is_cuda)
        # print(loss_t.scale_tensoris_cuda)
        # print(scale_tensor)
        # print(loss_t)

        # actor_loss = torch.matmul(scale_tensor,loss_t)
        # print(loss)

        # actions_real_2.detach_()
        # self.actor_optim.zero_grad()
        # actor_loss.backward()
        # sync_grads(self.actor_network)
        # self.actor_optim.step()
        # self.actor_optim.zero_grad()
        # actor_loss.backward()
        # sync_grads(self.actor_network)
        # self.actor_optim.step()
        # update the critic_network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self.critic_network)
        self.critic_optim.step()

        return loss, each_loss, critic_loss, scale
Пример #3
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        for i in range(self.env.num_reward):
            r_tensor = torch.tensor(transitions['r'],
                                    dtype=torch.float32).reshape(
                                        transitions['r'].shape[0], -1)[:, i]
            # print(r_tensor.shape)
            if self.args.cuda:
                inputs_norm_tensor = inputs_norm_tensor.cuda()
                inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
                actions_tensor = actions_tensor.cuda()
                r_tensor = r_tensor.cuda()
            # calculate the target Q value function
            with torch.no_grad():
                # do the normalization
                # concatenate the stuffs
                actions_next = self.actor_target_network(
                    inputs_next_norm_tensor)
                q_next_value = self.critic_target_network[i](
                    inputs_next_norm_tensor, actions_next)
                q_next_value = q_next_value.detach()
                # print('r_tensor_shape :', r_tensor.shape)
                # print('q_next_value :', q_next_value.shape)
                target_q_value = r_tensor + self.args.gamma * q_next_value
                target_q_value = target_q_value.detach()
                # clip the q value
                clip_return = 1 / (1 - self.args.gamma)
                target_q_value = torch.clamp(target_q_value, -clip_return, 0)
            # the q loss
            # if self.args.ddpg_vq_version=='ver3':
            #     real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor)
            # else:
            real_q_value = self.critic_network[i](inputs_norm_tensor,
                                                  actions_tensor)

            # print('target_q_value :', target_q_value.shape)
            # print('real_q_value :', real_q_value.shape)
            # print((target_q_value - real_q_value).shape)
            # print(( (target_q_value - real_q_value).pow(2)).shape)
            # print((target_q_value - real_q_value).pow(2).mean().shape)
            if self.args.critic_loss_type == 'MSE':
                critic_loss = (target_q_value - real_q_value).pow(2).mean()
            # elif self.args.critic_loss_type=='max':
            #     critic_loss, _ = torch.max((target_q_value - real_q_value).pow(2),dim=1)
            #     critic_loss = torch.mean(critic_loss)
            elif self.args.critic_loss_type == 'MAE':
                critic_loss = (target_q_value - real_q_value).abs().mean()

                # print(critic_loss.shape)
                # .mean()
                # update the critic_network
            self.critic_optim[i].zero_grad()
            critic_loss.backward()
            sync_grads(self.critic_network[i])
            self.critic_optim[i].step()

    #         print('critic_loss :',critic_loss.shape)
    # the actor loss
        actions_real = self.actor_network(inputs_norm_tensor)
        update_index_sampling_prob = []
        for i in range(self.env.num_reward):
            update_index_sampling_prob.append(self.critic_network[i](
                inputs_norm_tensor,
                actions_real).mean().data.cpu().numpy().item())
        update_index_sampling_prob = torch.Tensor(
            np.array(update_index_sampling_prob))
        update_index_sampling_prob = torch.nn.Softmin(dim=0)(
            update_index_sampling_prob /
            self.args.softmax_temperature).numpy()

        if self.args.actor_loss_type == 'default':
            actor_loss = -(self.critic_network(inputs_norm_tensor,
                                               actions_real)).mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
        elif self.args.actor_loss_type == 'min':
            actor_loss = -(self.critic_network(
                inputs_norm_tensor, actions_real)).min(axis=1)[0].mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
        elif self.args.actor_loss_type == 'softmin':
            # print((self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).shape)
            actor_loss = -(self.critic_network.softmin_forward(
                inputs_norm_tensor, actions_real)).mean()
            # print(actor_loss.shape)
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
        elif self.args.actor_loss_type == 'random':
            # print((self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).shape)
            update_index = np.random.choice(self.env.num_reward,
                                            p=update_index_sampling_prob)
            actor_loss = -(self.critic_network[update_index](
                inputs_norm_tensor, actions_real)).mean()
            # print(actor_loss)
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            # print(actor_loss)
            # for i in range(self.env.num_reward):
            #     self.writer.add_scalar('update_probability/number_{}'.format(i), update_index_sampling_prob[i])

        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
Пример #4
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda(self.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device)
            actions_tensor = actions_tensor.cuda(self.device)
            r_tensor = r_tensor.cuda(self.device)
        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next = self.actor_target_network(inputs_next_norm_tensor)
            actions_next += self.args.noise_eps * self.env_params[
                'action_max'] * torch.randn(actions_next.shape).cuda(
                    self.device)
            actions_next = torch.clamp(actions_next,
                                       -self.env_params['action_max'],
                                       self.env_params['action_max'])
            q_next_value1 = self.critic_target_network1(
                inputs_next_norm_tensor, actions_next)
            q_next_value2 = self.critic_target_network2(
                inputs_next_norm_tensor, actions_next)
            target_q_value = r_tensor + self.args.gamma * torch.min(
                q_next_value1, q_next_value2)
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
            target_q_value = target_q_value.detach()
        # the q loss
        real_q_value1 = self.critic_network1(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss1 = (target_q_value - real_q_value1).pow(2).mean()
        real_q_value2 = self.critic_network2(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss2 = (target_q_value - real_q_value2).pow(2).mean()
        # the actor loss
        actions_real = self.actor_network(inputs_norm_tensor)
        actor_loss = -torch.min(
            self.critic_network1(inputs_norm_tensor, actions_real),
            self.critic_network2(inputs_norm_tensor, actions_real)).mean()
        actor_loss += self.args.action_l2 * (
            actions_real / self.env_params['action_max']).pow(2).mean()
        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
        # update the critic_network
        self.critic_optim1.zero_grad()
        critic_loss1.backward()
        sync_grads(self.critic_network1)
        self.critic_optim1.step()

        self.critic_optim2.zero_grad()
        critic_loss2.backward()
        sync_grads(self.critic_network2)
        self.critic_optim2.step()

        self.logger.store(LossPi=actor_loss.detach().cpu().numpy())
        self.logger.store(LossQ=(critic_loss1 +
                                 critic_loss2).detach().cpu().numpy())
Пример #5
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32).reshape(
            transitions['r'].shape[0], -1)
        # if self.args.scale_rewards:
        #     r_tensor = r_tensor/self.reward_scales
        #         print(r_tensor.shape)
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda()
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
            actions_tensor = actions_tensor.cuda()
            r_tensor = r_tensor.cuda()

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next = self.actor_target_network(inputs_next_norm_tensor)
            q_next_value = self.critic_target_network(inputs_next_norm_tensor,
                                                      actions_next)
            q_next_value = q_next_value.detach()
            # print('r_tensor_shape :', r_tensor.shape)
            # print('q_next_value :', q_next_value.shape)
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the q loss
        # if self.args.ddpg_vq_version=='ver3':
        #     real_q_value = self.critic_network.deep_forward(inputs_norm_tensor, actions_tensor)
        # else:
        real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)

        if self.args.critic_loss_type == 'MSE':
            critic_loss = (target_q_value - real_q_value).pow(2).mean()
        elif self.args.critic_loss_type == 'MAE':
            critic_loss = (target_q_value - real_q_value).abs().mean()

        actions_real = self.actor_network(inputs_norm_tensor)
        with torch.no_grad():
            each_reward = (self.critic_network(inputs_norm_tensor,
                                               actions_real)).mean(axis=0)
            # print(each_reward.shape)

        if self.args.actor_loss_type == 'sum':
            actor_loss = -(self.critic_network(inputs_norm_tensor,
                                               actions_real)).mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            update_index = 0
        elif self.args.actor_loss_type == 'smoothed_minmax':
            actor_loss = torch.exp(
                (self.args.softmax_temperature) *
                (-(self.critic_network(inputs_norm_tensor, actions_real)))
            ).mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            update_index = 0
        # elif self.args.actor_loss_type=='smoothed_minmax':
        #     actor_loss = torch.exp((self.args.softmax_temperature)*(-(self.critic_network(inputs_norm_tensor, actions_real)))).mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        #     update_index = None
        # elif self.args.actor_loss_type=='min':
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        #     update_index = None
        elif self.args.actor_loss_type == 'minmax':
            update_index = np.argmin((self.critic_network(
                inputs_norm_tensor,
                actions_real)).detach().cpu().numpy().mean(axis=0))
            # print(update_index)
            onehot = torch.zeros(self.env.num_reward)
            onehot[update_index] = 1.
            if self.args.cuda:
                onehot = onehot.cuda()
            actor_loss = (
                -(self.critic_network(inputs_norm_tensor, actions_real)) *
                onehot).mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()

        # elif self.args.actor_loss_type=='softmin':
        #     actor_loss = -(self.critic_network.softmin_forward(inputs_norm_tensor, actions_real)).min(axis=1)[0].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # elif self.args.actor_loss_type=='strict_random':
        #     update_index = np.random.choice(self.env.num_reward)
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean()
        #     # print(actor_loss)
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # elif self.args.actor_loss_type=='random':
        #     update_index_sampling_prob= F.softmin((self.critic_network(inputs_norm_tensor, actions_real).cpu().mean(axis=0))\
        #      /(self.args.softmax_temperature), dim=0).detach().cpu().numpy()

        #     update_index = np.random.choice(self.env.num_reward, p= update_index_sampling_prob)
        #     actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real))[:,update_index].mean()
        #     actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        elif self.args.actor_loss_type == 'prod':
            actor_loss = torch.prod(
                -self.critic_network(inputs_norm_tensor, actions_real),
                1).mean()
            actor_loss += self.args.action_l2 * (
                actions_real / self.env_params['action_max']).pow(2).mean()
            update_index = None

        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
        # update the critic_network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self.critic_network)
        self.critic_optim.step()

        return update_index, actor_loss, each_reward, critic_loss
Пример #6
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'],
                                dtype=torch.float32) * self.args.reward_scale
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda(self.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device)
            actions_tensor = actions_tensor.cuda(self.device)
            r_tensor = r_tensor.cuda(self.device)

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next, log_prob_actions_next = self.actor_network(
                inputs_next_norm_tensor)
            q_next_value1 = self.critic_target_network1(
                inputs_next_norm_tensor, actions_next).detach()
            q_next_value2 = self.critic_target_network2(
                inputs_next_norm_tensor, actions_next).detach()
            target_q_value = r_tensor + self.args.gamma * (
                torch.min(q_next_value1, q_next_value2) -
                self.alpha * log_prob_actions_next)
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the q loss
        real_q_value1 = self.critic_network1(inputs_norm_tensor,
                                             actions_tensor)
        real_q_value2 = self.critic_network2(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss1 = (target_q_value - real_q_value1).pow(2).mean()
        critic_loss2 = (target_q_value - real_q_value2).pow(2).mean()

        # the actor loss
        actions, log_prob_actions = self.actor_network(inputs_norm_tensor)
        log_prob_actions = log_prob_actions.mean()
        actor_loss = self.alpha * log_prob_actions - torch.min(
            self.critic_network1(inputs_norm_tensor, actions),
            self.critic_network2(inputs_norm_tensor, actions)).mean()

        # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()

        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
        # update the critic_network
        self.critic_optim1.zero_grad()
        critic_loss1.backward()
        sync_grads(self.critic_network1)
        self.critic_optim1.step()
        self.critic_optim2.zero_grad()
        critic_loss2.backward()
        sync_grads(self.critic_network2)
        self.critic_optim2.step()

        self.logger.store(LossPi=actor_loss.detach().cpu().numpy())
        self.logger.store(LossQ=(critic_loss1 +
                                 critic_loss2).detach().cpu().numpy())
        self.logger.store(Entropy=-log_prob_actions.detach().cpu().numpy())

        # auto temperature
        if self.args.alpha < 0:

            comm = MPI.COMM_WORLD
            log_prob_actions = log_prob_actions.detach().cpu().numpy()
            global_log_prob_actions = np.zeros_like(log_prob_actions)
            comm.Allreduce(log_prob_actions,
                           global_log_prob_actions,
                           op=MPI.SUM)
            global_log_prob_actions /= MPI.COMM_WORLD.Get_size()

            logalpha_loss = -self.log_alpha * (log_prob_actions +
                                               self.target_entropy)

            self.alpha_optim.zero_grad()
            logalpha_loss.backward()
            self.alpha_optim.step()
            with torch.no_grad():
                self.alpha = self.log_alpha.exp().detach()

        self.logger.store(alpha=self.alpha.detach().cpu().numpy())
Пример #7
0
    def _update_network(self):
        # sample the episodes
        batches = self.buffer.sample(self.args.batch_size)

        o = torch.FloatTensor(batches['obs']).to(self.device)
        o2 = torch.FloatTensor(batches['obs2']).to(self.device)
        a = torch.FloatTensor(batches['act']).to(self.device)
        r = torch.FloatTensor(batches['rew']).to(self.device)
        c = torch.FloatTensor(batches['cost']).to(self.device)
        d = torch.FloatTensor(batches['done']).to(self.device)

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            a2 = self.actor_network(o2)
            q_next_value1 = self.critic_target_network1(o2, a2).detach()
            q_next_value2 = self.critic_target_network2(o2, a2).detach()
            target_q_value = r + self.args.gamma * (1 - d) * torch.min(
                q_next_value1, q_next_value2)
            target_q_value = target_q_value.detach()

            p_next_value1 = self.advice_target_network1(o2, a2).detach()
            p_next_value2 = self.advice_target_network2(o2, a2).detach()
            target_p_value = -c + self.args.gamma * (1 - d) * torch.min(
                p_next_value1, p_next_value2)
            target_p_value = target_p_value.detach()

        # the q loss
        real_q_value1 = self.critic_network1(o, a)
        real_q_value2 = self.critic_network2(o, a)
        critic_loss1 = (target_q_value - real_q_value1).pow(2).mean()
        critic_loss2 = (target_q_value - real_q_value2).pow(2).mean()

        # the p loss
        real_p_value1 = self.advice_network1(o, a)
        real_p_value2 = self.advice_network2(o, a)
        advice_loss1 = (target_p_value - real_p_value1).pow(2).mean()
        advice_loss2 = (target_p_value - real_p_value2).pow(2).mean()

        # the actor loss
        o_exp = o.repeat(self.args.expand_batch, 1)
        a_exp = self.actor_network(o_exp)
        actor_loss = -torch.min(self.critic_network1(o_exp, a_exp),
                                self.critic_network2(o_exp, a_exp)).mean()
        actor_loss -= self.args.advice * torch.min(
            self.advice_network1(o_exp, a_exp),
            self.advice_network2(o_exp, a_exp)).mean()

        mmd_entropy = torch.tensor(0.0)

        if self.args.mmd:
            # mmd is computationally expensive
            a_exp_reshape = a_exp.view(self.args.expand_batch, -1,
                                       a_exp.shape[-1]).transpose(0, 1)
            with torch.no_grad():
                uniform_actions = (2 * torch.rand_like(a_exp_reshape) - 1)
            mmd_entropy = mmd(a_exp_reshape, uniform_actions)
            if self.args.beta_mmd <= 0.0:
                mmd_entropy.detach_()
            else:
                actor_loss += self.args.beta_mmd * mmd_entropy

        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
        # update the critic_network
        self.critic_optim1.zero_grad()
        critic_loss1.backward()
        sync_grads(self.critic_network1)
        self.critic_optim1.step()
        self.critic_optim2.zero_grad()
        critic_loss2.backward()
        sync_grads(self.critic_network2)
        self.critic_optim2.step()

        self.logger.store(LossPi=actor_loss.detach().cpu().numpy())
        self.logger.store(LossQ=(critic_loss1 +
                                 critic_loss2).detach().cpu().numpy())
        self.logger.store(MMDEntropy=mmd_entropy.detach().cpu().numpy())
    def _update_network(self, step, if_write=False):
        # # for agent 1, the protagonist agent

        # sample the episodes
        transitions = self.buffer_1.sample(
            self.args.batch_size)  # sample from the replay_buffer
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm_1.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm_1.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.to(self.args.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.to(
                self.args.device)
            actions_tensor = actions_tensor.to(self.args.device)
            r_tensor = r_tensor.to(self.args.device)
        # calculate the target Q value function
        with torch.no_grad(
        ):  # wrapped by this, the grad_fn don't track this part
            # do the normalization
            # concatenate the stuffs
            actions_next, next_state_logp, _ = self.actor_target_network_1.sample(
                inputs_next_norm_tensor)
            q_next_value = self.critic_target_network_1(
                inputs_next_norm_tensor, actions_next)
            q_next_value = q_next_value.detach()
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()

            q_next_value_1_2 = self.critic_target_network_1_2(
                inputs_next_norm_tensor, actions_next)
            q_next_value_1_2 = q_next_value_1_2.detach()
            target_q_value_1_2 = r_tensor + self.args.gamma * q_next_value_1_2
            target_q_value_1_2 = target_q_value_1_2.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            # make target_q_value between -clip_return and 0
            target_q_value = torch.min(
                target_q_value,
                target_q_value_1_2) - self.alpha_1 * next_state_logp
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)

        # the q loss
        real_q_value = self.critic_network_1(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()
        # the q loss
        real_q_value_1_2 = self.critic_network_1_2(inputs_norm_tensor,
                                                   actions_tensor)
        critic_loss_1_2 = (target_q_value_1_2 - real_q_value_1_2).pow(2).mean()
        # the actor loss
        actions_real, logp, _ = self.actor_network_1.sample(inputs_norm_tensor)
        actor_loss = (self.alpha_1 * logp - torch.min(
            self.critic_network_1(inputs_norm_tensor, actions_real),
            self.critic_network_1_2(inputs_norm_tensor, actions_real))).mean()
        actor_loss += self.args.action_l2 * (
            actions_real / self.env_params['action_max']).pow(2).mean()
        alpha_loss_1 = -(self.log_alpha_1 *
                         (logp + self.target_entropy_1).detach()).mean()

        self.alpha_optim_1.zero_grad()
        alpha_loss_1.backward()
        sync_grads_for_tensor(self.log_alpha_1)
        self.alpha_optim_1.step()
        self.alpha_1 = self.log_alpha_1.exp()

        # start to update the network
        self.actor_optim_1.zero_grad()
        actor_loss.backward()

        # # clip the grad
        # clip_grad_norm_(self.actor_network_1.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.actor_network_1)
        self.actor_optim_1.step()
        # update the critic_network
        self.critic_optim_1.zero_grad()
        critic_loss.backward()

        # # clip the grad
        # clip_grad_norm_(self.critic_network_1.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.critic_network_1)
        self.critic_optim_1.step()

        self.critic_optim_1_2.zero_grad()
        critic_loss_1_2.backward()

        # # clip the grad
        # clip_grad_norm_(self.critic_network_1_2.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.critic_network_1_2)
        self.critic_optim_1_2.step()
        if self.writer is not None and if_write:
            self.writer.add_scalar('actor_loss_1', actor_loss, step)
            self.writer.add_scalar('critic_loss_1', critic_loss, step)
        """"""
        # # for agent 2, the adversary agent

        # sample the episodes
        transitions = self.buffer_2.sample(
            self.args.batch_size)  # sample from the replay_buffer
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm_2.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm_2.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.to(self.args.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.to(
                self.args.device)
            actions_tensor = actions_tensor.to(self.args.device)
            r_tensor = r_tensor.to(self.args.device)
        # calculate the target Q value function
        with torch.no_grad(
        ):  # wrapped by this, the grad_fn don't track this part
            # do the normalization
            # concatenate the stuffs
            actions_next, next_state_logp, _ = self.actor_target_network_2.sample(
                inputs_next_norm_tensor)
            q_next_value = self.critic_target_network_2(
                inputs_next_norm_tensor, actions_next)
            q_next_value = q_next_value.detach()
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()

            q_next_value_2_2 = self.critic_target_network_2_2(
                inputs_next_norm_tensor, actions_next)
            q_next_value_2_2 = q_next_value_2_2.detach()
            target_q_value_2_2 = r_tensor + self.args.gamma * q_next_value_2_2
            target_q_value_2_2 = target_q_value_2_2.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            # make target_q_value between -clip_return and 0
            target_q_value = torch.min(
                target_q_value,
                target_q_value_2_2) - self.alpha_2 * next_state_logp
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)

        # the q loss
        real_q_value = self.critic_network_2(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()
        # the q loss
        real_q_value_2_2 = self.critic_network_2_2(inputs_norm_tensor,
                                                   actions_tensor)
        critic_loss_2_2 = (target_q_value_2_2 - real_q_value_2_2).pow(2).mean()
        # the actor loss
        actions_real, logp, _ = self.actor_network_2.sample(inputs_norm_tensor)
        actor_loss = (self.alpha_2 * logp + torch.min(
            self.critic_network_2(inputs_norm_tensor, actions_real),
            self.critic_network_2_2(inputs_norm_tensor, actions_real))
                      ).mean()  # for adversary agent
        actor_loss += self.args.action_l2 * (
            actions_real / self.env_params['action_max']).pow(2).mean()

        alpha_loss_2 = -(self.log_alpha_2 *
                         (logp + self.target_entropy_2).detach()).mean()

        self.alpha_optim_2.zero_grad()
        alpha_loss_2.backward()
        sync_grads_for_tensor(self.log_alpha_2)
        self.alpha_optim_2.step()
        self.alpha_2 = self.log_alpha_2.exp()

        # start to update the network
        self.actor_optim_2.zero_grad()
        actor_loss.backward()

        # # clip grad loss
        # clip_grad_norm_(self.actor_network_2.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.actor_network_2)
        self.actor_optim_2.step()
        # update the critic_network
        self.critic_optim_2.zero_grad()
        critic_loss.backward()

        # # clip grad loss
        # clip_grad_norm_(self.critic_network_2.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.critic_network_2)
        self.critic_optim_2.step()

        self.critic_optim_2_2.zero_grad()
        critic_loss_2_2.backward()

        # # clip grad loss
        # clip_grad_norm_(self.critic_network_2_2.parameters(), max_norm=20, norm_type=2)

        sync_grads(self.critic_network_2_2)
        self.critic_optim_2_2.step()
        if self.writer is not None and if_write:
            self.writer.add_scalar('actor_loss_2', actor_loss, step)
            self.writer.add_scalar('critic_loss_2', critic_loss, step)
Пример #9
0
    def _update_network(self):
        # sample the episodes
        transitions = self.buffer.sample(self.args.batch_size)
        # pre-process the observation and goal
        o, o_next, g = transitions['obs'], transitions[
            'obs_next'], transitions['g']
        transitions['obs'], transitions['g'] = self._preproc_og(o, g)
        transitions['obs_next'], transitions['g_next'] = self._preproc_og(
            o_next, g)
        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                               dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'],
                                      dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'],
                                dtype=torch.float32) * self.args.reward_scale
        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda(self.device)
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda(self.device)
            actions_tensor = actions_tensor.cuda(self.device)
            r_tensor = r_tensor.cuda(self.device)

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next = self.actor_network(inputs_next_norm_tensor)
            q_next_value1 = self.critic_target_network1(
                inputs_next_norm_tensor, actions_next).detach()
            q_next_value2 = self.critic_target_network2(
                inputs_next_norm_tensor, actions_next).detach()
            target_q_value = r_tensor + self.args.gamma * torch.min(
                q_next_value1, q_next_value2)
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the q loss
        real_q_value1 = self.critic_network1(inputs_norm_tensor,
                                             actions_tensor)
        real_q_value2 = self.critic_network2(inputs_norm_tensor,
                                             actions_tensor)
        critic_loss1 = (target_q_value - real_q_value1).pow(2).mean()
        critic_loss2 = (target_q_value - real_q_value2).pow(2).mean()

        # the actor loss
        exp_inputs_norm_tensor = inputs_norm_tensor.repeat(
            self.args.expand_batch, 1)
        exp_actions_real = self.actor_network(exp_inputs_norm_tensor)
        actor_loss = -torch.min(
            self.critic_network1(exp_inputs_norm_tensor, exp_actions_real),
            self.critic_network2(exp_inputs_norm_tensor,
                                 exp_actions_real)).mean()

        mmd_entropy = torch.tensor(0.0)

        if self.args.mmd:
            # mmd is computationally expensive
            exp_actions_real2 = exp_actions_real.view(
                self.args.expand_batch, -1,
                exp_actions_real.shape[-1]).transpose(0, 1)
            with torch.no_grad():
                uniform_actions = (2 * torch.rand_like(exp_actions_real2) - 1)
            mmd_entropy = mmd(exp_actions_real2, uniform_actions)
            if self.args.beta_mmd <= 0.0:
                mmd_entropy.detach_()
            else:
                actor_loss += self.args.beta_mmd * mmd_entropy

        # actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()
        # update the critic_network
        self.critic_optim1.zero_grad()
        critic_loss1.backward()
        sync_grads(self.critic_network1)
        self.critic_optim1.step()
        self.critic_optim2.zero_grad()
        critic_loss2.backward()
        sync_grads(self.critic_network2)
        self.critic_optim2.step()

        self.logger.store(LossPi=actor_loss.detach().cpu().numpy())
        self.logger.store(LossQ=(critic_loss1 +
                                 critic_loss2).detach().cpu().numpy())
        self.logger.store(MMDEntropy=mmd_entropy.detach().cpu().numpy())
Пример #10
0
def update_disentangled(actor_network, critic_network, critic_target_network,
                        configuration_network, policy_optim, critic_optim,
                        alpha, log_alpha, target_entropy, alpha_optim,
                        obs_norm, ag_norm, g_norm, obs_next_norm, ag_next_norm,
                        g_next_norm, actions, rewards, args):
    obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32)
    obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32)
    ag_norm_tensor = torch.tensor(ag_norm, dtype=torch.float32)
    ag_next_norm_tensor = torch.tensor(ag_next_norm, dtype=torch.float32)
    g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32)
    g_next_norm_tensor = torch.tensor(g_next_norm, dtype=torch.float32)

    actions_tensor = torch.tensor(actions, dtype=torch.float32)
    r_tensor = torch.tensor(rewards,
                            dtype=torch.float32).reshape(rewards.shape[0], 1)

    if args.cuda:
        #inputs_norm_tensor = inputs_norm_tensor.cuda()
        #inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
        actions_tensor = actions_tensor.cuda()
        r_tensor = r_tensor.cuda()

    config_ag_z = configuration_network(ag_norm_tensor)
    config_ag_z_next = configuration_network(ag_next_norm_tensor)
    config_g_z = configuration_network(g_norm_tensor)
    config_g_z_next = configuration_network(g_next_norm_tensor)

    with torch.no_grad():
        # do the normalization
        # concatenate the stuffs
        inputs_norm_tensor = torch.tensor(np.concatenate(
            [obs_norm, config_ag_z.detach(),
             config_g_z.detach()], axis=1),
                                          dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(np.concatenate([
            obs_next_norm,
            config_ag_z_next.detach(),
            config_g_z_next.detach()
        ],
                                                              axis=1),
                                               dtype=torch.float32)
        actions_next, log_pi_next, _ = actor_network.sample(
            inputs_next_norm_tensor)
        qf1_next_target, qf2_next_target = critic_target_network(
            obs_next_norm_tensor, actions_next, config_ag_z_next.detach(),
            config_g_z_next.detach())
        min_qf_next_target = torch.min(qf1_next_target,
                                       qf2_next_target) - alpha * log_pi_next
        next_q_value = r_tensor + args.gamma * min_qf_next_target
        # clip the q value
        """clip_return = 1 / (1 - args.gamma)
        next_q_value = torch.clamp(next_q_value, 0, clip_return)"""

    # the q loss
    qf1, qf2 = critic_network(obs_norm_tensor, actions_tensor, config_ag_z,
                              config_g_z)
    qf1_loss = F.mse_loss(qf1, next_q_value)
    qf2_loss = F.mse_loss(qf2, next_q_value)

    # the actor loss
    pi, log_pi, _ = actor_network.sample(inputs_norm_tensor)
    qf1_pi, qf2_pi = critic_network(obs_norm_tensor, pi, config_ag_z,
                                    config_g_z)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((alpha * log_pi) - min_qf_pi).mean()

    # start to update the network
    policy_optim.zero_grad()
    policy_loss.backward(retain_graph=True)
    sync_grads(actor_network)
    policy_optim.step()

    # update the critic_network
    configuration_network.zero_grad()
    critic_optim.zero_grad()
    qf1_loss.backward(retain_graph=True)
    sync_grads(critic_network)
    critic_optim.step()

    critic_optim.zero_grad()
    qf2_loss.backward()
    sync_grads(critic_network)
    critic_optim.step()

    # configuration_optim.step()
    sync_grads(configuration_network)

    # configuration_optim.step()

    alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy,
                                             log_pi, alpha_optim, args)

    return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
    ), alpha_loss.item(), alpha_tlogs.item()
Пример #11
0
def update_flat(actor_network, critic_network, critic_target_network,
                policy_optim, critic_optim, alpha, log_alpha, target_entropy,
                alpha_optim, obs_norm, ag_norm, g_norm, obs_next_norm, actions,
                rewards, args):
    inputs_norm = np.concatenate([obs_norm, ag_norm, g_norm], axis=1)
    inputs_next_norm = np.concatenate([obs_next_norm, ag_norm, g_norm], axis=1)

    # Tensorize
    inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
    inputs_next_norm_tensor = torch.tensor(inputs_next_norm,
                                           dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.float32)
    r_tensor = torch.tensor(rewards,
                            dtype=torch.float32).reshape(rewards.shape[0], 1)

    if args.cuda:
        inputs_norm_tensor = inputs_norm_tensor.cuda()
        inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
        actions_tensor = actions_tensor.cuda()
        r_tensor = r_tensor.cuda()

    with torch.no_grad():
        actions_next, log_pi_next, _ = actor_network.sample(
            inputs_next_norm_tensor)
        qf1_next_target, qf2_next_target = critic_target_network(
            inputs_next_norm_tensor, actions_next)
        min_qf_next_target = torch.min(qf1_next_target,
                                       qf2_next_target) - alpha * log_pi_next
        next_q_value = r_tensor + args.gamma * min_qf_next_target

    # the q loss
    qf1, qf2 = critic_network(inputs_norm_tensor, actions_tensor)
    qf1_loss = F.mse_loss(qf1, next_q_value)
    qf2_loss = F.mse_loss(qf2, next_q_value)

    # the actor loss
    pi, log_pi, _ = actor_network.sample(inputs_norm_tensor)
    qf1_pi, qf2_pi = critic_network(inputs_norm_tensor, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((alpha * log_pi) - min_qf_pi).mean()

    # start to update the network
    policy_optim.zero_grad()
    policy_loss.backward()
    sync_grads(actor_network)
    policy_optim.step()

    # update the critic_network
    critic_optim.zero_grad()
    qf1_loss.backward()
    sync_grads(critic_network)
    critic_optim.step()

    critic_optim.zero_grad()
    qf2_loss.backward()
    sync_grads(critic_network)
    critic_optim.step()

    alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy,
                                             log_pi, alpha_optim, args)

    return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
    ), alpha_loss.item(), alpha_tlogs.item()
Пример #12
0
def update_deepsets(model, language, policy_optim, critic_optim, alpha,
                    log_alpha, target_entropy, alpha_optim, obs_norm, ag_norm,
                    g_norm, obs_next_norm, ag_next_norm, anchor_g, actions,
                    rewards, language_goals, args):
    # Tensorize
    obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32)
    obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32)
    if language:
        g_norm_tensor = g_norm
    else:
        g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32)
    ag_norm_tensor = torch.tensor(ag_norm, dtype=torch.float32)
    ag_next_norm_tensor = torch.tensor(ag_next_norm, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.float32)
    r_tensor = torch.tensor(rewards,
                            dtype=torch.float32).reshape(rewards.shape[0], 1)

    anchor_g_tensor = torch.tensor(anchor_g, dtype=torch.float32)

    if args.cuda:
        obs_norm_tensor = obs_norm_tensor.cuda()
        obs_next_norm_tensor = obs_next_norm_tensor.cuda()
        g_norm_tensor = g_norm_tensor.cuda()
        ag_norm_tensor = ag_norm_tensor.cuda()
        ag_next_norm_tensor = ag_next_norm_tensor.cuda()
        actions_tensor = actions_tensor.cuda()
        r_tensor = r_tensor.cuda()

    with torch.no_grad():
        if args.algo == 'language':
            model.forward_pass(obs_next_norm_tensor,
                               language_goals=language_goals)
        elif args.algo == 'continuous':
            model.forward_pass(obs_next_norm_tensor, ag_next_norm_tensor,
                               g_norm_tensor)
        else:
            model.forward_pass(obs_next_norm_tensor, ag_next_norm_tensor,
                               g_norm_tensor, anchor_g_tensor)
        actions_next, log_pi_next = model.pi_tensor, model.log_prob
        qf1_next_target, qf2_next_target = model.target_q1_pi_tensor, model.target_q2_pi_tensor
        min_qf_next_target = torch.min(qf1_next_target,
                                       qf2_next_target) - alpha * log_pi_next
        next_q_value = r_tensor + args.gamma * min_qf_next_target

    # the q loss
    if args.algo == 'language':
        qf1, qf2 = model.forward_pass(obs_norm_tensor,
                                      actions=actions_tensor,
                                      language_goals=language_goals)
    elif args.algo == 'continuous':
        qf1, qf2 = model.forward_pass(obs_norm_tensor,
                                      ag_norm_tensor,
                                      g_norm_tensor,
                                      actions=actions_tensor)
    else:
        qf1, qf2 = model.forward_pass(obs_norm_tensor,
                                      ag_norm_tensor,
                                      g_norm_tensor,
                                      anchor_g_tensor,
                                      actions=actions_tensor)
    qf1_loss = F.mse_loss(qf1, next_q_value)
    qf2_loss = F.mse_loss(qf2, next_q_value)
    qf_loss = qf1_loss + qf2_loss

    # the actor loss
    pi, log_pi = model.pi_tensor, model.log_prob
    qf1_pi, qf2_pi = model.q1_pi_tensor, model.q2_pi_tensor
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((alpha * log_pi) - min_qf_pi).mean()

    # start to update the network
    policy_optim.zero_grad()
    policy_loss.backward(retain_graph=True)
    sync_grads(model.actor)
    policy_optim.step()

    # update the critic_network
    critic_optim.zero_grad()
    qf_loss.backward()
    sync_grads(model.critic)
    critic_optim.step()

    alpha_loss, alpha_tlogs = update_entropy(alpha, log_alpha, target_entropy,
                                             log_pi, alpha_optim, args)

    return qf1_loss.item(), qf2_loss.item(), policy_loss.item(
    ), alpha_loss.item(), alpha_tlogs.item()