Exemplo n.º 1
0
    def add_expert_path(self, expert_paths):
        expert_obs, expert_act = expert_paths.get(('obs', 'action'))

        # Create the torch variables

        expert_obs_var = obs_to_torch(expert_obs, device=self.device)

        if self.discrete:
            # index is used for policy log_prob and for multi_head discriminator
            expert_act_index = expert_act.astype(int)
            expert_act_index_var = to_torch(expert_act_index, device=self.device)

            # one-hot is used with single head discriminator
            if (not self.discriminator.use_multi_head):
                expert_act = onehot_from_index(expert_act_index, self.action_dim)
                expert_act_var = to_torch(expert_act, device=self.device)

            else:
                expert_act_var = expert_act_index_var
        else:
            # there is no index actions for continuous control so index action and normal actions are the same variable
            expert_act_var = to_torch(expert_act, device=self.device)

            expert_act_index_var = expert_act_var

        self.expert_var = (expert_obs_var, expert_act_var)
        self.expert_act_index_var = expert_act_index_var
        self.n_expert = len(expert_obs)
Exemplo n.º 2
0
 def act(self, obs, sample):
     q_s = self(obs_to_torch(obs, unsqueeze_dim=0, device=self.device))
     return int(
         CategoricalPolicy.act_from_logits(
             logits=q_s / self.log_alpha.value.exp(),
             sample=sample,
             return_log_pi=False).data.cpu().numpy())
Exemplo n.º 3
0
    def add_expert_path(self, expert_paths):
        self.expert_paths = expert_paths
        expert_obs, expert_act = self.expert_paths.get(('obs', 'action'))
        self.n_tot = len(expert_act)

        obs = obs_to_torch(expert_obs, device=self.device)
        act = to_torch(expert_act, device=self.device)

        if self.use_validation:
            split_prop = 0.7
            shuffle = True
        else:
            split_prop = 1.
            shuffle = False

        self.n_train = int(split_prop * self.n_tot)

        shuffling_idxs = torch.randperm(self.n_tot) if shuffle else torch.arange(0, self.n_tot)
        obs = obs.get_from_index(shuffling_idxs)
        act = act[shuffling_idxs]
        train_idx = shuffling_idxs[0:self.n_train]
        valid_idx = shuffling_idxs[self.n_train:self.n_tot]
        self.train_obs = obs.get_from_index(train_idx)
        self.train_act = act[train_idx]

        self.valid_obs = obs.get_from_index(valid_idx)
        self.valid_act = act[valid_idx]
Exemplo n.º 4
0
 def add_expert_path(self, expert_paths):
     obs, act, mask = expert_paths.get(['obs', 'action', 'mask'])
     self.data_e = (
         obs_to_torch(obs, device=self.device),
         to_torch(act, device=self.device),
         mask,
     )
Exemplo n.º 5
0
 def act(self, obs, sample):
     obs = obs_to_torch(obs, unsqueeze_dim=0, device=self.device)
     action = self.actor.act(obs, sample=sample,
                             return_log_pi=False)[0].data.cpu().numpy()
     if self.discrete:
         return int(action)
     else:
         return action
Exemplo n.º 6
0
    def act(self, obs, sample, **kwargs):
        obs = obs_to_torch(obs, device=self.device, unsqueeze_dim=0)

        if self.discrete:
            action = self.classifier.act(obs, sample=sample, return_log_pi=False)[0].cpu().numpy()
            action = int(action)
        else:
            action = self.classifier.act(obs)[0].cpu().numpy()

        return action
Exemplo n.º 7
0
    def add_expert_path(self, expert_paths):
        expert_obs, expert_act = expert_paths.get(('obs', 'action'))

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation

            expert_act = onehot_from_index(expert_act.astype(int), self.action_dim)

        # create the torch variables
        self.data_e = (
            obs_to_torch(expert_obs, device=self.device),
            to_torch(expert_act, device=self.device),
        )
        self.n_expert = len(expert_obs)
Exemplo n.º 8
0
    def update_reward(self, experiences, policy, ent_wt):
        (obs, action, next_obs, g_reward, mask) = experiences
        obs_var = obs_to_torch(obs, device=self.device)

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation
            action_idx = action.astype(int)
            action = onehot_from_index(action_idx, self.action_dim)
        else:
            action_idx = action

        log_pi_list = policy.get_log_prob_from_obs_action_pairs(action=to_torch(action_idx, device=self.device),
                                                                obs=obs_var).detach()
        reward = to_numpy(self.get_reward(obs=obs_var,
                                          action=to_torch(action, device=self.device),
                                          log_pi_a=log_pi_list,
                                          ent_wt=ent_wt).squeeze().detach())

        return (obs, action_idx, next_obs, reward, mask)
Exemplo n.º 9
0
    def update_reward(self, experiences, policy, ent_wt):
        (obs, action, next_obs, g_reward, mask) = experiences
        obs_var = obs_to_torch(obs, device=self.device)
        if self.discrete:
            if (not self.discriminator.use_multi_head):
                action_idx = action.astype(int)
                action = onehot_from_index(action_idx, self.action_dim)
            else:
                action_idx = action.astype(int)
                action = action_idx
        else:
            action_idx = action

        log_pi_list = policy.get_log_prob_from_obs_action_pairs(action=to_torch(action_idx, device=self.device),
                                                                obs=obs_var).detach()
        reward = to_numpy(self.get_reward(obs=obs_var,
                                          action=to_torch(action, device=self.device),
                                          log_pi_a=log_pi_list,
                                          ent_wt=ent_wt).squeeze().detach())

        return (obs, action_idx, next_obs, reward, mask)
Exemplo n.º 10
0
    def fit(self, data, batch_size, policy, n_epochs_per_update, logger, **kwargs):
        """
        Train the Discriminator to distinguish expert from learner.
        """
        obs, act = data[0], data[1]

        # Create the torch variables

        obs_var = obs_to_torch(obs, device=self.device)

        if self.discrete:
            # index is used for policy log_prob and for multi_head discriminator
            act_index = act.astype(int)
            act_index_var = to_torch(act_index, device=self.device)

            # one-hot is used with single head discriminator
            if (not self.discriminator.use_multi_head):
                act = onehot_from_index(act_index, self.action_dim)
                act_var = to_torch(act, device=self.device)

            else:
                act_var = act_index_var
        else:
            # there is no index actions for continuous control index so action and normal actions are the same variable
            act_var = to_torch(act, device=self.device)
            act_index_var = act_var

        expert_obs_var, expert_act_var = self.expert_var
        expert_act_index_var = self.expert_act_index_var

        # Eval the prob of the transition under current policy
        # The result will be fill in part to the discriminator, no grad because if policy is discriminator as for ASQF
        # we do not want gradient passing
        with torch.no_grad():
            trans_log_probas = policy.get_log_prob_from_obs_action_pairs(obs=obs_var, action=act_index_var)
            expert_log_probas = policy.get_log_prob_from_obs_action_pairs(obs=expert_obs_var,
                                                                          action=expert_act_index_var)

        n_trans = len(obs)
        n_expert = self.n_expert

        # Train discriminator
        for it_update in TrainingIterator(n_epochs_per_update):
            shuffled_idxs_trans = torch.randperm(n_trans, device=self.device)

            for i, it_batch in enumerate(TrainingIterator(n_trans // batch_size)):

                # the epoch is defined on the collected transition data and not on the expert data

                batch_idxs_trans = shuffled_idxs_trans[batch_size * i: batch_size * (i + 1)]
                batch_idxs_expert = torch.tensor(random.sample(range(n_expert), k=batch_size), device=self.device)

                # lprobs_batch is the prob of obs and act under current policy

                obs_batch = obs_var.get_from_index(batch_idxs_trans)
                act_batch = act_var[batch_idxs_trans]
                lprobs_batch = trans_log_probas[batch_idxs_trans]

                # expert_lprobs_batch is the experts' obs and act under current policy

                expert_obs_batch = expert_obs_var.get_from_index(batch_idxs_expert)
                expert_act_batch = expert_act_var[batch_idxs_expert]
                expert_lprobs_batch = expert_log_probas[batch_idxs_expert]

                labels = torch.zeros((batch_size * 2, 1), device=self.device)
                labels[batch_size:] = 1.0  # expert is one
                total_obs_batch = torch_cat_obs([obs_batch, expert_obs_batch], dim=0)
                total_act_batch = torch.cat([act_batch, expert_act_batch], dim=0)

                total_lprobs_batch = torch.cat([lprobs_batch, expert_lprobs_batch], dim=0)

                loss = self.discriminator.get_classification_loss(obs=total_obs_batch, action=total_act_batch,
                                                                  log_pi_a=total_lprobs_batch, target=labels)

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_norm_clip)
                self.optimizer.step()

                it_update.record('d_loss', loss.cpu().data.numpy())

        return_dict = {}

        return_dict.update(it_update.pop_all_means())

        return return_dict
Exemplo n.º 11
0
 def act(self, obs, sample):
     obs = obs_to_torch(obs, device=self.device, unsqueeze_dim=0)
     act = self.discriminator.pi.act(obs, sample,
                                     return_log_pi=False).detach()
     return act[0].cpu().numpy()
Exemplo n.º 12
0
    def fit(self, data, batch_size, n_epochs_per_update, logger, **kwargs):
        """
        Train the discriminator to distinguish expert from learner.
        """
        agent_obs, agent_act = data[0], data[1]

        if isinstance(self.act_space, Discrete):
            # convert actions from integer representation to one-hot representation

            agent_act = onehot_from_index(agent_act.astype(int), self.action_dim)

        assert self.data_e is not None
        assert self.n_expert is not None

        # create the torch variables

        agent_obs_var = obs_to_torch(agent_obs, device=self.device)
        expert_obs_var = self.data_e[0]

        act_var = to_torch(agent_act, device=self.device)
        expert_act_var = self.data_e[1]

        # Train discriminator for n_epochs_per_update

        n_trans = len(agent_obs)
        n_expert = self.n_expert

        for it_update in TrainingIterator(n_epochs_per_update):  # epoch loop

            shuffled_idxs_trans = torch.randperm(n_trans, device=self.device)
            for i, it_batch in enumerate(TrainingIterator(n_trans // batch_size)):  # mini-bathc loop

                # the epoch is defined on the collected transition data and not on the expert data

                batch_idxs_trans = shuffled_idxs_trans[batch_size * i: batch_size * (i + 1)]
                batch_idxs_expert = torch.tensor(random.sample(range(n_expert), k=batch_size), device=self.device)

                # get mini-batch of agent transitions

                obs_batch = agent_obs_var.get_from_index(batch_idxs_trans)
                act_batch = act_var[batch_idxs_trans]

                # get mini-batch of expert transitions

                expert_obs_batch = expert_obs_var.get_from_index(batch_idxs_expert)
                expert_act_batch = expert_act_var[batch_idxs_expert]

                labels = torch.zeros((batch_size * 2, 1), device=self.device)
                labels[batch_size:] = 1.0  # expert is one
                total_obs_batch = torch_cat_obs([obs_batch, expert_obs_batch], dim=0)
                total_act_batch = torch.cat([act_batch, expert_act_batch], dim=0)

                loss = self.discriminator.get_classification_loss(obs=total_obs_batch, action=total_act_batch,
                                                                           target=labels)

                if self.gradient_penalty_coef != 0.0:
                    grad_penalty = self.discriminator.get_grad_penality(
                        obs_e=expert_obs_batch, obs_l=obs_batch, act_e=expert_act_batch, act_l=act_batch,
                        gradient_penalty_coef=self.gradient_penalty_coef
                    )
                    loss += grad_penalty

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_norm_clip)
                self.optimizer.step()

                it_update.record('d_loss', loss.cpu().data.numpy())

        return_dict = {}
        return_dict.update(it_update.pop_all_means())

        return return_dict
Exemplo n.º 13
0
    def train_model(self, experiences):
        obs, act, new_obs, rew, mask = experiences
        obs, new_obs = [
            obs_to_torch(o, device=self.device) for o in (obs, new_obs)
        ]
        act, rew, mask = [
            to_torch(e, device=self.device) for e in (act, rew, mask)
        ]

        # Define critic loss.
        cat = torch.cat((obs['obs_vec'], act), dim=-1)

        q1 = self.q1(cat).squeeze(1)  # N
        q2 = self.q2(cat).squeeze(1)  # N

        with torch.no_grad():
            new_act, new_log_prob = self.pi.act(new_obs,
                                                sample=True,
                                                return_log_pi=True)
            new_cat = torch.cat((new_obs['obs_vec'], new_act), dim=-1).detach()
            tq1 = self.tq1(new_cat).squeeze(1)  # N
            tq2 = self.tq2(new_cat).squeeze(1)  # N

        tq = torch.min(tq1, tq2)  # N
        # print(rew.shape, mask.shape, tq.shape, self._alpha, new_log_prob.shape)
        target = rew + self.gamma * mask * (tq - self._alpha * new_log_prob)

        q1_loss = (0.5 * (target - q1)**2).mean()
        q2_loss = (0.5 * (target - q2)**2).mean()

        # Define actor loss.
        act2, log_prob2 = self.pi.act(obs, sample=True, return_log_pi=True)
        cat2 = torch.cat((obs['obs_vec'], act2), dim=-1)

        q1 = self.q1(cat2).squeeze(1)  # N
        q2 = self.q2(cat2).squeeze(1)  # N
        q = torch.min(q1, q2)  # N

        pi_loss = (self._alpha * log_prob2 - q).mean()

        # Update non-target networks.
        if self.learn_alpha:
            # Define alpha loss
            alpha_loss = -self._log_alpha * (log_prob2.detach() +
                                             self.target_entropy).mean()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self._alpha = torch.exp(self._log_alpha).detach()

        self.pi.optim.zero_grad()
        pi_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.pi.parameters(),
                                       self.grad_norm_clip)
        self.pi.optim.step()

        self.q1.optim.zero_grad()
        q1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q1.parameters(),
                                       self.grad_norm_clip)
        self.q1.optim.step()

        self.q2.optim.zero_grad()
        q2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q2.parameters(),
                                       self.grad_norm_clip)
        self.q2.optim.step()

        # Update target networks.
        self.update_targets()

        return {
            'alpha_loss': alpha_loss.detach().cpu().numpy(),
            'pi_loss': pi_loss.detach().cpu().numpy(),
            'q1_loss': q1_loss.detach().cpu().numpy(),
            'q2_loss': q2_loss.detach().cpu().numpy(),
            'pi_entropy': -log_prob2.mean().detach().cpu().numpy(),
        }
Exemplo n.º 14
0
    def train_model(self, experiences):
        observations = obs_to_torch(obs=experiences[0], device=self.device)
        actions = to_torch(data=experiences[1], device=self.device)
        next_observations = obs_to_torch(obs=experiences[2],
                                         device=self.device)
        rewards = experiences[3]
        masks = experiences[4]

        old_values = self.critic(observations)
        old_next_values = self.critic(next_observations)
        returns, advants = self.get_gae(rewards=rewards,
                                        masks=masks,
                                        values=old_values.detach(),
                                        next_values=old_next_values.detach())

        old_policy = self.actor.get_log_prob_from_obs_action_pairs(
            action=actions, obs=observations)

        criterion = torch.nn.MSELoss()
        n = len(observations)

        for it_update in TrainingIterator(self.epochs_per_update):
            shuffled_idxs = torch.randperm(n, device=self.device)

            for i, it_batch in enumerate(TrainingIterator(n //
                                                          self.batch_size)):
                batch_idxs = shuffled_idxs[self.batch_size *
                                           i:self.batch_size * (i + 1)]

                inputs = ObsDict({
                    obs_name: obs_value[batch_idxs]
                    for obs_name, obs_value in observations.items()
                })
                actions_samples = actions[batch_idxs]
                returns_samples = returns.unsqueeze(1)[batch_idxs]
                advants_samples = advants.unsqueeze(1)[batch_idxs]
                oldvalue_samples = old_values[batch_idxs].detach()

                values = self.critic(inputs)
                clipped_values = oldvalue_samples + \
                                 torch.clamp(values - oldvalue_samples,
                                             -self.update_clip_param,
                                             self.update_clip_param)
                critic_loss1 = criterion(clipped_values, returns_samples)
                critic_loss2 = criterion(values, returns_samples)
                critic_loss = torch.max(critic_loss1, critic_loss2)

                loss, ratio = self.surrogate_loss(advants_samples, inputs,
                                                  old_policy.detach(),
                                                  actions_samples, batch_idxs)
                clipped_ratio = torch.clamp(ratio,
                                            1.0 - self.update_clip_param,
                                            1.0 + self.update_clip_param)
                clipped_loss = clipped_ratio * advants_samples
                actor_loss = -torch.min(loss, clipped_loss).mean(0)

                loss = actor_loss + self.critic_loss_coeff * critic_loss

                self.critic.optim.zero_grad()
                self.actor.optim.zero_grad()

                loss.backward()
                nn.utils.clip_grad_norm_([p for p in self.actor.parameters()] +
                                         [p for p in self.critic.parameters()],
                                         self.grad_norm_clip)

                self.critic.optim.step()
                self.actor.optim.step()
                it_update.record('loss', to_numpy(loss))

        vals = it_update.pop('loss')
        return_dict = {'loss': np.mean(vals)}

        return_dict.update({
            'ppo_rewards_mean': np.mean(rewards),
            'ppo_rewards_max': np.max(rewards),
            'ppo_rewards_min': np.min(rewards),
            'ppo_rewards_std': np.std(rewards),
            'ppo_rewards_median': np.median(rewards),
        })
        return return_dict
Exemplo n.º 15
0
    def train_model(self, experiences):
        observations = obs_to_torch(experiences[0], device=self.device)
        actions = experiences[1]
        next_states = obs_to_torch(experiences[2], device=self.device)
        rewards = to_torch(experiences[3], device=self.device)
        masks = to_torch(experiences[4], device=self.device)

        q1_s = self.q1(observations)
        q2_s = self.q2(observations)

        q_s = torch.min(q1_s, q2_s)

        alpha = self.log_alpha.value.exp()

        alpha_loss = (-(F.softmax(q_s / alpha, dim=1) * q_s).sum(1) + alpha *
                      (-self.target_entropy +
                       log_sum_exp(q_s / alpha, dim=1, keepdim=False)))
        alpha_loss = alpha_loss.mean(0)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(self.log_alpha.parameters(),
                                       self.grad_norm_clip)
        self.alpha_optimizer.step()

        # q values of current state action pairs

        q1_s_a = q1_s.gather(dim=1,
                             index=to_torch(actions,
                                            type=int,
                                            device=self.device)).squeeze(1)
        q2_s_a = q2_s.gather(dim=1,
                             index=to_torch(actions,
                                            type=int,
                                            device=self.device)).squeeze(1)

        # # target q values

        q1_sp = self.tq1(next_states)
        q2_sp = self.tq2(next_states)

        target_q = torch.min(q1_sp, q2_sp)

        pi_entropy = (-F.softmax(target_q / alpha, dim=1) *
                      F.log_softmax(target_q / alpha, dim=1)).sum(1)
        target = rewards + masks * self.gamma * (
            (F.softmax(target_q / alpha, dim=1) * target_q).sum(1) +
            alpha * pi_entropy)

        # losses

        q1_loss = ((q1_s_a - target.detach())**2).mean(0)
        q2_loss = ((q2_s_a - target.detach())**2).mean(0)

        # backprop

        self.q1.optim.zero_grad()
        q1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q1.parameters(),
                                       self.grad_norm_clip)
        self.q1.optim.step()

        self.q2.optim.zero_grad()
        q2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q2.parameters(),
                                       self.grad_norm_clip)
        self.q2.optim.step()

        return_dict = {}
        return_dict.update({
            'q1_loss': to_numpy(q1_loss),
            'q2_loss': to_numpy(q2_loss),
            'alpha_loss': to_numpy(alpha_loss),
            'pi_entropy': to_numpy(pi_entropy.mean()),
            'q_s': to_numpy(q_s.mean()),
            'alpha': to_numpy(alpha)
        })

        # update targets networks

        self.update_targets()

        return return_dict
Exemplo n.º 16
0
 def act(self, obs, sample):
     assert self.discriminator.use_multi_head
     #  the softmax makes no difference if from Q or advantages
     logits = self.discriminator.h(obs_to_torch(obs, unsqueeze_dim=0, device=self.device))
     return int(
         CategoricalPolicy.act_from_logits(logits=logits, sample=sample, return_log_pi=False).cpu().numpy())
Exemplo n.º 17
0
    def fit(self, data, batch_size, n_epochs_per_update, logger, **kwargs):

        # GET OBSERVATIONS / ACTIONS FOR LEARNER (l) AND EXPERT (e)

        obs_l, act_l, _, _, mask_l = data
        obs_e, act_e, mask_e = self.data_e

        if self.break_traj_to_windows:
            if self.window_over_episode:
                window_start_l = self.mask_to_window_start(mask_l)
                window_start_e = self.mask_to_window_start(mask_e)
                window_idx_l = [
                    list(range(s, s + self.window_size))
                    for s in window_start_l
                ]
                window_idx_e = [
                    list(range(s, s + self.window_size))
                    for s in window_start_e
                ]
                batch_size_e = batch_size
                batch_size_l = batch_size
            else:
                ep_start_l, ep_end_l = self.mask_to_episodes_limits(mask_l)
                window_start_l, window_end_l = self.split_episodes_to_windows(
                    ep_start_l, ep_end_l)
                ep_start_e, ep_end_e = self.mask_to_episodes_limits(mask_e)
                window_start_e, window_end_e = self.split_episodes_to_windows(
                    ep_start_e, ep_end_e)
                window_idx_l = [
                    list(range(s, e + 1))
                    for s, e in zip(window_start_l, window_end_l)
                ]
                window_idx_e = [
                    list(range(s, e + 1))
                    for s, e in zip(window_start_e, window_end_e)
                ]
                batch_size_l = int(
                    len(window_start_l) /
                    (kwargs['config'].d_episodes_between_updates / batch_size))
                batch_size_e = int(
                    len(window_start_e) /
                    (kwargs['config'].d_episodes_between_updates / batch_size))
                # we increase the batch_size so that even if the episode length increases the number of steps per epoch
                # remains constant. This is because for asaf-w on episodes (cannot window across episodes) the batch_size
                # is defined in terms of trajectories and not in terms of windows. Also note that this is not used for
                # asaf-w on transitions (can window across episodes) since the episode length does not influences the
                # optimization anymore.
        else:
            window_start_l, window_end_l = self.mask_to_episodes_limits(mask_l)
            window_start_e, window_end_e = self.mask_to_episodes_limits(mask_e)
            window_idx_l = [
                list(range(s, e + 1))
                for s, e in zip(window_start_l, window_end_l)
            ]
            window_idx_e = [
                list(range(s, e + 1))
                for s, e in zip(window_start_e, window_end_e)
            ]
            batch_size_e = batch_size
            batch_size_l = batch_size

        n_windows_l, n_windows_e = len(window_start_l), len(window_start_e)

        # COMPUTE ADVANTAGES USING (OLD) POLICY

        obs_l = obs_to_torch(obs_l, device=self.device)
        act_l = to_torch(act_l, device=self.device)

        with torch.no_grad():
            old_adv_l = self.discriminator.pi.get_log_prob_from_obs_action_pairs(
                obs=obs_l, action=act_l).squeeze()
            old_adv_e = self.discriminator.pi.get_log_prob_from_obs_action_pairs(
                obs=obs_e, action=act_e).squeeze()

        # TRAIN DISCRIMINATOR

        for it_update in TrainingIterator(n_epochs_per_update):
            shuffled_window_idx_l = np.random.permutation(n_windows_l)

            for i in range(n_windows_l // batch_size_l):

                # Shuffle the window indexes

                mb_rand_window_num_l = shuffled_window_idx_l[i *
                                                             batch_size_l:(i +
                                                                           1) *
                                                             batch_size_l]
                mb_rand_window_num_e = np.random.randint(
                    low=0, high=n_windows_e, size=batch_size_e
                )  # expert sample idx are sampled during iteration.

                # Get the indices corresponding window and flatten into one long list

                mb_window_idx_l = [
                    window_idx_l[num] for num in mb_rand_window_num_l
                ]
                mb_window_idx_e = [
                    window_idx_e[num] for num in mb_rand_window_num_e
                ]

                mb_flat_idx_l = list(itertools.chain(*mb_window_idx_l))
                mb_flat_idx_e = list(itertools.chain(*mb_window_idx_e))

                # Count window lengths

                mb_window_len_l = [len(window) for window in mb_window_idx_l]
                mb_window_len_e = [len(window) for window in mb_window_idx_e]

                # Concatenate collected and expert data

                mb_obs = torch_cat_obs((obs_l.get_from_index(mb_flat_idx_l),
                                        obs_e.get_from_index(mb_flat_idx_e)),
                                       dim=0)
                mb_act = torch.cat(
                    (act_l[mb_flat_idx_l], act_e[mb_flat_idx_e]), dim=0)
                mb_window_len = mb_window_len_l + mb_window_len_e

                # Compute (new) advantages using (current) policy and concatenate old advantages

                mb_new_adv = self.discriminator.pi.get_log_prob_from_obs_action_pairs(
                    obs=mb_obs, action=mb_act)
                mb_old_adv = torch.cat(
                    (old_adv_l[mb_flat_idx_l], old_adv_e[mb_flat_idx_e]),
                    dim=0).view(-1, 1)

                # Sum the advantages "window-wise" to be left with a vector of window-sums (length=2*mb_size)
                # This implements the sum of exponents in discriminator: q_pi = \prod_t exp(adv_t) = exp(\sum_t adv_t)

                new_sum_adv = torch.stack(
                    [s.sum(0) for s in torch.split(mb_new_adv, mb_window_len)],
                    dim=0).view(-1, 1)
                old_sum_adv = torch.stack(
                    [s.sum(0) for s in torch.split(mb_old_adv, mb_window_len)],
                    dim=0).view(-1, 1)

                # Computes the structured discriminator's binary cross-entropy loss
                # prob for agent:  log D = log-numerator - log-denominator = adv_pi - log ( exp(adv_pi) + exp(adv_g) )
                # prob for expert:  log D = log-numerator - log-denominator = adv_g - log ( exp(adv_pi) + exp(adv_g) )

                to_sum = torch.cat((new_sum_adv, old_sum_adv), dim=1)
                log_denominator = to_sum.logsumexp(dim=1, keepdim=True)
                target = torch.zeros_like(log_denominator)
                target[
                    batch_size_l:] = 1  # second half of the data is from the expert

                loss = -(target * (new_sum_adv - log_denominator) +
                         (1 - target) *
                         (old_sum_adv - log_denominator)).mean(0)

                # Backpropagation and gradient step

                self.discriminator.pi.optim.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_value_(
                    self.discriminator.pi.parameters(), self.grad_value_clip)
                torch.nn.utils.clip_grad_norm_(
                    self.discriminator.pi.parameters(), self.grad_norm_clip)
                self.discriminator.pi.optim.step()

                # Book-keeping

                it_update.record('d_loss', loss.cpu().data.numpy())

        vals = it_update.pop('d_loss')

        return_dict = {
            'd_loss': np.mean(vals),
            'd_loss_max': np.max(vals),
            'd_loss_min': np.min(vals),
            'd_loss_std': np.std(vals)
        }

        return return_dict