class Agent(): """ Agent for training """ def __init__(self): # Loading world model and vae vae_file, rnn_file, ctrl_file = \ [join("./training", m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']] assert exists(vae_file) and exists(rnn_file),\ "Either vae or mdrnn is untrained." vae_state, rnn_state = [ torch.load(fname, map_location={'cuda:0': str(device)}) for fname in (vae_file, rnn_file) ] for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)): logger.info("Loading {} at epoch {} " "with test loss {}".format(m, s['epoch'], s['precision'])) self.vae = VAE(3, LSIZE).to(device).double() self.vae.load_state_dict(vae_state['state_dict']) self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device).double() self.mdrnn.load_state_dict( {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}) for p in self.vae.parameters(): p.requires_grad = False for p in self.mdrnn.parameters(): p.requires_grad = False self.net = Controller(LSIZE, RSIZE, ASIZE).to(device).double() # load controller if it was previously saved if exists(ctrl_file): ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)}) logger.info("Loading Controller with reward {}".format( ctrl_state['reward'])) self.net.load_state_dict(ctrl_state['state_dict']) def select_action(self, state, hidden): with torch.no_grad(): _, latent_mu, _ = self.vae(state) alpha, beta = self.net(latent_mu, hidden[0])[0] action = alpha / (alpha + beta) _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) action = action.squeeze().cpu().numpy() return action, next_hidden def load_param(self): self.net.load_state_dict(torch.load('param/ppo_net_params.pkl'))
class Agent(): """ Agent for training """ max_grad_norm = 0.5 clip_param = 0.1 # epsilon in clipped loss ppo_epoch = 10 buffer_capacity, batch_size = 1500, 128 def __init__(self): # Loading world model and vae vae_file, rnn_file, ctrl_file = \ [join("./training", m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']] assert exists(vae_file) and exists(rnn_file),\ "Either vae or mdrnn is untrained." vae_state, rnn_state = [ torch.load(fname, map_location={'cuda:0': str(device)}) for fname in (vae_file, rnn_file)] for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)): logger.info("Loading {} at epoch {} " "with test loss {}".format( m, s['epoch'], s['precision'])) self.vae = VAE(3, LSIZE).to(device).double() self.vae.load_state_dict(vae_state['state_dict']) self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device).double() self.mdrnn.load_state_dict( {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}) for p in self.vae.parameters(): p.requires_grad = False for p in self.mdrnn.parameters(): p.requires_grad = False self.net = Controller(LSIZE, RSIZE, ASIZE).to(device).double() # load controller if it was previously saved if exists(ctrl_file): ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)}) logger.info("Loading Controller with reward {}".format( ctrl_state['reward'])) self.net.load_state_dict(ctrl_state['state_dict']) self.training_step = 0 self.buffer = np.empty(self.buffer_capacity, dtype=transition) self.counter = 0 self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3) def select_action(self, state, hidden): with torch.no_grad(): _, latent_mu, _ = self.vae(state) alpha, beta = self.net(latent_mu, hidden[0])[0] dist = Beta(alpha, beta) action = dist.sample() a_logp = dist.log_prob(action).sum(dim=1) a_logp = a_logp.item() _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) return action.squeeze().cpu().numpy(), a_logp, latent_mu, next_hidden def save_param(self): torch.save(self.net.state_dict(), 'param/ppo_net_params.pkl') def store(self, transition): self.buffer[self.counter] = transition self.counter += 1 if self.counter == self.buffer_capacity: self.counter = 0 return True else: return False def update(self): self.training_step += 1 mu = torch.tensor(self.buffer['mu'], dtype=torch.double).to(device) hidden = torch.tensor(self.buffer['hidden'], dtype=torch.double).to(device).view(-1, RSIZE) a = torch.tensor(self.buffer['a'], dtype=torch.double).to(device) r = torch.tensor(self.buffer['r'], dtype=torch.double).to(device).view(-1, 1) mu_ = torch.tensor(self.buffer['mu_'], dtype=torch.double).to(device) hidden_ = torch.tensor(self.buffer['hidden_'], dtype=torch.double).to(device).view(-1, RSIZE) old_a_logp = torch.tensor(self.buffer['a_logp'], dtype=torch.double).to(device).view(-1, 1) with torch.no_grad(): target_v = r + args.gamma * self.net(mu_, hidden_)[1] adv = target_v - self.net(mu, hidden)[1] # adv = (adv - adv.mean()) / (adv.std() + 1e-8) for _ in range(self.ppo_epoch): for index in BatchSampler(SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False): alpha, beta = self.net(mu[index], hidden[index])[0] dist = Beta(alpha, beta) a_logp = dist.log_prob(a[index]).sum(dim=1, keepdim=True) ratio = torch.exp(a_logp - old_a_logp[index]) surr1 = ratio * adv[index] surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv[index] action_loss = -torch.min(surr1, surr2).mean() value_loss = F.smooth_l1_loss(self.net(mu[index], hidden[index])[1], target_v[index]) loss = action_loss + 2. * value_loss self.optimizer.zero_grad() loss.backward() # nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.optimizer.step()