class Agent():
    """
    Agent for training
    """
    def __init__(self):

        # Loading world model and vae
        vae_file, rnn_file, ctrl_file = \
            [join("./training", m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']]

        assert exists(vae_file) and exists(rnn_file),\
            "Either vae or mdrnn is untrained."

        vae_state, rnn_state = [
            torch.load(fname, map_location={'cuda:0': str(device)})
            for fname in (vae_file, rnn_file)
        ]

        for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)):
            logger.info("Loading {} at epoch {} "
                        "with test loss {}".format(m, s['epoch'],
                                                   s['precision']))

        self.vae = VAE(3, LSIZE).to(device).double()
        self.vae.load_state_dict(vae_state['state_dict'])

        self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device).double()
        self.mdrnn.load_state_dict(
            {k.strip('_l0'): v
             for k, v in rnn_state['state_dict'].items()})

        for p in self.vae.parameters():
            p.requires_grad = False
        for p in self.mdrnn.parameters():
            p.requires_grad = False

        self.net = Controller(LSIZE, RSIZE, ASIZE).to(device).double()
        # load controller if it was previously saved
        if exists(ctrl_file):
            ctrl_state = torch.load(ctrl_file,
                                    map_location={'cuda:0': str(device)})
            logger.info("Loading Controller with reward {}".format(
                ctrl_state['reward']))
            self.net.load_state_dict(ctrl_state['state_dict'])

    def select_action(self, state, hidden):

        with torch.no_grad():
            _, latent_mu, _ = self.vae(state)
            alpha, beta = self.net(latent_mu, hidden[0])[0]

        action = alpha / (alpha + beta)

        _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden)
        action = action.squeeze().cpu().numpy()
        return action, next_hidden

    def load_param(self):
        self.net.load_state_dict(torch.load('param/ppo_net_params.pkl'))
class Agent():
    """
    Agent for training
    """
    max_grad_norm = 0.5
    clip_param = 0.1  # epsilon in clipped loss
    ppo_epoch = 10
    buffer_capacity, batch_size = 1500, 128

    def __init__(self):
        
        
        # Loading world model and vae
        vae_file, rnn_file, ctrl_file = \
            [join("./training", m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']]

        assert exists(vae_file) and exists(rnn_file),\
            "Either vae or mdrnn is untrained."

        vae_state, rnn_state = [
            torch.load(fname, map_location={'cuda:0': str(device)})
            for fname in (vae_file, rnn_file)]

        for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)):
            logger.info("Loading {} at epoch {} "
                  "with test loss {}".format(
                      m, s['epoch'], s['precision']))

        self.vae = VAE(3, LSIZE).to(device).double()
        self.vae.load_state_dict(vae_state['state_dict'])

        self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device).double()
        self.mdrnn.load_state_dict(
            {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()})
    
        for p in self.vae.parameters():
            p.requires_grad = False
        for p in self.mdrnn.parameters():
            p.requires_grad = False
        
        
        self.net = Controller(LSIZE, RSIZE, ASIZE).to(device).double()
        # load controller if it was previously saved
        if exists(ctrl_file):
            ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)})
            logger.info("Loading Controller with reward {}".format(
                ctrl_state['reward']))
            self.net.load_state_dict(ctrl_state['state_dict'])
        
        self.training_step = 0
       
        self.buffer = np.empty(self.buffer_capacity, dtype=transition)
        self.counter = 0

        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3)

    def select_action(self, state, hidden):
        
        with torch.no_grad():
            _, latent_mu, _ = self.vae(state)
            alpha, beta = self.net(latent_mu, hidden[0])[0]
        
        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        a_logp = a_logp.item()
        _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden)
        
        return action.squeeze().cpu().numpy(), a_logp, latent_mu, next_hidden

    def save_param(self):
        torch.save(self.net.state_dict(), 'param/ppo_net_params.pkl')

    def store(self, transition):
        self.buffer[self.counter] = transition
        self.counter += 1
        if self.counter == self.buffer_capacity:
            self.counter = 0
            return True
        else:
            return False

    def update(self):
        self.training_step += 1
        mu = torch.tensor(self.buffer['mu'], dtype=torch.double).to(device)
        hidden = torch.tensor(self.buffer['hidden'], dtype=torch.double).to(device).view(-1, RSIZE)
        a = torch.tensor(self.buffer['a'], dtype=torch.double).to(device)
        r = torch.tensor(self.buffer['r'], dtype=torch.double).to(device).view(-1, 1)
        mu_ = torch.tensor(self.buffer['mu_'], dtype=torch.double).to(device)
        hidden_ = torch.tensor(self.buffer['hidden_'], dtype=torch.double).to(device).view(-1, RSIZE)

        old_a_logp = torch.tensor(self.buffer['a_logp'], dtype=torch.double).to(device).view(-1, 1)

        with torch.no_grad():
            target_v = r + args.gamma * self.net(mu_, hidden_)[1]
            adv = target_v - self.net(mu, hidden)[1]
            # adv = (adv - adv.mean()) / (adv.std() + 1e-8)

        for _ in range(self.ppo_epoch):
            for index in BatchSampler(SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False):

                alpha, beta = self.net(mu[index], hidden[index])[0]
                dist = Beta(alpha, beta)
                a_logp = dist.log_prob(a[index]).sum(dim=1, keepdim=True)
                ratio = torch.exp(a_logp - old_a_logp[index])

                surr1 = ratio * adv[index]
                surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv[index]
                action_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.smooth_l1_loss(self.net(mu[index], hidden[index])[1], target_v[index])
                loss = action_loss + 2. * value_loss

                self.optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
                self.optimizer.step()