def learn(self, writer, i_iter):
        """learn model"""
        memory, log = self.collector.collect_samples(self.min_batch_size)

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalars(
            "reinforce", {
                "total reward": log['total_reward'],
                "average reward": log['avg_reward'],
                "min reward": log['min_episode_reward'],
                "max reward": log['max_episode_reward'],
                "num steps": log['num_steps']
            }, i_iter)

        batch = memory.sample()  # sample all items in memory

        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        p_loss = torch.empty(1)
        for _ in range(self.reinforce_epochs):
            p_loss = reinforce_step(self.policy_net, self.optimizer_p,
                                    batch_state, batch_action, batch_reward,
                                    batch_mask, self.gamma)
        return p_loss
Exemple #2
0
    def learn(self, writer, i_iter):
        """learn model"""
        memory, log = self.collector.collect_samples(self.min_batch_size)

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalar("total reward", log['total_reward'], i_iter)
        writer.add_scalar("average reward", log['avg_reward'], i_iter)
        writer.add_scalar("min reward", log['min_episode_reward'], i_iter)
        writer.add_scalar("max reward", log['max_episode_reward'], i_iter)
        writer.add_scalar("num steps", log['num_steps'], i_iter)

        batch = memory.sample()  # sample all items in memory

        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        with torch.no_grad():
            batch_value = self.ac_net.get_value(batch_state)

        batch_advantage, batch_return = estimate_advantages(
            batch_reward, batch_mask, batch_value, self.gamma, self.tau)

        alg_step_stats = a2c_step(self.ac_net, self.optimizer_ac, batch_state,
                                  batch_action, batch_return, batch_advantage,
                                  self.value_net_coeff, self.entropy_coeff)

        return alg_step_stats
Exemple #3
0
    def learn(self, writer, i_iter):
        """learn model"""
        memory, log = self.collector.collect_samples(self.min_batch_size)

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalar("total reward", log['total_reward'], i_iter)
        writer.add_scalar("average reward", log['avg_reward'], i_iter)
        writer.add_scalar("min reward", log['min_episode_reward'], i_iter)
        writer.add_scalar("max reward", log['max_episode_reward'], i_iter)
        writer.add_scalar("num steps", log['num_steps'], i_iter)

        batch = memory.sample()  # sample all items in memory

        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_mask = FLOAT(batch.mask).to(device)
        batch_log_prob = FLOAT(batch.log_prob).to(device)

        with torch.no_grad():
            batch_value = self.value_net(batch_state)

        batch_advantage, batch_return = estimate_advantages(
            batch_reward, batch_mask, batch_value, self.gamma, self.tau)

        # update by TRPO
        trpo_step(self.policy_net, self.value_net, batch_state, batch_action,
                  batch_return, batch_advantage, batch_log_prob, self.max_kl,
                  self.damping, 1e-3, None)
    def update(self, batch):
        batch_state = FLOAT(batch.state).to(device)
        batch_action = LONG(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_next_state = FLOAT(batch.next_state).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        duelingdqn_step(self.value_net, self.optimizer, self.value_net_target, batch_state, batch_action,
                        batch_reward, batch_next_state, batch_mask, self.gamma)
Exemple #5
0
    def update(self, batch, global_steps):
        batch_state = FLOAT(batch.state).to(device)
        batch_action = LONG(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_next_state = FLOAT(batch.next_state).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        doubledqn_step(self.value_net, self.optimizer, self.value_net_target,
                       batch_state, batch_action, batch_reward,
                       batch_next_state, batch_mask, self.gamma, self.polyak,
                       global_steps % self.update_target_gap == 0)
Exemple #6
0
    def update(self, batch):
        """learn model"""
        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_next_state = FLOAT(batch.next_state).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        # update by DDPG
        ddpg_step(self.policy_net, self.policy_net_target, self.value_net,
                  self.value_net_target, self.optimizer_p, self.optimizer_v,
                  batch_state, batch_action, batch_reward, batch_next_state,
                  batch_mask, self.gamma, self.polyak)
    def update(self, batch, k_iter):
        """learn model"""
        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_next_state = FLOAT(batch.next_state).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        # update by SAC
        sac_step(self.policy_net, self.value_net, self.value_net_target,
                 self.q_net_1, self.q_net_2, self.optimizer_p,
                 self.optimizer_v, self.optimizer_q_1, self.optimizer_q_2,
                 batch_state, batch_action, batch_reward, batch_next_state,
                 batch_mask, self.gamma, self.polyak,
                 k_iter % self.target_update_delay == 0)
Exemple #8
0
 def choose_action(self, state):
     """select action"""
     state = FLOAT(state).unsqueeze(0).to(device)
     with torch.no_grad():
         action, _ = self.policy_net.rsample(state)
     action = action.cpu().numpy()[0]
     return action, None
def collect_samples(pid, queue, env, policy, render, running_state,
                    custom_reward, min_batch_size):
    torch.randn(pid)
    log = dict()
    memory = Memory()
    num_steps = 0
    num_episodes = 0

    min_episode_reward = float('inf')
    max_episode_reward = float('-inf')
    total_reward = 0

    while num_steps < min_batch_size:
        state = env.reset()
        episode_reward = 0
        if running_state:
            state = running_state(state)

        for t in range(10000):
            if render:
                env.render()
            state_tensor = FLOAT(state).unsqueeze(0)
            with torch.no_grad():
                action, log_prob = policy.get_action_log_prob(state_tensor)
            action = action.cpu().numpy()[0]
            log_prob = log_prob.cpu().numpy()[0]
            next_state, reward, done, _ = env.step(action)
            if custom_reward:
                reward = custom_reward(state, action)
            episode_reward += reward

            if running_state:
                next_state = running_state(next_state)

            mask = 0 if done else 1
            # ('state', 'action', 'reward', 'next_state', 'mask', 'log_prob')
            memory.push(state, action, reward, next_state, mask, log_prob)
            num_steps += 1
            if done or num_steps >= min_batch_size:
                break

            state = next_state

        # num_steps += (t + 1)
        num_episodes += 1
        total_reward += episode_reward
        min_episode_reward = min(episode_reward, min_episode_reward)
        max_episode_reward = max(episode_reward, max_episode_reward)

    log['num_steps'] = num_steps
    log['num_episodes'] = num_episodes
    log['total_reward'] = total_reward
    log['avg_reward'] = total_reward / num_episodes
    log['max_episode_reward'] = max_episode_reward
    log['min_episode_reward'] = min_episode_reward

    if queue is not None:
        queue.put([pid, memory, log])
    else:
        return memory, log
    def update(self, batch, k_iter):
        """learn model"""
        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_next_state = FLOAT(batch.next_state).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        # update by TD3
        td3_step(self.policy_net, self.policy_net_target, self.value_net_1,
                 self.value_net_target_1, self.value_net_2,
                 self.value_net_target_2, self.optimizer_p, self.optimizer_v_1,
                 self.optimizer_v_2, batch_state, batch_action, batch_reward,
                 batch_next_state, batch_mask, self.gamma, self.polyak,
                 self.target_action_noise_std, self.target_action_noise_clip,
                 self.action_high, k_iter % self.policy_update_delay == 0)
Exemple #11
0
    def choose_action(self, state):
        """select action"""
        state = FLOAT(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.ac_net.get_action_log_prob(state)

        action = action.cpu().numpy()[0]
        return action
Exemple #12
0
 def choose_action(self, state):
     state = FLOAT(state).unsqueeze(0).to(device)
     if np.random.uniform() <= self.epsilon:
         with torch.no_grad():
             action = self.value_net.get_action(state)
         action = action.cpu().numpy()[0]
     else:  # choose action greedy
         action = np.random.randint(0, self.num_actions)
     return action
Exemple #13
0
 def choose_action(self, state, noise_scale):
     """select action"""
     state = FLOAT(state).unsqueeze(0).to(device)
     with torch.no_grad():
         action, log_prob = self.policy_net.get_action_log_prob(state)
     action = action.cpu().numpy()[0]
     # add noise
     noise = noise_scale * np.random.randn(self.num_actions)
     action += noise
     action = np.clip(action, -self.action_high, self.action_high)
     return action
    def learn(self, writer, i_iter):
        """learn model"""
        memory, log = self.collector.collect_samples(self.min_batch_size)

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalars(
            "vpg", {
                "total reward": log['total_reward'],
                "average reward": log['avg_reward'],
                "min reward": log['min_episode_reward'],
                "max reward": log['max_episode_reward'],
                "num steps": log['num_steps']
            }, i_iter)

        batch = memory.sample()  # sample all items in memory

        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_mask = FLOAT(batch.mask).to(device)

        with torch.no_grad():
            batch_value = self.value_net(batch_state)

        batch_advantage, batch_return = estimate_advantages(
            batch_reward, batch_mask, batch_value, self.gamma, self.tau)
        v_loss, p_loss = vpg_step(self.policy_net, self.value_net,
                                  self.optimizer_p, self.optimizer_v,
                                  self.vpg_epochs, batch_state, batch_action,
                                  batch_return, batch_advantage, 1e-3)
        return v_loss, p_loss
    def value_objective_grad_func(value_net_flat_params):
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        value_loss.backward()  # to get the grad
        objective_value_loss_grad = get_flat_grad_params(
            value_net).detach().cpu().numpy().astype(np.float64)
        return objective_value_loss_grad
Exemple #16
0
    def value_objective_func(value_net_flat_params):
        """
        get value_net loss
        :param value_net_flat_params: numpy
        :return:
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        objective_value_loss = value_loss.item()
        # print("Current value loss: ", objective_value_loss)
        return objective_value_loss
Exemple #17
0
def trpo_step(policy_net,
              value_net,
              states,
              actions,
              returns,
              advantages,
              old_log_probs,
              max_kl,
              damping,
              l2_reg,
              optimizer_value=None):
    """
    Update by TRPO algorithm
    """
    """update critic"""
    def value_objective_func(value_net_flat_params):
        """
        get value_net loss
        :param value_net_flat_params: numpy
        :return:
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        objective_value_loss = value_loss.item()
        # print("Current value loss: ", objective_value_loss)
        return objective_value_loss

    def value_objective_grad_func(value_net_flat_params):
        """
        objective function for scipy optimizing 
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        value_loss.backward()  # to get the grad
        objective_value_loss_grad = get_flat_grad_params(
            value_net).detach().cpu().numpy().astype(np.float64)
        return objective_value_loss_grad

    if optimizer_value is None:
        """ 
        update by scipy optimizing, for detail about L-BFGS-B: ref: 
        https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb
        """
        value_net_flat_params_old = get_flat_params(
            value_net).detach().cpu().numpy().astype(
                np.float64)  # initial guess
        res = opt.minimize(value_objective_func,
                           value_net_flat_params_old,
                           method='L-BFGS-B',
                           jac=value_objective_grad_func,
                           options={
                               "maxiter": 30,
                               "disp": False
                           })
        # print("Call L-BFGS-B, result: ", res)
        value_net_flat_params_new = res.x
        set_flat_params(value_net, FLOAT(value_net_flat_params_new))

    else:
        """
        update by gradient descent
        """
        for _ in range(10):
            values_pred = value_net(states)
            value_loss = nn.MSELoss()(values_pred, returns)
            # weight decay
            for param in value_net.parameters():
                value_loss += param.pow(2).sum() * l2_reg
            optimizer_value.zero_grad()
            value_loss.backward()
            optimizer_value.step()
    """update policy"""
    update_policy(policy_net, states, actions, old_log_probs, advantages,
                  max_kl, damping)
Exemple #18
0
 def choose_action(self, state):
     """select action"""
     state = FLOAT(state).unsqueeze(0).to(device)
     with torch.no_grad():
         action, log_prob = self.policy_net.get_action_log_prob(state)
     return action, log_prob
    def learn(self, writer, i_iter):
        """learn model"""
        memory, log = self.collector.collect_samples(self.min_batch_size)

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalars(
            "ppo", {
                "total reward": log['total_reward'],
                "average reward": log['avg_reward'],
                "min reward": log['min_episode_reward'],
                "max reward": log['max_episode_reward'],
                "num steps": log['num_steps']
            }, i_iter)

        batch = memory.sample()  # sample all items in memory
        #  ('state', 'action', 'reward', 'next_state', 'mask', 'log_prob')
        batch_state = FLOAT(batch.state).to(device)
        batch_action = FLOAT(batch.action).to(device)
        batch_reward = FLOAT(batch.reward).to(device)
        batch_mask = FLOAT(batch.mask).to(device)
        batch_log_prob = FLOAT(batch.log_prob).to(device)

        with torch.no_grad():
            batch_value = self.value_net(batch_state)

        batch_advantage, batch_return = estimate_advantages(
            batch_reward, batch_mask, batch_value, self.gamma, self.tau)
        v_loss, p_loss = torch.empty(1), torch.empty(1)

        for _ in range(self.ppo_epochs):
            if self.ppo_mini_batch_size:
                batch_size = batch_state.shape[0]
                mini_batch_num = int(
                    math.ceil(batch_size / self.ppo_mini_batch_size))

                # update with mini-batch
                for _ in range(self.ppo_epochs):
                    index = torch.randperm(batch_size)

                    for i in range(mini_batch_num):
                        ind = index[slice(
                            i * self.ppo_mini_batch_size,
                            min(batch_size,
                                (i + 1) * self.ppo_mini_batch_size))]
                        state, action, returns, advantages, old_log_pis = batch_state[ind], batch_action[ind], \
                                                                          batch_return[
                                                                              ind], batch_advantage[ind], \
                                                                          batch_log_prob[
                                                                              ind]

                        v_loss, p_loss = ppo_step(
                            self.policy_net, self.value_net, self.optimizer_p,
                            self.optimizer_v, 1, state, action, returns,
                            advantages, old_log_pis, self.clip_epsilon, 1e-3)
            else:
                v_loss, p_loss = ppo_step(self.policy_net, self.value_net,
                                          self.optimizer_p, self.optimizer_v,
                                          1, batch_state, batch_action,
                                          batch_return, batch_advantage,
                                          batch_log_prob, self.clip_epsilon,
                                          1e-3)

        return v_loss, p_loss
    def learn(self, writer, i_iter):
        memory, log = self.collector.collect_samples(
            self.config["train"]["generator"]["sample_batch_size"])

        self.policy.train()
        self.value.train()
        self.discriminator.train()

        print(
            f"Iter: {i_iter}, num steps: {log['num_steps']}, total reward: {log['total_reward']: .4f}, "
            f"min reward: {log['min_episode_reward']: .4f}, max reward: {log['max_episode_reward']: .4f}, "
            f"average reward: {log['avg_reward']: .4f}, sample time: {log['sample_time']: .4f}"
        )

        # record reward information
        writer.add_scalar("gail/average reward", log['avg_reward'], i_iter)
        writer.add_scalar("gail/num steps", log['num_steps'], i_iter)

        # collect generated batch
        # gen_batch = self.collect_samples(self.config["ppo"]["sample_batch_size"])
        gen_batch = memory.sample()
        gen_batch_state = FLOAT(gen_batch.state).to(
            device)  # [batch size, state size]
        gen_batch_action = FLOAT(gen_batch.action).to(
            device)  # [batch size, action size]
        gen_batch_old_log_prob = FLOAT(gen_batch.log_prob).to(
            device)  # [batch size, 1]
        gen_batch_mask = FLOAT(gen_batch.mask).to(device)  # [batch, 1]

        ####################################################
        # update discriminator
        ####################################################
        d_optim_i_iters = self.config["train"]["discriminator"]["optim_step"]
        if i_iter % d_optim_i_iters == 0:
            for step, (expert_batch_state, expert_batch_action) in enumerate(
                    self.expert_dataset.train_loader):
                if step >= d_optim_i_iters:
                    break
                # calculate probs and logits
                gen_prob, gen_logits = self.discriminator(
                    gen_batch_state, gen_batch_action)
                expert_prob, expert_logits = self.discriminator(
                    expert_batch_state.to(device),
                    expert_batch_action.to(device))

                # calculate accuracy
                gen_acc = torch.mean((gen_prob < 0.5).float())
                expert_acc = torch.mean((expert_prob > 0.5).float())

                # calculate regression loss
                expert_labels = torch.ones_like(expert_prob)
                gen_labels = torch.zeros_like(gen_prob)
                e_loss = self.discriminator_func(expert_prob,
                                                 target=expert_labels)
                g_loss = self.discriminator_func(gen_prob, target=gen_labels)
                d_loss = e_loss + g_loss

                # calculate entropy loss
                logits = torch.cat([gen_logits, expert_logits], 0)
                entropy = ((1. - torch.sigmoid(logits)) * logits -
                           torch.nn.functional.logsigmoid(logits)).mean()
                entropy_loss = - \
                    self.config["train"]["discriminator"]["ent_coeff"] * entropy

                total_loss = d_loss + entropy_loss

                self.optimizer_discriminator.zero_grad()
                total_loss.backward()
                self.optimizer_discriminator.step()

        writer.add_scalar('discriminator/d_loss', d_loss.item(), i_iter)
        writer.add_scalar("discriminator/e_loss", e_loss.item(), i_iter)
        writer.add_scalar("discriminator/g_loss", g_loss.item(), i_iter)
        writer.add_scalar("discriminator/ent", entropy.item(), i_iter)
        writer.add_scalar('discriminator/expert_acc', gen_acc.item(), i_iter)
        writer.add_scalar('discriminator/gen_acc', expert_acc.item(), i_iter)

        ####################################################
        # update policy by ppo [mini_batch]
        ####################################################

        with torch.no_grad():
            gen_batch_value = self.value(gen_batch_state)
            d_out, _ = self.discriminator(gen_batch_state, gen_batch_action)
            gen_batch_reward = -torch.log(1 - d_out + 1e-6)

        gen_batch_advantage, gen_batch_return = estimate_advantages(
            gen_batch_reward, gen_batch_mask, gen_batch_value,
            self.config["train"]["generator"]["gamma"],
            self.config["train"]["generator"]["tau"])

        ppo_optim_i_iters = self.config["train"]["generator"]["optim_step"]
        ppo_mini_batch_size = self.config["train"]["generator"][
            "mini_batch_size"]

        for _ in range(ppo_optim_i_iters):
            if ppo_mini_batch_size > 0:
                gen_batch_size = gen_batch_state.shape[0]
                optim_iter_num = int(
                    math.ceil(gen_batch_size / ppo_mini_batch_size))
                perm = torch.randperm(gen_batch_size)

                for i in range(optim_iter_num):
                    ind = perm[slice(
                        i * ppo_mini_batch_size,
                        min((i + 1) * ppo_mini_batch_size, gen_batch_size))]
                    mini_batch_state, mini_batch_action, mini_batch_advantage, mini_batch_return, \
                        mini_batch_old_log_prob = gen_batch_state[ind], gen_batch_action[ind], \
                        gen_batch_advantage[ind], gen_batch_return[ind], gen_batch_old_log_prob[
                            ind]

                    v_loss, p_loss, ent_loss = ppo_step(
                        policy_net=self.policy,
                        value_net=self.value,
                        optimizer_policy=self.optimizer_policy,
                        optimizer_value=self.optimizer_value,
                        optim_value_iternum=self.config["value"]
                        ["optim_value_iter"],
                        states=mini_batch_state,
                        actions=mini_batch_action,
                        returns=mini_batch_return,
                        old_log_probs=mini_batch_old_log_prob,
                        advantages=mini_batch_advantage,
                        clip_epsilon=self.config["train"]["generator"]
                        ["clip_ratio"],
                        l2_reg=self.config["value"]["l2_reg"])
            else:
                v_loss, p_loss, ent_loss = ppo_step(
                    policy_net=self.policy,
                    value_net=self.value,
                    optimizer_policy=self.optimizer_policy,
                    optimizer_value=self.optimizer_value,
                    optim_value_iternum=self.config["value"]
                    ["optim_value_iter"],
                    states=gen_batch_state,
                    actions=gen_batch_action,
                    returns=gen_batch_return,
                    old_log_probs=gen_batch_old_log_prob,
                    advantages=gen_batch_advantage,
                    clip_epsilon=self.config["train"]["generator"]
                    ["clip_ratio"],
                    l2_reg=self.config["value"]["l2_reg"])

        writer.add_scalar('generator/p_loss', p_loss, i_iter)
        writer.add_scalar('generator/v_loss', v_loss, i_iter)
        writer.add_scalar('generator/ent_loss', ent_loss, i_iter)

        print(f" Training episode:{i_iter} ".center(80, "#"))
        print('d_gen_prob:', gen_prob.mean().item())
        print('d_expert_prob:', expert_prob.mean().item())
        print('d_loss:', d_loss.item())
        print('e_loss:', e_loss.item())
        print("d/bernoulli_entropy:", entropy.item())
    def __init__(self,
                 expert_data_path,
                 train_fraction=0.7,
                 traj_limitation=-1,
                 shuffle=True,
                 batch_size=64,
                 num_workers=multiprocessing.cpu_count()):
        """
        Custom dataset deal with gail expert dataset
        """
        traj_data = np.load(expert_data_path, allow_pickle=True)

        states = traj_data['state' or 'obs']
        actions = traj_data['action' or 'acs']
        self.ep_ret = traj_data['ep_reward' or 'ep_rets']
        if traj_limitation < 0:
            traj_limitation = len(self.ep_ret)
            self.ep_ret = self.ep_ret[:traj_limitation]

        # states, actions: shape (N, L, ) + S where N = # episodes, L = episode length
        # and S is the environment observation/action space.
        # Flatten to (N * L, prod(S))
        if len(states.shape) > 2:
            self.states = np.reshape(states, [-1, np.prod(states.shape[2:])])
            self.actions = np.reshape(actions,
                                      [-1, np.prod(actions.shape[2:])])
        else:
            self.states = np.vstack(states)
            self.actions = np.vstack(actions)

        self._num_states = self.states.shape[-1]
        self._num_actions = self.actions.shape[-1]

        self.avg_ret = sum(self.ep_ret) / len(self.ep_ret)
        self.std_ret = np.std(np.array(self.ep_ret))
        self.shuffle = shuffle

        assert len(self.states) == len(self.actions), "The number of actions and observations differ " \
                                                      "please check your expert dataset"
        self.num_traj = min(traj_limitation,
                            len(traj_data['ep_reward' or 'ep_rets']))
        self.num_transition = len(self.states)

        self.data_loader = DataLoader(TensorDataset(
            FLOAT(self.states),
            FLOAT(self.actions),
        ),
                                      shuffle=self.shuffle,
                                      batch_size=batch_size,
                                      num_workers=num_workers)

        self.train_loader = DataLoader(TensorDataset(
            FLOAT(self.states[:int(self.num_transition * train_fraction), :]),
            FLOAT(self.actions[:int(self.num_transition * train_fraction), :]),
        ),
                                       shuffle=self.shuffle,
                                       batch_size=batch_size,
                                       num_workers=num_workers)
        self.val_loader = DataLoader(TensorDataset(
            FLOAT(self.states[int(self.num_transition * train_fraction):, :]),
            FLOAT(self.actions[int(self.num_transition * train_fraction):, :]),
        ),
                                     shuffle=self.shuffle,
                                     batch_size=batch_size,
                                     num_workers=num_workers)

        self.log_info()