Example #1
0
    def add_expert_path(self, expert_paths):
        expert_obs, expert_act = expert_paths.get(('obs', 'action'))

        # Create the torch variables

        expert_obs_var = obs_to_torch(expert_obs, device=self.device)

        if self.discrete:
            # index is used for policy log_prob and for multi_head discriminator
            expert_act_index = expert_act.astype(int)
            expert_act_index_var = to_torch(expert_act_index, device=self.device)

            # one-hot is used with single head discriminator
            if (not self.discriminator.use_multi_head):
                expert_act = onehot_from_index(expert_act_index, self.action_dim)
                expert_act_var = to_torch(expert_act, device=self.device)

            else:
                expert_act_var = expert_act_index_var
        else:
            # there is no index actions for continuous control so index action and normal actions are the same variable
            expert_act_var = to_torch(expert_act, device=self.device)

            expert_act_index_var = expert_act_var

        self.expert_var = (expert_obs_var, expert_act_var)
        self.expert_act_index_var = expert_act_index_var
        self.n_expert = len(expert_obs)
Example #2
0
    def add_expert_path(self, expert_paths):
        expert_obs, expert_act = expert_paths.get(('obs', 'action'))

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation

            expert_act = onehot_from_index(expert_act.astype(int), self.action_dim)

        # create the torch variables
        self.data_e = (
            obs_to_torch(expert_obs, device=self.device),
            to_torch(expert_act, device=self.device),
        )
        self.n_expert = len(expert_obs)
Example #3
0
    def update_reward(self, experiences, policy, ent_wt):
        (obs, action, next_obs, g_reward, mask) = experiences
        obs_var = obs_to_torch(obs, device=self.device)

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation
            action_idx = action.astype(int)
            action = onehot_from_index(action_idx, self.action_dim)
        else:
            action_idx = action

        log_pi_list = policy.get_log_prob_from_obs_action_pairs(action=to_torch(action_idx, device=self.device),
                                                                obs=obs_var).detach()
        reward = to_numpy(self.get_reward(obs=obs_var,
                                          action=to_torch(action, device=self.device),
                                          log_pi_a=log_pi_list,
                                          ent_wt=ent_wt).squeeze().detach())

        return (obs, action_idx, next_obs, reward, mask)
Example #4
0
    def update_reward(self, experiences, policy, ent_wt):
        (obs, action, next_obs, g_reward, mask) = experiences
        obs_var = obs_to_torch(obs, device=self.device)
        if self.discrete:
            if (not self.discriminator.use_multi_head):
                action_idx = action.astype(int)
                action = onehot_from_index(action_idx, self.action_dim)
            else:
                action_idx = action.astype(int)
                action = action_idx
        else:
            action_idx = action

        log_pi_list = policy.get_log_prob_from_obs_action_pairs(action=to_torch(action_idx, device=self.device),
                                                                obs=obs_var).detach()
        reward = to_numpy(self.get_reward(obs=obs_var,
                                          action=to_torch(action, device=self.device),
                                          log_pi_a=log_pi_list,
                                          ent_wt=ent_wt).squeeze().detach())

        return (obs, action_idx, next_obs, reward, mask)
Example #5
0
    def fit(self, data, batch_size, policy, n_epochs_per_update, logger, **kwargs):
        """
        Train the Discriminator to distinguish expert from learner.
        """
        obs, act = data[0], data[1]

        # Create the torch variables

        obs_var = obs_to_torch(obs, device=self.device)

        if self.discrete:
            # index is used for policy log_prob and for multi_head discriminator
            act_index = act.astype(int)
            act_index_var = to_torch(act_index, device=self.device)

            # one-hot is used with single head discriminator
            if (not self.discriminator.use_multi_head):
                act = onehot_from_index(act_index, self.action_dim)
                act_var = to_torch(act, device=self.device)

            else:
                act_var = act_index_var
        else:
            # there is no index actions for continuous control index so action and normal actions are the same variable
            act_var = to_torch(act, device=self.device)
            act_index_var = act_var

        expert_obs_var, expert_act_var = self.expert_var
        expert_act_index_var = self.expert_act_index_var

        # Eval the prob of the transition under current policy
        # The result will be fill in part to the discriminator, no grad because if policy is discriminator as for ASQF
        # we do not want gradient passing
        with torch.no_grad():
            trans_log_probas = policy.get_log_prob_from_obs_action_pairs(obs=obs_var, action=act_index_var)
            expert_log_probas = policy.get_log_prob_from_obs_action_pairs(obs=expert_obs_var,
                                                                          action=expert_act_index_var)

        n_trans = len(obs)
        n_expert = self.n_expert

        # Train discriminator
        for it_update in TrainingIterator(n_epochs_per_update):
            shuffled_idxs_trans = torch.randperm(n_trans, device=self.device)

            for i, it_batch in enumerate(TrainingIterator(n_trans // batch_size)):

                # the epoch is defined on the collected transition data and not on the expert data

                batch_idxs_trans = shuffled_idxs_trans[batch_size * i: batch_size * (i + 1)]
                batch_idxs_expert = torch.tensor(random.sample(range(n_expert), k=batch_size), device=self.device)

                # lprobs_batch is the prob of obs and act under current policy

                obs_batch = obs_var.get_from_index(batch_idxs_trans)
                act_batch = act_var[batch_idxs_trans]
                lprobs_batch = trans_log_probas[batch_idxs_trans]

                # expert_lprobs_batch is the experts' obs and act under current policy

                expert_obs_batch = expert_obs_var.get_from_index(batch_idxs_expert)
                expert_act_batch = expert_act_var[batch_idxs_expert]
                expert_lprobs_batch = expert_log_probas[batch_idxs_expert]

                labels = torch.zeros((batch_size * 2, 1), device=self.device)
                labels[batch_size:] = 1.0  # expert is one
                total_obs_batch = torch_cat_obs([obs_batch, expert_obs_batch], dim=0)
                total_act_batch = torch.cat([act_batch, expert_act_batch], dim=0)

                total_lprobs_batch = torch.cat([lprobs_batch, expert_lprobs_batch], dim=0)

                loss = self.discriminator.get_classification_loss(obs=total_obs_batch, action=total_act_batch,
                                                                  log_pi_a=total_lprobs_batch, target=labels)

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_norm_clip)
                self.optimizer.step()

                it_update.record('d_loss', loss.cpu().data.numpy())

        return_dict = {}

        return_dict.update(it_update.pop_all_means())

        return return_dict
Example #6
0
    def fit(self, data, batch_size, n_epochs_per_update, logger, **kwargs):
        """
        Train the discriminator to distinguish expert from learner.
        """
        agent_obs, agent_act = data[0], data[1]

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation

            agent_act = onehot_from_index(agent_act.astype(int), self.action_dim)

        assert self.data_e is not None
        assert self.n_expert is not None

        # create the torch variables

        agent_obs_var = obs_to_torch(agent_obs, device=self.device)
        expert_obs_var = self.data_e[0]

        act_var = to_torch(agent_act, device=self.device)
        expert_act_var = self.data_e[1]

        # Train discriminator for n_epochs_per_update

        n_trans = len(agent_obs)
        n_expert = self.n_expert

        for it_update in TrainingIterator(n_epochs_per_update):  # epoch loop

            shuffled_idxs_trans = torch.randperm(n_trans, device=self.device)
            for i, it_batch in enumerate(TrainingIterator(n_trans // batch_size)):  # mini-bathc loop

                # the epoch is defined on the collected transition data and not on the expert data

                batch_idxs_trans = shuffled_idxs_trans[batch_size * i: batch_size * (i + 1)]
                batch_idxs_expert = torch.tensor(random.sample(range(n_expert), k=batch_size), device=self.device)

                # get mini-batch of agent transitions

                obs_batch = agent_obs_var.get_from_index(batch_idxs_trans)
                act_batch = act_var[batch_idxs_trans]

                # get mini-batch of expert transitions

                expert_obs_batch = expert_obs_var.get_from_index(batch_idxs_expert)
                expert_act_batch = expert_act_var[batch_idxs_expert]

                labels = torch.zeros((batch_size * 2, 1), device=self.device)
                labels[batch_size:] = 1.0  # expert is one
                total_obs_batch = torch_cat_obs([obs_batch, expert_obs_batch], dim=0)
                total_act_batch = torch.cat([act_batch, expert_act_batch], dim=0)

                loss = self.discriminator.get_classification_loss(obs=total_obs_batch, action=total_act_batch,
                                                                           target=labels)

                if self.gradient_penalty_coef != 0.0:
                    grad_penalty = self.discriminator.get_grad_penality(
                        obs_e=expert_obs_batch, obs_l=obs_batch, act_e=expert_act_batch, act_l=act_batch,
                        gradient_penalty_coef=self.gradient_penalty_coef
                    )
                    loss += grad_penalty

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_norm_clip)
                self.optimizer.step()

                it_update.record('d_loss', loss.cpu().data.numpy())

        return_dict = {}
        return_dict.update(it_update.pop_all_means())

        return return_dict