def __init__(self, mdir, device, time_limit, explorer=False): """ Build vae, rnn, controller and environment. """ self.explorer = explorer # Load controllers vae_file, rnn_file, ctrl_file = \ [join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']] if self.explorer: ctrl_file = join(mdir, 'exp', 'best.tar') assert exists(vae_file) and exists(rnn_file),\ "Either vae or mdrnn is untrained." vae_state, rnn_state = [ torch.load(fname, map_location={'cuda:0': str(device)}) for fname in (vae_file, rnn_file) ] for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)): print("Loading {} at epoch {} " "with test loss {}".format(m, s['epoch'], s['precision'])) self.vae = VAE(3, LSIZE).to(device) self.vae.load_state_dict(vae_state['state_dict']) # MDRNNCell self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device) self.mdrnn.load_state_dict( {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}) self.controller = Controller(LSIZE, RSIZE, ASIZE).to(device) # load controller if it was previously saved if exists(ctrl_file): ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)}) print("Loading Controller with reward {}".format( ctrl_state['reward'])) self.controller.load_state_dict(ctrl_state['state_dict']) self.env = gym.make('CarRacing-v0') self.device = device self.time_limit = time_limit self.mdrnn_notcell = MDRNN(LSIZE, ASIZE, RSIZE, 5) self.mdrnn_notcell.to(device) self.mdrnn_notcell.load_state_dict(rnn_state['state_dict'])
# Load VAE vae_file = join(args.originallogdir, 'vae', 'best.tar') assert exists(vae_file), "No trained VAE in the originallogdir..." state = torch.load(vae_file, map_location={'cuda:0': str(device)}) print("Loading VAE at epoch {} " "with test error {}".format(state['epoch'], state['precision'])) vae = VAE(3, LSIZE).to(device) vae.load_state_dict(state['state_dict']) vae_optimizer = torch.optim.Adam(vae.parameters()) vae_scheduler = ReduceLROnPlateau(vae_optimizer, 'min', factor=0.5, patience=5) # Load RNN rnn_dir = join(args.originallogdir, 'mdrnn') rnn_file = join(rnn_dir, 'best.tar') assert exists(rnn_file), 'No trained MDNRNN in the originallogdir...' mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5) mdrnn.to(device) mdrnn_optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9) mdrnn_scheduler = ReduceLROnPlateau(mdrnn_optimizer, 'min', factor=0.5, patience=5) rnn_state = torch.load(rnn_file, map_location={'cuda:0': str(device)}) print("Loading MDRNN at epoch {} " "with test error {}".format(rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) def collation_fn(rollouts): rollout_items = [[], [], [], [], []]
state = torch.load(vae_file) print("Loading VAE at epoch {} " "with test error {}".format( state['epoch'], state['precision'])) vae = VAE(3, LSIZE).to(device) vae.load_state_dict(state['state_dict']) # Loading model rnn_dir = join(args.logdir, 'mdrnn') rnn_file = join(rnn_dir, 'best.tar') if not exists(rnn_dir): mkdir(rnn_dir) mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5) mdrnn.to(device) optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) earlystopping = EarlyStopping('min', patience=30) if exists(rnn_file) and not args.noreload: rnn_state = torch.load(rnn_file) print("Loading MDRNN at epoch {} " "with test error {}".format( rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) optimizer.load_state_dict(rnn_state["optimizer"]) scheduler.load_state_dict(state['scheduler']) earlystopping.load_state_dict(state['earlystopping'])
assert exists(vae_file), "No trained VAE in the logdir..." state = torch.load(vae_file, map_location=device) print("Loading VAE at epoch {} " "with val error {}".format(state['epoch'], state['precision'])) vae = VAE(3, LSIZE).to(device) vae.load_state_dict(state['state_dict']) # Loading model rnn_dir = join(args.logdir, 'mdrnn') rnn_file = join(rnn_dir, 'best.tar') if not exists(rnn_dir): mkdir(rnn_dir) mdrnn = MDRNN(latents=LSIZE, actions=ASIZE, hiddens=RSIZE, gaussians=5) mdrnn.to(device) optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) earlystopping = EarlyStopping('min', patience=30) if exists(rnn_file) and not args.noreload: rnn_state = torch.load(rnn_file) print("Loading MDRNN at epoch {} " "with val error {}".format(rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) optimizer.load_state_dict(rnn_state["optimizer"]) scheduler.load_state_dict(state['scheduler']) earlystopping.load_state_dict(state['earlystopping'])
class RolloutGenerator(object): """ Utility to generate rollouts. Encapsulate everything that is needed to generate rollouts in the TRUE ENV using a controller with previously trained VAE and MDRNN. :attr vae: VAE model loaded from mdir/vae :attr mdrnn: MDRNN model loaded from mdir/mdrnn :attr controller: Controller, either loaded from mdir/ctrl or randomly initialized :attr env: instance of the CarRacing-v0 gym environment :attr device: device used to run VAE, MDRNN and Controller :attr time_limit: rollouts have a maximum of time_limit timesteps """ def __init__(self, mdir, device, time_limit, explorer=False): """ Build vae, rnn, controller and environment. """ self.explorer = explorer # Load controllers vae_file, rnn_file, ctrl_file = \ [join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl']] if self.explorer: ctrl_file = join(mdir, 'exp', 'best.tar') assert exists(vae_file) and exists(rnn_file),\ "Either vae or mdrnn is untrained." vae_state, rnn_state = [ torch.load(fname, map_location={'cuda:0': str(device)}) for fname in (vae_file, rnn_file) ] for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)): print("Loading {} at epoch {} " "with test loss {}".format(m, s['epoch'], s['precision'])) self.vae = VAE(3, LSIZE).to(device) self.vae.load_state_dict(vae_state['state_dict']) # MDRNNCell self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device) self.mdrnn.load_state_dict( {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}) self.controller = Controller(LSIZE, RSIZE, ASIZE).to(device) # load controller if it was previously saved if exists(ctrl_file): ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)}) print("Loading Controller with reward {}".format( ctrl_state['reward'])) self.controller.load_state_dict(ctrl_state['state_dict']) self.env = gym.make('CarRacing-v0') self.device = device self.time_limit = time_limit self.mdrnn_notcell = MDRNN(LSIZE, ASIZE, RSIZE, 5) self.mdrnn_notcell.to(device) self.mdrnn_notcell.load_state_dict(rnn_state['state_dict']) #####$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ # VERY LAZY. Copied from the other trainmdrnn file # from trainmdrnn import get_loss, to_latent def to_latent(self, obs, next_obs): """ Transform observations to latent space. :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE) :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE) :returns: (latent_obs, latent_next_obs) - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE) - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE) """ with torch.no_grad(): obs, next_obs = [ f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE, mode='bilinear', align_corners=True) for x in (obs, next_obs) ] (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [ self.vae(x)[1:] for x in (obs, next_obs) ] SEQ_LEN = 1 latent_obs, latent_next_obs = [ (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view( BSIZE, SEQ_LEN, LSIZE) for x_mu, x_logsigma in [( obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)] ] return latent_obs, latent_next_obs def mdrnn_exp_reward(self, latent_obs, action, reward, latent_next_obs, hidden): """ # REMOVE TERMINAL Compute losses. The loss that is computed is: (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) + BCE(terminal, logit_terminal)) / (LSIZE + 2) The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with LSIZE. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor :args reward: (BSIZE, SEQ_LEN) torch tensor :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ mus, sigmas, logpi, rs, ds, next_hidden = self.mdrnn( action, latent_obs, hidden) gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi) # bce = f.binary_cross_entropy_with_logits(ds, terminal) mse = f.mse_loss(rs, reward) loss = (gmm + mse) / (LSIZE + 2) return loss.squeeze().cpu().numpy() # def recon_error_reward(self, obs, hidden, obs_new): # print('recon_error_reward') # """Find out how good the reconstruction was. # Encoding the vae to get mu and the controller action is deterministic, so its fine to be duplicated # ??? maybe remove this and the above function because of unnecessary duplication # """ # # obs_new = torch.from_numpy(np.moveaxis(obs_new, 2, 0).copy()).unsqueeze(0).to(self.device).type(torch.cuda.FloatTensor) # # obs = obs.to(self.device).type(torch.cuda.FloatTensor) # _, latent_mu, _ = self.vae(obs) # action = self.controller(latent_mu, hidden[0]) # mus, sigmas, logpi, r, d, next_hidden = self.mdrnn(action, latent_mu, hidden) # print('mus.size()', mus.size()) # print('sigmas.size()', sigmas.size()) # print('logpi.size()', logpi.size()) # print('r.size()', r.size()) # print('d.size()', d.size()) # print('next_hidden.size() [0], [1]', next_hidden[0].size(), next_hidden[1].size()) # recon_x = self.vae.decoder(mus.squeeze()).type(torch.cuda.FloatTensor) # ??? this is just mu, right? Still a bit confused # print('obs_new.size()', obs_new.size()) # print('recon_x.size()', recon_x.size()) # # reward = -1*((recon_x - obs_new) ** 2).mean() # reward = -1*F.mse_loss(recon_x, obs_new).item() def rollout(self, params, render=False): """ Execute a rollout and return reward Load :params: into the controller and execute a single rollout. This is the main API of this class. :args params: parameters as a single 1D np array :returns: minus cumulative reward if ctrl mode, cumulative recon_error if exp mode """ # copy params into the controller if params is not None: load_parameters(params, self.controller) obs = self.env.reset() # This first render is required ! self.env.render() hidden = [torch.zeros(1, RSIZE).to(self.device) for _ in range(2)] cumulative = 0 i = 0 while True: obs = transform(obs).unsqueeze(0).to(self.device) # GET ACTION _, latent_mu, _ = self.vae(obs) action = self.controller(latent_mu, hidden[0]) _, _, _, _, _, next_hidden = self.mdrnn(action, latent_mu, hidden) action = action.squeeze().cpu().numpy() next_obs, reward, done, _ = self.env.step(action) if self.explorer: latent_obs, latent_next_obs = self.to_latent( obs.unsqueeze(0), transform(next_obs).unsqueeze(0).to(self.device)) action = torch.from_numpy(action).unsqueeze(0) latent_obs = latent_obs.to(self.device).squeeze().unsqueeze(0) latent_next_obs = latent_next_obs.to( self.device).squeeze().unsqueeze(0) action = action.to(self.device) reward = torch.from_numpy(np.array(reward)).unsqueeze(0).type( torch.cuda.FloatTensor) reward = self.mdrnn_exp_reward(latent_obs, action, reward, latent_next_obs, hidden) obs = next_obs hidden = next_hidden if render: self.env.render() cumulative += reward if done or i > self.time_limit: return -cumulative i += 1
# constants BSIZE = 8 SEQ_LEN = 999 epochs = 3000 torch.backends.cudnn.benchmark = True learning_rate = 1e-4 # Loading model rnn_dir = join(args.logdir, 'mdrnn') rnn_file = join(rnn_dir, 'best.tar') if not exists(rnn_dir): mkdir(rnn_dir) mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5) mdrnn = torch.nn.DataParallel(mdrnn, device_ids=[1, 2, 3, 4, 5, 6, 7]) mdrnn.cuda(1) #mdrnn.to(device) optimizer = optim.Adam(mdrnn.parameters(), lr=1e-4, betas=(0.9, 0.999)) # scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) # earlystopping = EarlyStopping('min', patience=30) if exists(rnn_file) and not args.noreload: rnn_state = torch.load(rnn_file) print("Loading MDRNN at epoch {} " "with test error {}".format(rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) optimizer.load_state_dict(rnn_state["optimizer"]) # scheduler.load_state_dict(state['scheduler'])
v_dataset_train, v_dataset_test, v_optimizer, v_scheduler, v_earlystopping, skip_train=True, max_train_epochs=1000) # hardmaru 는 rollout 10k, epoch=10 으로 vae 트레이닝을 끝냈다. # ctellec 은 rollout 1k, max epoch=1000 이나 줬는데, 100이면 충분한게 아니었나 싶다. # 3-1. MDN-RNN를 train할 (random) dataset 생성 m_dataset_train, m_dataset_test = make_mdrnn_dataset(rollout_root_dir) # 3-2. MDN-RNN 모델(M) 생성 m_model = MDRNN(LSIZE, ASIZE, RSIZE, 5).to(device) # pipaek : why gaussian=5? m_optimizer = torch.optim.RMSprop(m_model.parameters(), lr=1e-3, alpha=.9) # pipaek : hardmaru 는 lr=1e-4 이다. m_scheduler = ReduceLROnPlateau(m_optimizer, 'min', factor=0.5, patience=5) m_earlystopping = EarlyStopping('min', patience=30) # patience 30 -> 5 # 3-3. MDN-RNN 모델(M) 훈련 m_model_train_proc(rnn_dir, m_model, v_model, m_dataset_train, m_dataset_test, m_optimizer, m_scheduler, m_earlystopping, skip_train=True,
def train_mdrnn(logdir, traindir, epochs=10, testdir=None): BSIZE = 80 # maybe should change this back to their initial one of 16 noreload = False #Best model is not reloaded if specified SEQ_LEN = 32 epochs = int(epochs) testdir = testdir if testdir else traindir cuda = torch.cuda.is_available() torch.manual_seed(123) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Loading VAE vae_file = join(logdir, 'vae', 'best.tar') assert exists(vae_file), "No trained VAE in the logdir..." state = torch.load(vae_file) print("Loading VAE at epoch {} " "with test error {}".format( state['epoch'], state['precision'])) vae = VAE(3, LSIZE).to(device) vae.load_state_dict(state['state_dict']) # Loading model rnn_dir = join(logdir, 'mdrnn') rnn_file = join(rnn_dir, 'best.tar') if not exists(rnn_dir): mkdir(rnn_dir) mdrnn = MDRNN(LSIZE, ASIZE, RSIZE, 5) mdrnn.to(device) optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) earlystopping = EarlyStopping('min', patience=30) if exists(rnn_file) and not noreload: rnn_state = torch.load(rnn_file) print("Loading MDRNN at epoch {} " "with test error {}".format( rnn_state["epoch"], rnn_state["precision"])) mdrnn.load_state_dict(rnn_state["state_dict"]) optimizer.load_state_dict(rnn_state["optimizer"]) scheduler.load_state_dict(state['scheduler']) earlystopping.load_state_dict(state['earlystopping']) # Data Loading transform = transforms.Lambda( lambda x: np.transpose(x, (0, 3, 1, 2)) / 255) train_loader = DataLoader( RolloutSequenceDataset(traindir, SEQ_LEN, transform, buffer_size=30), batch_size=BSIZE, num_workers=8, shuffle=True) test_loader = DataLoader( RolloutSequenceDataset(testdir, SEQ_LEN, transform, train=False, buffer_size=10), batch_size=BSIZE, num_workers=8) def to_latent(obs, next_obs): """ Transform observations to latent space. :args obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE) :args next_obs: 5D torch tensor (BSIZE, SEQ_LEN, ASIZE, SIZE, SIZE) :returns: (latent_obs, latent_next_obs) - latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE) - next_latent_obs: 4D torch tensor (BSIZE, SEQ_LEN, LSIZE) """ with torch.no_grad(): obs, next_obs = [ f.upsample(x.view(-1, 3, SIZE, SIZE), size=RED_SIZE, mode='bilinear', align_corners=True) for x in (obs, next_obs)] (obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma) = [ vae(x)[1:] for x in (obs, next_obs)] latent_obs, latent_next_obs = [ (x_mu + x_logsigma.exp() * torch.randn_like(x_mu)).view(BSIZE, SEQ_LEN, LSIZE) for x_mu, x_logsigma in [(obs_mu, obs_logsigma), (next_obs_mu, next_obs_logsigma)]] return latent_obs, latent_next_obs def get_loss(latent_obs, action, reward, terminal, latent_next_obs): """ Compute losses. The loss that is computed is: (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) + BCE(terminal, logit_terminal)) / (LSIZE + 2) The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales approximately linearily with LSIZE. All losses are averaged both on the batch and the sequence dimensions (the two first dimensions). :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor :args reward: (BSIZE, SEQ_LEN) torch tensor :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor :returns: dictionary of losses, containing the gmm, the mse, the bce and the averaged loss. """ latent_obs, action,\ reward, terminal,\ latent_next_obs = [arr.transpose(1, 0) for arr in [latent_obs, action, reward, terminal, latent_next_obs]] mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs) gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi) bce = f.binary_cross_entropy_with_logits(ds, terminal) mse = f.mse_loss(rs, reward) loss = (gmm + bce + mse) / (LSIZE + 2) return dict(gmm=gmm, bce=bce, mse=mse, loss=loss) def data_pass(epoch, train): # pylint: disable=too-many-locals """ One pass through the data """ if train: mdrnn.train() loader = train_loader else: mdrnn.eval() loader = test_loader loader.dataset.load_next_buffer() cum_loss = 0 cum_gmm = 0 cum_bce = 0 cum_mse = 0 pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch)) for i, data in enumerate(loader): obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data] # transform obs latent_obs, latent_next_obs = to_latent(obs, next_obs) if train: losses = get_loss(latent_obs, action, reward, terminal, latent_next_obs) optimizer.zero_grad() losses['loss'].backward() optimizer.step() else: with torch.no_grad(): losses = get_loss(latent_obs, action, reward, terminal, latent_next_obs) cum_loss += losses['loss'].item() cum_gmm += losses['gmm'].item() cum_bce += losses['bce'].item() cum_mse += losses['mse'].item() pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} " "gmm={gmm:10.6f} mse={mse:10.6f}".format( loss=cum_loss / (i + 1), bce=cum_bce / (i + 1), gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1))) pbar.update(BSIZE) pbar.close() return cum_loss * BSIZE / len(loader.dataset) train = partial(data_pass, train=True) test = partial(data_pass, train=False) for e in range(epochs): cur_best = None train(e) test_loss = test(e) scheduler.step(test_loss) earlystopping.step(test_loss) is_best = not cur_best or test_loss < cur_best if is_best: cur_best = test_loss checkpoint_fname = join(rnn_dir, 'checkpoint.tar') save_checkpoint({ "state_dict": mdrnn.state_dict(), "optimizer": optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'earlystopping': earlystopping.state_dict(), "precision": test_loss, "epoch": e}, is_best, checkpoint_fname, rnn_file) if earlystopping.stop: print("End of Training because of early stopping at epoch {}".format(e)) break
def test_mdrnn_learning(self): num_epochs = 300 num_episodes = 400 batch_size = 200 action_dim = 2 seq_len = 5 state_dim = 2 simulated_num_gaussian = 2 mdrnn_num_gaussian = 2 simulated_hidden_size = 3 mdrnn_hidden_size = 10 mdrnn_hidden_layer = 1 adam_lr = 0.01 cur_state_mem = numpy.zeros((num_episodes, seq_len, state_dim)) next_state_mem = numpy.zeros((num_episodes, seq_len, state_dim)) action_mem = numpy.zeros((num_episodes, seq_len, action_dim)) reward_mem = numpy.zeros((num_episodes, seq_len)) terminal_mem = numpy.zeros((num_episodes, seq_len)) next_mus_mem = numpy.zeros( (num_episodes, seq_len, simulated_num_gaussian, state_dim)) swm = SimulatedWorldModel( action_dim=action_dim, state_dim=state_dim, num_gaussian=simulated_num_gaussian, lstm_num_layer=1, lstm_hidden_dim=simulated_hidden_size, ) actions = torch.eye(action_dim) for e in range(num_episodes): swm.init_hidden(batch_size=1) next_state = torch.randn((1, 1, state_dim)) for s in range(seq_len): cur_state = next_state action = torch.tensor( actions[numpy.random.randint(action_dim)]).view( 1, 1, action_dim) next_mus, reward = swm(action, cur_state) terminal = 0 if s == seq_len - 1: terminal = 1 next_pi = torch.ones( simulated_num_gaussian) / simulated_num_gaussian index = Categorical(next_pi).sample((1, )).long().item() next_state = next_mus[0, 0, index].view(1, 1, state_dim) print( "{} cur_state: {}, action: {}, next_state: {}, reward: {}, terminal: {}" .format(e, cur_state, action, next_state, reward, terminal)) print("next_pi: {}, sampled index: {}".format(next_pi, index)) print("next_mus:", next_mus, "\n") cur_state_mem[e, s, :] = cur_state.detach().numpy() action_mem[e, s, :] = action.numpy() reward_mem[e, s] = reward.detach().numpy() terminal_mem[e, s] = terminal next_state_mem[e, s, :] = next_state.detach().numpy() next_mus_mem[e, s, :, :] = next_mus.detach().numpy() mdrnn = MDRNN( latents=state_dim, actions=action_dim, gaussians=mdrnn_num_gaussian, hiddens=mdrnn_hidden_size, layers=mdrnn_hidden_layer, ) mdrnn.train() optimizer = torch.optim.Adam(mdrnn.parameters(), lr=adam_lr) num_batch = num_episodes // batch_size earlystopping = EarlyStopping('min', patience=30) cum_loss = [] cum_gmm = [] cum_bce = [] cum_mse = [] for e in range(num_epochs): for i in range(0, num_batch): mdrnn.init_hidden(batch_size=batch_size) optimizer.zero_grad() sample_indices = numpy.random.randint(num_episodes, size=batch_size) obs, action, reward, terminal, next_obs = \ cur_state_mem[sample_indices], \ action_mem[sample_indices], \ reward_mem[sample_indices], \ terminal_mem[sample_indices], \ next_state_mem[sample_indices] obs, action, reward, terminal, next_obs = \ torch.tensor(obs, dtype=torch.float), \ torch.tensor(action, dtype=torch.float), \ torch.tensor(reward, dtype=torch.float), \ torch.tensor(terminal, dtype=torch.float), \ torch.tensor(next_obs, dtype=torch.float) print("learning at epoch {} step {} best score {} counter {}". format(e, i, earlystopping.best, earlystopping.num_bad_epochs)) losses = self.get_loss(obs, action, reward, terminal, next_obs, state_dim, mdrnn) losses['loss'].backward() optimizer.step() cum_loss += [losses['loss'].item()] cum_gmm += [losses['gmm'].item()] cum_bce += [losses['bce'].item()] cum_mse += [losses['mse'].item()] print( "loss={loss:10.6f} bce={bce:10.6f} gmm={gmm:10.6f} mse={mse:10.6f}" .format( loss=losses['loss'], bce=losses['bce'], gmm=losses['gmm'], mse=losses['mse'], )) print( "cum loss={loss:10.6f} cum bce={bce:10.6f} cum gmm={gmm:10.6f} cum mse={mse:10.6f}" .format( loss=numpy.mean(cum_loss), bce=numpy.mean(cum_bce), gmm=numpy.mean(cum_gmm), mse=numpy.mean(cum_mse), )) print() earlystopping.step(numpy.mean(cum_loss[-num_batch:])) if numpy.mean(cum_loss[-num_batch:]) < -3. and earlystopping.stop: break assert numpy.mean(cum_loss[-num_batch:]) < -3. sample_indices = [0] mdrnn.init_hidden(batch_size=len(sample_indices)) mdrnn.eval() obs, action, reward, terminal, next_obs = \ cur_state_mem[sample_indices], \ action_mem[sample_indices], \ reward_mem[sample_indices], \ terminal_mem[sample_indices], \ next_state_mem[sample_indices] obs, action, reward, terminal, next_obs = \ torch.tensor(obs, dtype=torch.float), \ torch.tensor(action, dtype=torch.float), \ torch.tensor(reward, dtype=torch.float), \ torch.tensor(terminal, dtype=torch.float), \ torch.tensor(next_obs, dtype=torch.float) transpose_obs, transpose_action, transpose_reward, transpose_terminal, transpose_next_obs = \ self.transpose(obs, action, reward, terminal, next_obs) mus, sigmas, logpi, rs, ds = mdrnn(transpose_action, transpose_obs) pi = torch.exp(logpi) gl = gmm_loss(transpose_next_obs, mus, sigmas, logpi) print(gl) print()