def transform_mdp_to_bamdp_rollouts(vae, args, obs, actions, rewards, next_obs, terminals): ''' :param vae: :param args: :param obs: shape (trajectory_len, n_rollouts, dim) :param actions: :param rewards: :param next_obs: :param terminals: :return: ''' # augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size)) augmented_obs = ptu.zeros((obs.shape[0], obs.shape[1], obs.shape[2] + 2 * args.task_embedding_size)) # augmented_next_obs = ptu.zeros((obs.shape[0], obs.shape[1] + 2 * args.task_embedding_size)) augmented_next_obs = ptu.zeros( (obs.shape[0], obs.shape[1], obs.shape[2] + 2 * args.task_embedding_size)) if args.belief_rewards: belief_rewards = ptu.zeros_like(rewards) else: belief_rewards = None with torch.no_grad(): # _, mean, logvar, hidden_state = vae.encoder.prior(batch_size=1) _, mean, logvar, hidden_state = vae.encoder.prior( batch_size=obs.shape[1]) augmented_obs[0, :, :] = torch.cat((obs[0], mean[0], logvar[0]), dim=-1) for step in range(args.trajectory_len): # update encoding _, mean, logvar, hidden_state = utl.update_encoding( encoder=vae.encoder, obs=next_obs[step].unsqueeze(0), action=actions[step].unsqueeze(0), reward=rewards[step].unsqueeze(0), done=terminals[step].unsqueeze(0), hidden_state=hidden_state) # augment data augmented_next_obs[step, :, :] = torch.cat( (next_obs[step], mean, logvar), dim=-1) if args.belief_rewards: with torch.no_grad(): belief_rewards[step, :, :] = vae.compute_belief_reward( mean.unsqueeze(dim=0), logvar.unsqueeze(dim=0), obs[step].unsqueeze(dim=0), next_obs[step].unsqueeze(dim=0), actions[step].unsqueeze(dim=0)) augmented_obs[1:, :, :] = augmented_next_obs[:-1, :, :].clone() return augmented_obs, belief_rewards, augmented_next_obs
def eval_vae(dataset, vae, args): num_tasks = len(dataset) reward_preds = np.zeros((num_tasks, args.trajectory_len)) rewards = np.zeros((num_tasks, args.trajectory_len)) random_tasks = np.random.choice( len(dataset), NUM_EVAL_TASKS) # which trajectory to evaluate for task_idx, task in enumerate(random_tasks): traj_idx_random = np.random.choice( dataset[task][0].shape[1]) # which trajectory to evaluate # traj_idx_random = np.random.choice(np.min([d[0].shape[1] for d in dataset])) # get prior parameters with torch.no_grad(): task_sample, task_mean, task_logvar, hidden_state = vae.encoder.prior( batch_size=1) for step in range(args.trajectory_len): # update encoding task_sample, task_mean, task_logvar, hidden_state = utl.update_encoding( encoder=vae.encoder, obs=ptu.FloatTensor( dataset[task][3][step, traj_idx_random]).unsqueeze(0), action=ptu.FloatTensor( dataset[task][1][step, traj_idx_random]).unsqueeze(0), reward=ptu.FloatTensor( dataset[task][2][step, traj_idx_random]).unsqueeze(0), done=ptu.FloatTensor( dataset[task][4][step, traj_idx_random]).unsqueeze(0), hidden_state=hidden_state) rewards[task_idx, step] = dataset[task][2][step, traj_idx_random].item() reward_preds[task_idx, step] = ptu.get_numpy( vae.reward_decoder( task_sample.unsqueeze(0), ptu.FloatTensor(dataset[task][3][ step, traj_idx_random]).unsqueeze(0).unsqueeze(0), ptu.FloatTensor(dataset[task][0][ step, traj_idx_random]).unsqueeze(0).unsqueeze(0), ptu.FloatTensor(dataset[task][1][ step, traj_idx_random]).unsqueeze(0).unsqueeze(0))[0, 0]) return rewards, reward_preds
def eval_vae(dataset, vae, args): num_tasks = len(dataset) reward_preds = np.zeros((num_tasks, args.trajectory_len)) rewards = np.zeros((num_tasks, args.trajectory_len)) random_tasks = np.random.choice(len(dataset), 10) # which trajectory to evaluate states, actions = get_heatmap_params() state_preds = np.zeros((num_tasks, states.shape[0])) for task_idx, task in enumerate(random_tasks): traj_idx_random = np.random.choice(dataset[0][0].shape[1]) # which trajectory to evaluate # get prior parameters with torch.no_grad(): task_sample, task_mean, task_logvar, hidden_state = vae.encoder.prior(batch_size=1) for step in range(args.trajectory_len): # update encoding task_sample, task_mean, task_logvar, hidden_state = utl.update_encoding( encoder=vae.encoder, obs=ptu.FloatTensor(dataset[task][3][step, traj_idx_random]).unsqueeze(0), action=ptu.FloatTensor(dataset[task][1][step, traj_idx_random]).unsqueeze(0), reward=ptu.FloatTensor(dataset[task][2][step, traj_idx_random]).unsqueeze(0), done=ptu.FloatTensor(dataset[task][4][step, traj_idx_random]).unsqueeze(0), hidden_state=hidden_state ) rewards[task_idx, step] = dataset[task][2][step, traj_idx_random].item() reward_preds[task_idx, step] = ptu.get_numpy( vae.reward_decoder(task_sample.unsqueeze(0), ptu.FloatTensor(dataset[task][3][step, traj_idx_random]).unsqueeze(0).unsqueeze(0), ptu.FloatTensor(dataset[task][0][step, traj_idx_random]).unsqueeze(0).unsqueeze(0), ptu.FloatTensor(dataset[task][1][step, traj_idx_random]).unsqueeze(0).unsqueeze(0))[0, 0]) states, actions = get_heatmap_params() prediction = ptu.get_numpy(vae.state_decoder(task_sample.expand((1, 30, task_sample.shape[-1])), ptu.FloatTensor(states).unsqueeze(0), ptu.FloatTensor(actions).unsqueeze(0))).squeeze() for i in range(30): state_preds[task_idx, i] = 1 if np.linalg.norm(prediction[i, :]) > 1 else 0 return rewards, reward_preds, state_preds, random_tasks
def load_and_render(self, load_iter): #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models') #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models') save_path = os.path.join( '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25', 'models') self.policy.actor_critic = torch.load( os.path.join(save_path, "policy{0}.pt".format(load_iter))) self.vae.encoder = torch.load( os.path.join(save_path, "encoder{0}.pt").format(load_iter)) args = self.args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_processes = 1 num_episodes = 100 num_steps = 1999 #import pdb; pdb.set_trace() # initialise environments envs = make_vec_envs( env_name=args.env_name, seed=args.seed, num_processes=num_processes, # 1 gamma=args.policy_gamma, log_dir=args.agent_log_dir, device=device, allow_early_resets=False, episodes_per_task=self.args.max_rollouts_per_task, obs_rms=None, ret_rms=None, ) # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior( num_processes) for episode_idx in range(num_episodes): (prev_obs_raw, prev_obs_normalised) = envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) for step_idx in range(num_steps): with torch.no_grad(): _, action, _ = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step(envs, action) # render envs.venv.venv.envs[0].env.env.env.env.render() # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw if done[0]: break
def train(self): """ Given some stream of environments and a logger (tensorboard), (meta-)trains the policy. """ start_time = time.time() # reset environments (prev_obs_raw, prev_obs_normalised) = self.envs.reset() prev_obs_raw = prev_obs_raw.to(device) prev_obs_normalised = prev_obs_normalised.to(device) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw) self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised) self.policy_storage.to(device) vae_is_pretrained = False for self.iter_idx in range(self.args.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) # compute latent embedding (will return prior if current trajectory is empty) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # check if we flushed the policy storage assert len(self.policy_storage.latent_mean) == 0 # add this initial hidden state to the policy storage self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action, action_log_prob = utl.select_action( args=self.args, policy=self.policy, obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # observe reward and next obs (next_obs_raw, next_obs_normalised), ( rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action) tasks = torch.FloatTensor([info['task'] for info in infos]).to(device) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_obs_raw, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_stochasticity_in_latent): self.vae.rollout_storage.insert(prev_obs_raw.clone(), action.detach().clone(), next_obs_raw.clone(), rew_raw.clone(), done.clone(), tasks.clone()) # add the obs before reset to the policy storage # (only used to recompute embeddings if rlloss is backpropagated through encoder) self.policy_storage.next_obs_raw[step] = next_obs_raw.clone() self.policy_storage.next_obs_normalised[ step] = next_obs_normalised.clone() # reset environments that are done done_indices = np.argwhere( done.cpu().detach().flatten()).flatten() if len(done_indices) == self.args.num_processes: [next_obs_raw, next_obs_normalised] = self.envs.reset() if not self.args.sample_embeddings: latent_sample = latent_sample else: for i in done_indices: [next_obs_raw[i], next_obs_normalised[i]] = self.envs.reset(index=i) if not self.args.sample_embeddings: latent_sample[i] = latent_sample[i] # # add experience to policy buffer self.policy_storage.insert( obs_raw=next_obs_raw, obs_normalised=next_obs_normalised, actions=action, action_log_probs=action_log_prob, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0).detach(), latent_sample=latent_sample.detach(), latent_mean=latent_mean.detach(), latent_logvar=latent_logvar.detach(), ) prev_obs_normalised = next_obs_normalised prev_obs_raw = next_obs_raw self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > 0 and not vae_is_pretrained: for _ in range(self.args.pretrain_len): self.vae.compute_vae_loss(update=True) vae_is_pretrained = True # otherwise do the normal update (policy + vae) else: train_stats = self.update( obs=prev_obs_normalised if self.args.norm_obs_for_policy else prev_obs_raw, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [action, action_log_prob, value] if train_stats is not None: self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update()
def evaluate(args, policy, ret_rms, iter_idx, tasks, encoder=None, num_episodes=None): env_name = args.env_name if hasattr(args, 'test_env_name'): env_name = args.test_env_name if num_episodes is None: num_episodes = args.max_rollouts_per_task num_processes = args.num_processes # --- set up the things we want to log --- # for each process, we log the returns during the first, second, ... episode # (such that we have a minimum of [num_episodes]; the last column is for # any overflow and will be discarded at the end, because we need to wait until # all processes have at least [num_episodes] many episodes) returns_per_episode = torch.zeros( (num_processes, num_episodes + 1)).to(device) # --- initialise environments and latents --- envs = make_vec_envs( env_name, seed=args.seed * 42 + iter_idx, num_processes=num_processes, gamma=args.policy_gamma, device=device, rank_offset=num_processes + 1, # to use diff tmp folders than main processes episodes_per_task=num_episodes, normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms, tasks=tasks, add_done_info=args.max_rollouts_per_task > 1, ) num_steps = envs._max_episode_steps # reset environments state, belief, task = utl.reset_env(envs, args) # this counts how often an agent has done the same task already task_count = torch.zeros(num_processes).long().to(device) if encoder is not None: # reset latent state to prior latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior( num_processes) else: latent_sample = latent_mean = latent_logvar = hidden_state = None for episode_idx in range(num_episodes): for step_idx in range(num_steps): with torch.no_grad(): _, action = utl.select_action(args=args, policy=policy, state=state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, deterministic=True) # observe reward and next obs [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( envs, action, args) done_mdp = [info['done_mdp'] for info in infos] if encoder is not None: # update the hidden state latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=encoder, next_obs=state, action=action, reward=rew_raw, done=None, hidden_state=hidden_state) # add rewards returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1) for i in np.argwhere(done_mdp).flatten(): # count task up, but cap at num_episodes + 1 task_count[i] = min(task_count[i] + 1, num_episodes) # zero-indexed, so no +1 if np.sum(done) > 0: done_indices = np.argwhere(done.flatten()).flatten() state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state) envs.close() return returns_per_episode[:, :num_episodes]
def train(self): """ Main Meta-Training loop """ start_time = time.time() # reset environments prev_state, belief, task = utl.reset_env(self.envs, self.args) # insert initial observation / embeddings to rollout storage self.policy_storage.prev_state[0].copy_(prev_state) # log once before training with torch.no_grad(): self.log(None, None, start_time) for self.iter_idx in range(self.num_updates): # First, re-compute the hidden states given the current rollouts (since the VAE might've changed) with torch.no_grad(): latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory( ) # add this initial hidden state to the policy storage assert len(self.policy_storage.latent_mean ) == 0 # make sure we emptied buffers self.policy_storage.hidden_states[0].copy_(hidden_state) self.policy_storage.latent_samples.append(latent_sample.clone()) self.policy_storage.latent_mean.append(latent_mean.clone()) self.policy_storage.latent_logvar.append(latent_logvar.clone()) # rollout policies for a few steps for step in range(self.args.policy_num_steps): # sample actions from policy with torch.no_grad(): value, action = utl.select_action( args=self.args, policy=self.policy, state=prev_state, belief=belief, task=task, deterministic=False, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) # take step in the environment [next_state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step( self.envs, action, self.args) done = torch.from_numpy(np.array( done, dtype=int)).to(device).float().view((-1, 1)) # create mask for episode ends masks_done = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(device) # bad_mask is true if episode ended because time limit was reached bad_masks = torch.FloatTensor( [[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]).to(device) with torch.no_grad(): # compute next embedding (for next loop and/or value prediction bootstrap) latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding( encoder=self.vae.encoder, next_obs=next_state, action=action, reward=rew_raw, done=done, hidden_state=hidden_state) # before resetting, update the embedding and add to vae buffer # (last state might include useful task info) if not (self.args.disable_decoder and self.args.disable_kl_term): self.vae.rollout_storage.insert( prev_state.clone(), action.detach().clone(), next_state.clone(), rew_raw.clone(), done.clone(), task.clone() if task is not None else None) # add the obs before reset to the policy storage self.policy_storage.next_state[step] = next_state.clone() # reset environments that are done done_indices = np.argwhere(done.cpu().flatten()).flatten() if len(done_indices) > 0: next_state, belief, task = utl.reset_env( self.envs, self.args, indices=done_indices, state=next_state) # TODO: deal with resampling for posterior sampling algorithm # latent_sample = latent_sample # latent_sample[i] = latent_sample[i] # add experience to policy buffer self.policy_storage.insert( state=next_state, belief=belief, task=task, actions=action, rewards_raw=rew_raw, rewards_normalised=rew_normalised, value_preds=value, masks=masks_done, bad_masks=bad_masks, done=done, hidden_states=hidden_state.squeeze(0), latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar, ) prev_state = next_state self.frames += self.args.num_processes # --- UPDATE --- if self.args.precollect_len <= self.frames: # check if we are pre-training the VAE if self.args.pretrain_len > self.iter_idx: for p in range(self.args.num_vae_updates_per_pretrain): self.vae.compute_vae_loss( update=True, pretrain_index=self.iter_idx * self.args.num_vae_updates_per_pretrain + p) # otherwise do the normal update (policy + vae) else: train_stats = self.update(state=prev_state, belief=belief, task=task, latent_sample=latent_sample, latent_mean=latent_mean, latent_logvar=latent_logvar) # log run_stats = [ action, self.policy_storage.action_log_probs, value ] with torch.no_grad(): self.log(run_stats, train_stats, start_time) # clean up after update self.policy_storage.after_update() self.envs.close()