def get_mdrnn_cell(rnn_dir): rnn_file = os.path.join(rnn_dir, 'best.tar') assert os.path.exists(rnn_file) state = torch.load(rnn_file) mdrnn_cell = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device) mdrnn_cell.load_state_dict( {k.strip('_l0'): v for k, v in state['state_dict'].items()}) return mdrnn_cell
def __init__(self, directory): vae_file = join(directory, 'vae', 'best.tar') rnn_file = join(directory, 'mdrnn', 'best.tar') assert exists(vae_file), "No VAE model in the directory..." assert exists(rnn_file), "No MDRNN model in the directory..." # spaces self.action_space = spaces.Box(np.array([-1, 0, 0]), np.array([1, 1, 1])) self.observation_space = spaces.Box(low=0, high=255, shape=(RED_SIZE, RED_SIZE, 3), dtype=np.uint8) # load VAE vae = VAE(3, LSIZE) vae_state = torch.load(vae_file, map_location=lambda storage, location: storage) print("Loading VAE at epoch {}, " "with test error {}...".format(vae_state['epoch'], vae_state['precision'])) vae.load_state_dict(vae_state['state_dict']) self._decoder = vae.decoder # load MDRNN self._rnn = MDRNNCell(32, 3, RSIZE, 5) rnn_state = torch.load(rnn_file, map_location=lambda storage, location: storage) print("Loading MDRNN at epoch {}, " "with test error {}...".format(rnn_state['epoch'], rnn_state['precision'])) rnn_state_dict = { k.strip('_l0'): v for k, v in rnn_state['state_dict'].items() } self._rnn.load_state_dict(rnn_state_dict) # init state self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # obs self._obs = None self._visual_obs = None # rendering self.monitor = None self.figure = None
def __init__(self, directory): vae_file = join(directory, 'vae', 'best.tar') rnn_file = join(directory, 'mdrnn', 'best.tar') assert exists(vae_file), "No VAE model in the directory..." assert exists(rnn_file), "No MDRNN model in the directory..." # spaces self.action_space = spaces.Box(np.array([-1, 0, 0]), np.array([1, 1, 1])) self.observation_space = spaces.Box(low=0, high=255, shape=(RED_SIZE, RED_SIZE, 3), dtype=np.uint8) # load VAE vae = VAE(3, LSIZE) vae_state = torch.load(vae_file, map_location=lambda storage, location: storage) print("Loading VAE at epoch {}, " "with test error {}...".format( vae_state['epoch'], vae_state['precision'])) vae.load_state_dict(vae_state['state_dict']) self._decoder = vae.decoder # load MDRNN self._rnn = MDRNNCell(32, 3, RSIZE, 5) rnn_state = torch.load(rnn_file, map_location=lambda storage, location: storage) print("Loading MDRNN at epoch {}, " "with test error {}...".format( rnn_state['epoch'], rnn_state['precision'])) rnn_state_dict = {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()} self._rnn.load_state_dict(rnn_state_dict) # init state self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # obs self._obs = None self._visual_obs = None # rendering self.monitor = None self.figure = None
class SimulatedDoom(gym.Env): # pylint: disable=too-many-instance-attributes """ Simulated Car Racing. Gym environment using learnt VAE and MDRNN to simulate the CarRacing-v0 environment. :args directory: directory from which the vae and mdrnn are loaded. """ def __init__(self, directory): vae_file = join(directory, 'vae', 'best.tar') rnn_file = join(directory, 'mdrnn', 'best.tar') assert exists(vae_file), "No VAE model in the directory..." assert exists(rnn_file), "No MDRNN model in the directory..." # spaces self.action_space = spaces.Box(low=-1, high=1, shape=(1, )) self.observation_space = spaces.Box(low=0, high=255, shape=(RED_SIZE, RED_SIZE, 3), dtype=np.uint8) # load VAE vae = VAE(3, LSIZE) vae_state = torch.load(vae_file, map_location=lambda storage, location: storage) print("Loading VAE at epoch {}, " "with test error {}...".format(vae_state['epoch'], vae_state['precision'])) vae.load_state_dict(vae_state['state_dict']) self._decoder = vae.decoder # load MDRNN self._rnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5) rnn_state = torch.load(rnn_file, map_location=lambda storage, location: storage) print("Loading MDRNN at epoch {}, " "with test error {}...".format(rnn_state['epoch'], rnn_state['precision'])) rnn_state_dict = { k.strip('_l0'): v for k, v in rnn_state['state_dict'].items() } self._rnn.load_state_dict(rnn_state_dict) # init state self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # obs self._obs = None self._visual_obs = None # rendering self.monitor = None self.figure = None def reset(self): """ Resetting """ import matplotlib.pyplot as plt self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # also reset monitor if not self.monitor: self.figure = plt.figure() self.monitor = plt.imshow( np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8)) def step(self, action): """ One step forward """ with torch.no_grad(): action = torch.Tensor(action).unsqueeze(0) mu, sigma, pi, r, d, n_h = self._rnn(action, self._lstate, self._hstate) pi = pi.squeeze() mixt = Categorical(torch.exp(pi)).sample().item() self._lstate = mu[:, mixt, :] # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :]) self._hstate = n_h self._obs = self._decoder(self._lstate) np_obs = self._obs.numpy() np_obs = np.clip(np_obs, 0, 1) * 255 np_obs = np.transpose(np_obs, (0, 2, 3, 1)) np_obs = np_obs.squeeze() np_obs = np_obs.astype(np.uint8) self._visual_obs = np_obs return np_obs, r.item(), d.item() > 0 def render(self): # pylint: disable=arguments-differ """ Rendering """ import matplotlib.pyplot as plt if not self.monitor: self.figure = plt.figure() self.monitor = plt.imshow( np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8)) self.monitor.set_data(self._visual_obs) plt.pause(.01)
class SimulatedCarracing(gym.Env): # pylint: disable=too-many-instance-attributes """ Simulated Car Racing. Gym environment using learnt VAE and MDRNN to simulate the CarRacing-v0 environment. :args directory: directory from which the vae and mdrnn are loaded. """ def __init__(self, directory): vae_file = join(directory, 'vae', 'best.tar') rnn_file = join(directory, 'mdrnn', 'best.tar') assert exists(vae_file), "No VAE model in the directory..." assert exists(rnn_file), "No MDRNN model in the directory..." # spaces self.action_space = spaces.Box(np.array([-1, 0, 0]), np.array([1, 1, 1])) self.observation_space = spaces.Box(low=0, high=255, shape=(RED_SIZE, RED_SIZE, 3), dtype=np.uint8) # load VAE vae = VAE(3, LSIZE) vae_state = torch.load(vae_file, map_location=lambda storage, location: storage) print("Loading VAE at epoch {}, " "with test error {}...".format( vae_state['epoch'], vae_state['precision'])) vae.load_state_dict(vae_state['state_dict']) self._decoder = vae.decoder # load MDRNN self._rnn = MDRNNCell(32, 3, RSIZE, 5) rnn_state = torch.load(rnn_file, map_location=lambda storage, location: storage) print("Loading MDRNN at epoch {}, " "with test error {}...".format( rnn_state['epoch'], rnn_state['precision'])) rnn_state_dict = {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()} self._rnn.load_state_dict(rnn_state_dict) # init state self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # obs self._obs = None self._visual_obs = None # rendering self.monitor = None self.figure = None def reset(self): """ Resetting """ import matplotlib.pyplot as plt self._lstate = torch.randn(1, LSIZE) self._hstate = 2 * [torch.zeros(1, RSIZE)] # also reset monitor if not self.monitor: self.figure = plt.figure() self.monitor = plt.imshow( np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8)) def step(self, action): """ One step forward """ with torch.no_grad(): action = torch.Tensor(action).unsqueeze(0) mu, sigma, pi, r, d, n_h = self._rnn(action, self._lstate, self._hstate) pi = pi.squeeze() mixt = Categorical(torch.exp(pi)).sample().item() self._lstate = mu[:, mixt, :] # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :]) self._hstate = n_h self._obs = self._decoder(self._lstate) np_obs = self._obs.numpy() np_obs = np.clip(np_obs, 0, 1) * 255 np_obs = np.transpose(np_obs, (0, 2, 3, 1)) np_obs = np_obs.squeeze() np_obs = np_obs.astype(np.uint8) self._visual_obs = np_obs return np_obs, r.item(), d.item() > 0 def render(self): # pylint: disable=arguments-differ """ Rendering """ import matplotlib.pyplot as plt if not self.monitor: self.figure = plt.figure() self.monitor = plt.imshow( np.zeros((RED_SIZE, RED_SIZE, 3), dtype=np.uint8)) self.monitor.set_data(self._visual_obs) plt.pause(.01)
def create_memory(env, agent, memory, steps, config, episodes=100): print("Create buffer with size {} steps ".format(steps)) score = 0 average_score = 0 average_steps = 0 agent.eval() LSIZE = 200 ASIZE = 1 RSIZE = 256 mdir = "15.02_l200_RGB" m = "" vae_file, rnn_file = [join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn']] vae_state, rnn_state = [ torch.load(fname, map_location={'cuda:0': str(config["device"])}) for fname in (vae_file, rnn_file) ] for m, s in (('VAE', vae_state), ('MDRNN', rnn_state)): print("Loading {} at epoch {} " "with test loss {}".format(m, s['epoch'], s['precision'])) vae = VAE(3, LSIZE).to(config["device"]) vae.load_state_dict(vae_state['state_dict']) mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(config["device"]) mdrnn.load_state_dict( {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()}) for i in range(episodes): # env = gym.wrappers.Monitor(env,str(config["locexp"])+"/vid/{}/{}".format(steps, i), video_callable=lambda episode_id: True,force=True) env.seed(i) state, obs = env.reset("mediumClassic") print(obs) episode_reward = 0 index = memory.idx hidden = [torch.zeros(1, RSIZE).to(config["device"]) for _ in range(2)] for t in range(125): state_tensor = state.clone().detach().type( torch.cuda.FloatTensor).div_(255) action = agent.act(state_tensor) action_rnn = torch.as_tensor(action, device=config["device"]).type( torch.int).unsqueeze(0).unsqueeze(0) # print(action_rnn) # print(action_rnn.shape) states = obs states = torch.as_tensor(states, device=config["device"]).unsqueeze(0) states = states.type(torch.float32).div_(255) _, latent_mu, _ = vae(states) # print(latent_mu.shape) _, _, _, _, _, next_hidden = mdrnn(action_rnn, latent_mu, hidden) # print(next_hidden[0].shape) # print(next_hidden[1].shape) # sys.exit() next_state, reward, done, next_obs = env.step(action) if t != 124: done_no_max = done else: done_no_max = False memory.add(obs, hidden[0], hidden[1], action, next_obs, next_hidden[0], next_hidden[1], done, done_no_max) if memory.idx % 500 == 0: path = "pacman_expert_memory-{}".format(memory.idx) print("save memory to ", path) memory.save_memory(path) if memory.idx >= steps: return hidden = next_hidden state = next_state obs = next_obs score += reward episode_reward += reward if done or t == 124: if episode_reward < 600: memory.idx = index print("Episode_reward {} and memory idx {}".format( episode_reward, memory.idx)) break
def data_pass(epoch, train, include_reward): # pylint: disable=too-many-locals """ One pass through the data """ if train: mdrnn.train() loader = train_loader else: mdrnn.eval() loader = test_loader loader.dataset.load_next_buffer() cum_loss = 0 cum_gmm = 0 cum_bce = 0 cum_mse = 0 pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch)) for i, data in enumerate(loader): obs, action, reward, terminal, next_obs = [ arr.to(device) for arr in data ] # transform obs latent_obs, latent_next_obs = to_latent(obs, next_obs) if train: losses = get_loss(latent_obs, action, reward, terminal, latent_next_obs, include_reward) mdrnnCell = MDRNNCell(LSIZE, ASIZE, RSIZE, 5) rnn_state_dict = { k.strip('_l0'): v for k, v in mdrnn.state_dict().items() } mdrnnCell.load_state_dict(rnn_state_dict) interim_policy = train_C_given_M(mdrnnCell=mdrnnCell, latent_dim=LSIZE, hidden_dim=RSIZE, action_dim=ASIZE) interim_policy_ope = ope(mdrnnCell, interim_policy) loss = losses['loss'] - interim_policy_ope.squeeze(dim=0) optimizer.zero_grad() loss.backward() optimizer.step() else: with torch.no_grad(): losses = get_loss(latent_obs, action, reward, terminal, latent_next_obs, include_reward) # cum_loss += losses['loss'].item() cum_loss += loss.item() cum_gmm += losses['gmm'].item() cum_bce += losses['bce'].item() cum_mse += losses['mse'].item() if hasattr(losses['mse'], 'item') else \ losses['mse'] pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} " "gmm={gmm:10.6f} mse={mse:10.6f}".format( loss=cum_loss / (i + 1), bce=cum_bce / (i + 1), gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1))) pbar.update(BSIZE) pbar.close() return cum_loss * BSIZE / len(loader.dataset)