def repalce_mission_with_omega(obs_input: DictObs, omega: torch.Tensor) -> DictObs: # obs = copy.deepcopy(obs_input) obs = obs_input.copy() if 'mission' in obs.keys(): obs.pop('mission') obs.update({'omega': omega.clone()}) return obs
def replace_goal_vector_with_z(obs_input: DictObs, z_latent: torch.Tensor) -> DictObs: # obs = copy.deepcopy(obs_input) obs = obs_input.copy() if 'goal_vector' in obs.keys(): obs.pop('omega') obs.update({'z_latent': z_latent.clone()}) return obs
def eval_success_td(self): with torch.no_grad(): self.val_envs.reset_config_rng() assert self.val_envs.get_attr('reset_on_done')[0] self.actor_critic.train() reset_output = self.val_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) episode_counter = 0 episode_rewards = np.zeros((self.args.num_eval_episodes)) episode_mrids = np.zeros((self.args.num_eval_episodes)) masks = torch.ones(self.num_processes_eff, 1).float().to(self.device) recurrent_hidden_states = torch.zeros( self.args.num_processes, self.actor_critic.recurrent_hidden_state_size).to(self.device) eval_done = False while not eval_done: z_latent, z_gauss_dist, value, action, \ action_log_prob, recurrent_hidden_states = \ self.actor_critic.act( inputs=obs, rnn_hxs=recurrent_hidden_states, masks=masks, do_z_sampling=False) cpu_actions = action.view(-1).cpu().numpy() obs, _, done, info = self.val_envs.step(cpu_actions) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(self.device) obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) for batch_idx, info_item in enumerate(info): if 'prev_episode' in info_item.keys(): episode_rewards[episode_counter] = \ info_item['prev_episode']['info']['episode_reward'] episode_mrids[episode_counter] = \ info_item['prev_episode']['info']['max_room_id'] episode_counter += 1 if episode_counter >= self.args.num_eval_episodes: eval_done = True break episode_success = 1.0 * (episode_rewards > 0) return episode_rewards, episode_success, episode_mrids
def train_infobot_supervised(self, total_training_steps, start_iter): """Train loop""" print("=" * 36) print("Trainer initialized! Training information:") print("\t# of total_training_steps: {}".format(total_training_steps)) # print("\t# of train envs: {}".format(len(self.train_envs))) print("\tnum_processes: {}".format(self.args.num_processes)) print("\tnum_agents: {}".format(self.args.num_agents)) # print("\tIterations per epoch: {}".format(self.num_batches_per_epoch)) print("=" * 36) self.save_checkpoint(0) if self.args.model == 'hier': self.do_sampling = True elif self.args.model == 'cond': self.do_sampling = False next_save_on = 1 * self.args.save_interval self.actor_critic.train() # self.agent_pos = np.zeros( # [self.args.num_steps + 1, self.num_processes_eff, 2], dtype='int') # self.visit_count = [np.ones(self.num_processes_eff)] # self.visit_count = np.ones( # [self.args.num_steps, self.num_processes_eff], dtype='int') # self.heuristic_ds = np.zeros( # [self.args.num_steps, self.num_processes_eff], dtype='int') reset_output = self.train_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) # info = dict_stack_helper(info) # curr_pos = np.stack([item['agent_pos'] for item in info], 0) # self.agent_pos[0] = curr_pos # [obs] = flatten_batch_dims(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) self.rollouts.obs[0].copy_(obs) self.rollouts.to(self.device) # time_steps = torch.zeros(self.num_processes_eff, 1).long().to(self.device) # episode_rewards = torch.zeros(self.num_processes_eff, 1).to(self.device) episode_counter = 0 episode_rewards = deque(maxlen=300) episode_mrids = deque(maxlen=300) ep_len = deque(maxlen=300) zz_kld = deque(maxlen=self.args.log_interval) zz_kl_loss = deque(maxlen=self.args.log_interval) effective_return = deque(maxlen=self.args.log_interval) masks = torch.ones(self.num_processes_eff, 1).float().to(self.device) recurrent_hidden_states = torch.zeros( self.args.num_steps + 1, self.args.num_processes, self.actor_critic.recurrent_hidden_state_size) num_updates = int(total_training_steps) // \ (self.num_processes_eff * self.args.num_steps) def batch_iterator(start_idx): idx = start_idx for _ in range(start_idx, num_updates + self.args.log_interval): yield idx idx += 1 start = time.time() for iter_id in batch_iterator(start_iter): self.actor_critic.train() self.rollouts.prev_final_mask.fill_(0) for step in range(self.args.num_steps): with torch.no_grad(): z_latent, z_gauss_dist, value, action, \ action_log_prob, recurrent_hidden_states = \ self.actor_critic.act( inputs=obs, rnn_hxs=self.rollouts.recurrent_hidden_states[step], masks=self.rollouts.masks[step], do_z_sampling=self.args.z_stochastic) z_eps = (z_latent - z_gauss_dist.loc) / z_gauss_dist.scale cpu_actions = action.view(-1).cpu().numpy() obs, reward, done, info = self.train_envs.step(cpu_actions) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(self.device) episode_counter += done.sum() obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) # curr_pos = np.stack([item['agent_pos'] for item in info], 0) # curr_dir = np.stack([item['agent_dir'] for item in info], 0) # visit_count = np.stack([item['visit_count'] for item in info], 0) # self.agent_pos[step + 1] = curr_pos # self.visit_count[step] = visit_count # if 'is_heuristic_ds' in info[0].keys(): # is_heuristic_ds = np.stack( # [item['is_heuristic_ds'] for item in info], 0) # self.heuristic_ds[step] = is_heuristic_ds for batch_idx, info_item in enumerate(info): if 'prev_episode' in info_item.keys(): # prev_final_obs = info_item['prev_episode']['obs'] # prev_final_obs = DictObs( # {key:torch.from_numpy(obs_i).to(self.device) \ # for key, obs_i in prev_final_obs.items()}) # self.rollouts.prev_final_mask[step, batch_idx] = 1 # self.rollouts.prev_final_visit_count[step, batch_idx] = \ # info_item['visit_count'] # self.rollouts.prev_final_heur_ds[step, batch_idx] = \ # float(info_item['is_heuristic_ds']) # self.rollouts.prev_final_obs[step, batch_idx].copy_( # prev_final_obs) episode_rewards.append(info_item['prev_episode'] ['info']['episode_reward']) episode_mrids.append( info_item['prev_episode']['info']['max_room_id']) ep_len.append( info_item['prev_episode']['info']['step_count']) reward = torch.from_numpy(reward[:, np.newaxis]).float() reward = reward.to(self.device) self.rollouts.insert( obs=obs, recurrent_hidden_states=recurrent_hidden_states, actions=action, action_log_probs=action_log_prob, value_preds=value, z_eps=z_eps, rewards=reward, masks=masks, ) with torch.no_grad(): next_value = self.actor_critic.get_value( inputs=self.rollouts.obs[-1], rnn_hxs=self.rollouts.recurrent_hidden_states[-1], masks=self.rollouts.masks[-1], ).detach() total_num_steps = (iter_id + 1) * \ self.num_processes_eff * self.args.num_steps anneal_coeff = utils.kl_coefficient_curriculum( iter_id=total_num_steps, iters_per_epoch=1, start_after_epochs=self.args.kl_anneal_start_epochs, linear_growth_epochs=self.args.kl_anneal_growth_epochs, ) q_start_flag = utils.q_start_curriculum( iter_id=total_num_steps, iters_per_epoch=1, start_after_epochs=self.args.q_start_epochs, ) if not self.args.z_stochastic: infobot_coeff = 0 else: infobot_coeff = utils.kl_coefficient_curriculum( iter_id=total_num_steps, iters_per_epoch=1, start_after_epochs=self.args.infobot_kl_start, linear_growth_epochs=self.args.infobot_kl_growth, ) min_ib_coeff = min(self.args.infobot_beta_min, self.args.infobot_beta) if self.args.infobot_beta > 0: infobot_coeff = max(infobot_coeff, min_ib_coeff / self.args.infobot_beta) if not self.args.z_stochastic: infobot_coeff = 0 if self.args.algo == 'a2c' or self.args.algo == 'acktr': # Conditional model value_loss, action_loss, dist_entropy,\ action_log_probs_mean, ic_info = \ self.agent.update_infobot_supervised( rollouts=self.rollouts, infobot_beta=self.args.infobot_beta, next_value=next_value, anneal_coeff=infobot_coeff, ) ic_info.update({ 'anneal_coeff': infobot_coeff, 'q_start_flag': q_start_flag, }) zz_kld.append(ic_info['zz_kld']) zz_kl_loss.append(ic_info['zz_kl_loss']) effective_return.append(ic_info['effective_return']) else: raise ValueError("Unknown algo: {}".format(self.args.algo)) self.rollouts.after_update() if iter_id % self.args.log_interval == 0: if len(episode_rewards) > 1: # cpu_rewards = episode_rewards.cpu().numpy() cpu_rewards = episode_rewards mrids = episode_mrids episode_length = ep_len else: cpu_rewards = np.array([0]) mrids = np.array([0]) episode_length = np.array([0]) end = time.time() FPS = int(total_num_steps / (end - start)) print( f"Updates {iter_id}, num timesteps {total_num_steps}, FPS {FPS}, episodes: {episode_counter} \n Last {len(cpu_rewards)} training episodes: mean/median reward {np.mean(cpu_rewards):.1f}/{np.median(cpu_rewards):.1f}, min/max reward {np.min(cpu_rewards):.1f}/{np.max(cpu_rewards):.1f}" ) print( f" Max room id mean/median: {np.mean(mrids):.1f}/{np.median(mrids):.1f}, min/max: {np.min(mrids)}/{np.max(mrids)}" ) train_success = 1.0 * (np.array(cpu_rewards) > 0) self.logger.plot_success( prefix="train_", total_num_steps=total_num_steps, rewards=cpu_rewards, success=train_success, mrids=mrids, ) self.logger.viz.line(total_num_steps, FPS, "FPS", "FPS", xlabel="time_steps") self.logger.plot_quad_stats(x_val=total_num_steps, array=episode_length, plot_title="episode_length") self.logger.viz.line(total_num_steps, np.mean(effective_return), "effective_return", "mean", xlabel="time_steps") self.logger.viz.line(total_num_steps, np.mean(zz_kld), "zz_kl", "zz_kld", xlabel="time_steps") self.logger.viz.line(total_num_steps, np.mean(zz_kl_loss), "zz_kl", "zz_kl_loss", xlabel="time_steps") self.logger.viz.line(total_num_steps, infobot_coeff, "zz_kl", "anneal_coeff", xlabel="time_steps") self.logger.viz.line(total_num_steps, np.mean(dist_entropy), "policy_entropy", "entropy", xlabel="time_steps") if total_num_steps > self.next_val_after: print(f"Evaluating success at {total_num_steps} steps") self.next_val_after += self.args.val_interval val_rewards, val_success, val_mrids = self.eval_success_td() best_success_achieved = self.logger.plot_success( prefix="val_", total_num_steps=total_num_steps, rewards=val_rewards, success=val_success, mrids=val_mrids, track_best=True, ) self.save_checkpoint(total_num_steps, fname="best_val_success.vd") if total_num_steps > next_save_on: next_save_on += self.args.save_interval self.save_checkpoint(total_num_steps)
def train(self, total_training_steps, start_iter): """Train loop""" print("=" * 36) print("Trainer initialized! Training information:") print("\t# of total_training_steps: {}".format(total_training_steps)) # print("\t# of train envs: {}".format(len(self.train_envs))) print("\tnum_processes: {}".format(self.args.num_processes)) print("\tnum_agents: {}".format(self.args.num_agents)) # print("\tIterations per epoch: {}".format(self.num_batches_per_epoch)) print("=" * 36) self.save_checkpoint(0) if self.args.model == 'hier': self.do_sampling = True elif self.args.model == 'cond': self.do_sampling = False self.actor_critic.train() self.agent_pos = np.zeros( [self.args.num_steps + 1, self.num_processes_eff, 2], dtype='int') # self.visit_count = [np.ones(self.num_processes_eff)] self.visit_count = np.ones( [self.args.num_steps, self.num_processes_eff], dtype='int') self.heuristic_ds = np.zeros( [self.args.num_steps, self.num_processes_eff], dtype='int') reset_output = self.train_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) # info = dict_stack_helper(info) curr_pos = np.stack([item['agent_pos'] for item in info], 0) self.agent_pos[0] = curr_pos # [obs] = flatten_batch_dims(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) self.rollouts.obs[0].copy_(obs) self.rollouts.to(self.device) # time_steps = torch.zeros(self.num_processes_eff, 1).long().to(self.device) # episode_rewards = torch.zeros(self.num_processes_eff, 1).to(self.device) episode_counter = 0 episode_rewards = deque(maxlen=300) episode_mrids = deque(maxlen=300) masks = torch.ones(self.num_processes_eff, 1).float().to(self.device) recurrent_hidden_states = torch.zeros( self.args.num_steps + 1, self.args.num_processes, self.actor_critic.recurrent_hidden_state_size) num_updates = int(total_training_steps) // \ (self.num_processes_eff * self.args.num_steps) def batch_iterator(start_idx): idx = start_idx for _ in range(start_idx, num_updates + self.args.log_interval): yield idx idx += 1 start = time.time() for iter_id in batch_iterator(start_iter): self.actor_critic.train() self.rollouts.prev_final_mask.fill_(0) for step in range(self.args.num_steps): with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = \ self.actor_critic.act( inputs=obs, rnn_hxs=self.rollouts.recurrent_hidden_states[step], masks=self.rollouts.masks[step]) cpu_actions = action.view(-1).cpu().numpy() obs, reward, done, info = self.train_envs.step(cpu_actions) masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]).to(self.device) episode_counter += done.sum() obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) curr_pos = np.stack([item['agent_pos'] for item in info], 0) curr_dir = np.stack([item['agent_dir'] for item in info], 0) visit_count = np.stack([item['visit_count'] for item in info], 0) self.agent_pos[step + 1] = curr_pos self.visit_count[step] = visit_count if 'is_heuristic_ds' in info[0].keys(): is_heuristic_ds = np.stack( [item['is_heuristic_ds'] for item in info], 0) self.heuristic_ds[step] = is_heuristic_ds for batch_idx, info_item in enumerate(info): if 'prev_episode' in info_item.keys(): prev_final_obs = info_item['prev_episode']['obs'] prev_final_obs = DictObs( {key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in prev_final_obs.items()}) self.rollouts.prev_final_mask[step, batch_idx] = 1 self.rollouts.prev_final_visit_count[step, batch_idx] = \ info_item['visit_count'] self.rollouts.prev_final_heur_ds[step, batch_idx] = \ float(info_item['is_heuristic_ds']) self.rollouts.prev_final_obs[step, batch_idx].copy_( prev_final_obs) episode_rewards.append(info_item['prev_episode'] ['info']['episode_reward']) episode_mrids.append( info_item['prev_episode']['info']['max_room_id']) reward = torch.from_numpy(reward[:, np.newaxis]).float() reward = reward.to(self.device) # episode_rewards += reward # reward = torch.from_numpy(reward).float() # not_done = np.logical_not(done) # masks = torch.from_numpy(not_done.astype('float32')).unsqueeze(1) # masks = masks.to(self.device) self.rollouts.insert( obs=obs, recurrent_hidden_states=recurrent_hidden_states, actions=action, action_log_probs=action_log_prob, value_preds=value, rewards=reward, masks=masks, ) with torch.no_grad(): next_value = self.actor_critic.get_value( inputs=self.rollouts.obs[-1], rnn_hxs=self.rollouts.recurrent_hidden_states[-1], masks=self.rollouts.masks[-1], ).detach() anneal_coeff = utils.kl_coefficient_curriculum( iter_id=iter_id, iters_per_epoch=self.num_batches_per_epoch, start_after_epochs=self.args.kl_anneal_start_epochs, linear_growth_epochs=self.args.kl_anneal_growth_epochs, ) q_start_flag = utils.q_start_curriculum( iter_id=iter_id, iters_per_epoch=self.num_batches_per_epoch, start_after_epochs=self.args.q_start_epochs, ) if self.args.algo == 'a2c' or self.args.algo == 'acktr': # Conditional model value_loss, action_loss, dist_entropy,\ action_log_probs_mean, ic_info, option_info = \ self.agent.update( rollouts=self.rollouts, hier_mode=self.args.hier_mode, use_intrinsic_control=False, next_value=next_value, option_space=self.args.option_space, use_ib=self.args.use_infobot, agent_pos=self.agent_pos, bonus_z_encoder=self.z_encoder, b_args=self.b_args, bonus_type=self.args.bonus_type, bonus_normalization=self.args.bonus_normalization, heuristic_ds=self.heuristic_ds, heuristic_coeff=self.args.bonus_heuristic_beta, visit_count=self.visit_count, ) ic_info.update({ 'anneal_coeff': anneal_coeff, # 'infobot_coeff': infobot_coeff, 'q_start_flag': q_start_flag, }) # if 'traj_ce_loss' in ic_info: # traj_ce_loss.extend(ic_info['traj_ce_loss']) else: raise ValueError("Unknown algo: {}".format(self.args.algo)) self.rollouts.after_update() total_num_steps = (iter_id + 1) * \ self.num_processes_eff * self.args.num_steps if iter_id % self.args.log_interval == 0: if len(episode_rewards) > 1: # cpu_rewards = episode_rewards.cpu().numpy() cpu_rewards = episode_rewards mrids = episode_mrids else: cpu_rewards = np.array([0]) mrids = np.array([-1]) end = time.time() FPS = int(total_num_steps / (end - start)) print( f"Updates {iter_id}, num timesteps {total_num_steps}, FPS {FPS}, episodes: {episode_counter} \n Last {len(cpu_rewards)} training episodes: mean/median reward {np.mean(cpu_rewards):.1f}/{np.median(cpu_rewards):.1f}, min/max reward {np.min(cpu_rewards):.1f}/{np.max(cpu_rewards):.1f}" ) print( f" Max room id mean/median: {np.mean(mrids):.1f}/{np.median(mrids):.1f}, min/max: {np.min(mrids)}/{np.max(mrids)}" ) train_success = 1.0 * (np.array(cpu_rewards) > 0) self.logger.plot_success( prefix="train_", total_num_steps=total_num_steps, rewards=cpu_rewards, success=train_success, mrids=mrids, ) self.logger.viz.line(total_num_steps, FPS, "FPS", "FPS", xlabel="time_steps") if total_num_steps > self.next_val_after: print(f"Evaluating success at {total_num_steps} steps") self.next_val_after += self.args.val_interval val_rewards, val_success, val_mrids = self.eval_success() self.logger.plot_success( prefix="val_", total_num_steps=total_num_steps, rewards=val_rewards, success=val_success, mrids=val_mrids, track_best=True, )
def forward_step(self, step, omega_option, obs_base, ib_rnn_hxs, options_rhx): # Sample options if applicable if self.args.hier_mode == 'transfer': with torch.no_grad(): if step % self.args.num_option_steps == 0: omega_option = None previous_options_rhx = options_rhx option_value, omega_option, option_log_probs, options_rhx = \ self.options_policy.act( inputs=obs_base, rnn_hxs=options_rhx, masks=self.rollouts.masks[step]) if self.args.option_space == 'discrete': omega_option = omega_option.squeeze(-1) omega_option = torch.eye(self.args.omega_option_dims)\ .to(self.device)[omega_option] self.rollouts.insert_option_t( step=step, omega_option_t=omega_option, option_log_probs=option_log_probs, option_value=option_value, options_rhx=previous_options_rhx) obs_base = repalce_mission_with_omega(obs_base, omega_option) # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = \ self.actor_critic.act( inputs=obs_base, rnn_hxs=self.rollouts.recurrent_hidden_states[step], masks=self.rollouts.masks[step]) # Take actions, observe reward and next obs # cpu_actions = action.view( # (self.args.num_processes, self.args.num_agents)).cpu().numpy() cpu_actions = action.view(-1).cpu().numpy() # obs, reward, _, info = self.train_envs.step(cpu_actions + 1) obs, reward, _, info = self.train_envs.step(cpu_actions) obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) if self.args.hier_mode == 'transfer' or self.args.model == 'cond': obs_base = obs else: if self.args.use_infobot: if self.args.hier_mode == 'infobot-supervised': z_latent, z_log_prob, z_dist, ib_rnn_hxs = \ self.actor_critic.encoder_forward( obs=obs, rnn_hxs=ib_rnn_hxs, masks=self.rollouts.masks[step], do_z_sampling=True) obs_base = replace_goal_vector_with_z(obs, z_latent) else: # Sample next z_t obs_omega = repalce_mission_with_omega(obs, omega_option) z_latent, z_log_prob, z_dist, ib_rnn_hxs = \ self.actor_critic.encoder_forward( obs=obs_omega, rnn_hxs=ib_rnn_hxs, masks=self.rollouts.masks[step], do_z_sampling=self.do_z_sampling) obs_base = repalce_omega_with_z(obs, z_latent) self.rollouts.insert_z_latent(z_latent=z_latent, z_logprobs=z_log_prob, z_dist=z_dist, ib_enc_hidden_states=ib_rnn_hxs) else: obs_base = repalce_mission_with_omega(obs, omega_option) done = np.stack([item['done'] for item in info], 0) if 'is_heuristic_ds' in info[0].keys(): is_heuristic_ds = np.stack( [item['is_heuristic_ds'] for item in info], 0) self.heuristic_ds[step + 1] = is_heuristic_ds if not self.continuous_state_space: curr_pos = np.stack([item['agent_pos'] for item in info], 0) curr_dir = np.stack([item['agent_dir'] for item in info], 0) visit_count = np.stack([item['visit_count'] for item in info], 0) self.agent_pos[step + 1] = curr_pos # if 'current_room' in info[0]: # self.current_room = np.stack( # [item['current_room'] for item in info], 0) self.visit_count.append(visit_count) pos_velocity = None else: curr_pos = None curr_dir = None if self.args.env_name == 'mountain-car': pos_velocity = obs['pos-velocity'] else: pos_velocity = np.zeros((self.num_processes_eff, 2)) # [obs, reward] = utils.flatten_batch_dims(obs,reward) # print(step, done) # Extract the done flag from the info # done = np.concatenate([info_['done'] for info_ in info],0) # if step == self.args.num_steps - 1: # s_extract = lambda key_: np.array( # [item[key_] for item in info]) # success_train = s_extract('success').astype('float') # goal_index = s_extract('goal_index') # success_0 = success_train[goal_index == 0] # success_1 = success_train[goal_index == 1] # # spl_train = s_extract('spl_values') # # Shape Assertions reward = torch.from_numpy(reward[:, np.newaxis]).float() # episode_rewards += reward cpu_reward = reward reward = reward.to(self.device) # reward = torch.from_numpy(reward).float() not_done = np.logical_not(done) self.total_time_steps += not_done.sum() masks = torch.from_numpy(not_done.astype('float32')).unsqueeze(1) masks = masks.to(self.device) for key in obs.keys(): if obs[key].dim() == 5: obs[key] *= masks.type_as(obs[key])\ .unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) elif obs[key].dim() == 4: obs[key] *= masks.type_as(obs[key]).unsqueeze(-1).unsqueeze(-1) elif obs[key].dim() == 1: obs[key] *= masks.type_as(obs[key]).squeeze(1) else: obs[key] *= masks.type_as(obs[key]) self.rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks) return obs_base, omega_option, ib_rnn_hxs, options_rhx, cpu_reward, \ curr_pos, curr_dir, pos_velocity, not_done
def on_episode_start(self): self.actor_critic.train() # obs = train_envs.reset() reset_output = self.train_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) info = dict_stack_helper(info) if not self.continuous_state_space: self.visit_count = [np.ones(self.num_processes_eff)] self.agent_pos = np.zeros( [self.args.num_steps + 1, self.num_processes_eff, 2], dtype='int') self.agent_pos[0] = info['agent_pos'] info['pos_velocity'] = None else: self.agent_pos = None if self.args.env_name == 'mountain-car': info['pos_velocity'] = obs['pos-velocity'] else: info['pos_velocity'] = np.zeros((self.num_processes_eff, 2)) self.heuristic_ds = np.zeros( [self.args.num_steps + 1, self.num_processes_eff], dtype='int') # [obs] = flatten_batch_dims(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(self.device) \ for key, obs_i in obs.items()}) if self.args.model == 'cond': omega_option = None q_dist_ref = None options_rhx = None ib_rnn_hx = None self.rollouts.obs[0].copy_(obs) return omega_option, obs, q_dist_ref, ib_rnn_hx, \ options_rhx, info if self.args.use_infobot: ib_rnn_hx = self.rollouts.recurrent_hidden_states.new_zeros( self.num_processes_eff, self.actor_critic.encoder_recurrent_hidden_state_size) if self.args.hier_mode == 'infobot-supervised': omega_option = None q_dist_ref = None options_rhx = None self.rollouts.obs[0].copy_(obs) z_latent, z_log_prob, z_dist, ib_rnn_hx = \ self.actor_critic.encoder_forward( obs=obs, rnn_hxs=ib_rnn_hx, masks=self.rollouts.masks[0], do_z_sampling=True, ) self.rollouts.insert_z_latent(z_latent, z_log_prob, z_dist, ib_rnn_hx) obs = replace_goal_vector_with_z(obs, z_latent) return omega_option, obs, q_dist_ref, ib_rnn_hx, \ options_rhx, info else: ib_rnn_hx = None if self.args.option_space == 'continuous': option_log_probs = None if self.args.hier_mode == 'default': omega_option, q_dist, _, ldj = self.options_decoder( obs, do_sampling=self.do_sampling) elif self.args.hier_mode == 'vic': ldj = 0.0 _shape = (self.num_processes_eff, self.args.omega_option_dims) _loc = torch.zeros(*_shape).to(self.device) _scale = torch.ones(*_shape).to(self.device) # if self.omega_dim_current < self.args.omega_option_dims: # _scale[:, self.omega_dim_current:].fill_(1e-3) q_dist = ds.normal.Normal(loc=_loc, scale=_scale) if self.do_sampling: omega_option = q_dist.rsample() else: omega_option = q_dist.mean if self.args.ic_mode == 'diyan': _shape_t = (self.args.num_steps + 1, *_shape) _loc_t = torch.zeros(*_shape_t).to(self.device) _scale_t = torch.ones(*_shape_t).to(self.device) # if self.omega_dim_current < self.args.omega_option_dims: # _scale_t[:, :, self.omega_dim_current:].fill_(1e-3) q_dist_ref = ds.normal.Normal(loc=_loc_t, scale=_scale_t) else: q_dist_ref = q_dist if self.args.use_infobot: obs_omega = repalce_mission_with_omega(obs, omega_option) z_latent, z_log_prob, z_dist, ib_rnn_hx = \ self.actor_critic.encoder_forward( obs=obs_omega, rnn_hxs=ib_rnn_hx, masks=self.rollouts.masks[0], do_z_sampling=self.do_z_sampling) elif self.args.hier_mode == 'transfer': # omega_option, q_dist, _, ldj = self.options_policy( # obs, do_sampling=self.do_sampling) omega_option = None q_dist_ref = None else: raise ValueError else: ldj = 0.0 if self.args.hier_mode == 'default': with torch.no_grad(): option_discrete, q_dist, option_log_probs = self.options_decoder( obs, do_sampling=self.do_sampling) if self.args.use_infobot: raise NotImplementedError elif self.args.hier_mode == 'vic': with torch.no_grad(): _shape = (self.num_processes_eff, self.args.omega_option_dims) uniform_probs = torch.ones(*_shape).to(self.device) if self.omega_dim_current < self.args.omega_option_dims: uniform_probs[:, self.omega_dim_current:].fill_(0) uniform_probs = uniform_probs / uniform_probs.sum( -1, keepdim=True) q_dist = distributions.FixedCategorical( probs=uniform_probs) option_discrete = q_dist.sample() # option_log_probs = q_dist.log_probs(option_discrete) if self.args.ic_mode == 'diyan': _shape_t = (self.args.num_steps + 1, *_shape) uniform_probs = torch.ones(*_shape_t).to(self.device) if self.omega_dim_current < self.args.omega_option_dims: uniform_probs[:, :, self.omega_dim_current:].fill_(0) uniform_probs = uniform_probs / uniform_probs.sum( -1, keepdim=True) q_dist_ref = distributions.FixedCategorical( probs=uniform_probs) else: q_dist_ref = q_dist if self.args.use_infobot: omega_one_hot = torch.eye( self.args.omega_option_dims)[option_discrete] omega_one_hot = omega_one_hot.float().to(self.device) obs_omega = repalce_mission_with_omega( obs, omega_one_hot) z_latent, z_log_prob, z_dist, ib_rnn_hx = \ self.actor_critic.encoder_forward( obs=obs_omega, rnn_hxs=ib_rnn_hx, masks=self.rollouts.masks[0], do_z_sampling=self.do_z_sampling) elif self.args.hier_mode in ['transfer', 'bonus']: omega_option = None q_dist_ref = None else: raise ValueError if self.args.hier_mode != 'transfer': option_np = option_discrete.squeeze(-1).cpu().numpy() option_one_hot = np.eye(self.args.omega_option_dims)[option_np] omega_option = torch.from_numpy(option_one_hot).float().to( self.device) if self.args.hier_mode == 'transfer': obs_base = obs if self.args.use_infobot: raise NotImplementedError else: pass else: if self.args.use_infobot: obs_base = repalce_omega_with_z(obs, z_latent) self.rollouts.insert_option(omega_option) self.rollouts.insert_z_latent(z_latent, z_log_prob, z_dist, ib_rnn_hx) else: obs_base = repalce_mission_with_omega(obs, omega_option) self.rollouts.insert_option(omega_option) if self.args.hier_mode == 'transfer': options_rhx = torch.zeros( self.num_processes_eff, self.options_policy.recurrent_hidden_state_size).to( self.device) else: options_rhx = None # self.omega_option = omega_option # self.obs_base = obs_base self.rollouts.obs[0].copy_(obs) return omega_option, obs_base, q_dist_ref, ib_rnn_hx, \ options_rhx, info
def eval_success_simple( num_processes, num_steps, val_envs, actor_critic, device, num_episodes, ): ARGMAX_POLICY = True episode_count = 0 return_list = [] all_max_room = [] val_envs.modify_attr('render_rgb', [False] * num_processes) val_envs.reset_config_rng() while episode_count < num_episodes: reward_list = [] reset_output = val_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) info = dict_stack_helper(info) obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) recurrent_hidden_states = torch.zeros( num_processes, actor_critic.recurrent_hidden_state_size).to(device) masks = torch.ones(num_processes, 1).to(device) for step in range(num_steps): _, action, _, recurrent_hidden_states = \ actor_critic.act( inputs=obs, rnn_hxs=recurrent_hidden_states, masks=masks, deterministic=bool(ARGMAX_POLICY)) cpu_actions = action.view(-1).cpu().numpy() obs, reward, _, info = val_envs.step(cpu_actions) reward_list.append(reward) obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) done = np.stack([item['done'] for item in info], 0) if 'max_room_id' in info[0]: max_room = np.stack([item['max_room_id'] for item in info], 0) else: max_room = np.ones((num_processes)) * -1 all_max_room.append(max_room) episodic_return = np.stack(reward_list, 0).sum(0) return_list.append(episodic_return) episode_count += num_processes all_return = np.concatenate(return_list, 0) success = (all_return > 0).astype('float') all_max_room = np.concatenate(all_max_room, 0) return success, all_max_room, all_return
def eval_ib_kl(args, vis_env, actor_critic, device, omega_dim_current, num_samples=10): assert args.use_infobot != 0 action_dims = vis_env.action_space.n if hasattr(vis_env.actions, 'forward'): action_space_type = 'pov' elif hasattr(vis_env.actions, 'up'): action_space_type = 'cardinal' vis_obs, vis_info = vis_env.reset() assert 'rgb_grid' in vis_info env_rgb_img = vis_info['rgb_grid'].transpose([2, 0, 1]) env_rgb_img = np.flip(env_rgb_img, 1) all_obs = vis_env.enumerate_states() _rhx = torch.zeros(num_samples, actor_critic.recurrent_hidden_state_size).to(device) _masks = torch.ones(1, num_samples, 1).to(device) def repeat_dict_obs(dict_obs, batch_size): out = {} for key, value in dict_obs.items(): out[key] = np.broadcast_to(value[np.newaxis, :], (batch_size, *value.shape)) return out grid_shape = (vis_env.width, vis_env.height) # kl_zz_grid = torch.zeros(*grid_shape).to(device) kl_zz_opt_grid = [torch.zeros(*grid_shape).to(device) \ for _ in range(omega_dim_current)] kl_pi_def_grid = torch.zeros(*grid_shape).to(device) kl_pi_opt_grid = [torch.zeros(*grid_shape).to(device) \ for _ in range(omega_dim_current)] pi_def_grid = torch.zeros((action_dims, *grid_shape)).to(device) pi_opt_grid = [torch.zeros((action_dims, *grid_shape)).to(device) \ for _ in range(omega_dim_current)] if args.option_space == 'continuous': _shape = (num_samples, args.omega_option_dims) _loc = torch.zeros(*_shape).to(device) _scale = torch.ones(*_shape).to(device) omega_prior = ds.normal.Normal(loc=_loc, scale=_scale) _z_shape = (num_samples, args.z_latent_dims) _z_loc = torch.zeros(*_z_shape).to(device) _z_scale = torch.ones(*_z_shape).to(device) z_prior = ds.normal.Normal(loc=_z_loc, scale=_z_scale) for key, obs in all_obs.items(): obs = repeat_dict_obs(obs, num_samples) omega_option = omega_prior.rsample() obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) if 'mission' in obs.keys(): obs.pop('mission') obs.update({'omega': omega_option}) z_latent, z_log_prob, z_dist, _ = \ actor_critic.encoder_forward( obs=obs, rnn_hxs=_rhx, masks=_masks, do_z_sampling=True) kld_zz = ds.kl.kl_divergence(z_dist, z_prior) # kld_zz = kld_zz.view( # num_steps + 1, num_processes, z_latent_dims) kld_zz = torch.sum(kld_zz, 1).mean() kl_zz_grid[key.x, key.y] = kld_zz else: # _shape = (omega_dim_current, args.omega_option_dims) # uniform_probs = torch.ones(*_shape).to(device) _z_shape = (omega_dim_current * num_samples, args.z_latent_dims) _z_loc = torch.zeros(*_z_shape).to(device) _z_scale = torch.ones(*_z_shape).to(device) z_prior = ds.normal.Normal(loc=_z_loc, scale=_z_scale) # if omega_dim_current < args.omega_option_dims: # uniform_probs[:, omega_dim_current:].fill_(0) # uniform_probs = uniform_probs / uniform_probs.sum(-1, keepdim=True) # omega_prior = FixedCategorical(probs=uniform_probs) # # option_discrete = omega_prior.sample() omega_option = torch.eye(omega_dim_current).to(device) if omega_dim_current < args.omega_option_dims: _diff = args.omega_option_dims - omega_dim_current _pad = omega_option.new_zeros(omega_dim_current, _diff) omega_option = torch.cat([omega_option, _pad], 1) omega_option = omega_option.unsqueeze(0).repeat(num_samples, 1, 1) omega_option = omega_option.view(-1, *omega_option.shape[2:]) for key, obs in all_obs.items(): obs = repeat_dict_obs(obs, omega_option.shape[0]) # omega_option = omega_prior.rsample() obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) if 'mission' in obs.keys(): obs.pop('mission') obs.update({'omega': omega_option}) z_latent, z_log_prob, z_dist, _ = \ actor_critic.encoder_forward( obs=obs, rnn_hxs=_rhx, masks=_masks, do_z_sampling=True) kld_zz = ds.kl.kl_divergence(z_dist, z_prior) kld_zz = kld_zz.view(num_samples, omega_dim_current, *kld_zz.shape[1:]) kld_zz = kld_zz.sum(-1).mean(0) for opt_idx in range(omega_dim_current): kl_zz_opt_grid[opt_idx][key.x, key.y] = kld_zz[opt_idx] obs.pop('omega') obs.update({'z_latent': z_latent}) _, action_dist, _, _ = \ actor_critic.get_action_dist( inputs=obs, rnn_hxs=_rhx, masks=_masks) action_probs = action_dist.probs action_probs = action_probs.view(num_samples, omega_dim_current, *action_probs.shape[1:]).mean(0) pi_opt, pi_kl = {}, {} for opt_idx in range(omega_dim_current): pi_opt[opt_idx] = FixedCategorical(probs=action_probs[opt_idx]) pi_opt_grid[opt_idx][:, key.x, key.y] = pi_opt[opt_idx].probs pi_def = FixedCategorical(probs=action_probs.mean(0)) pi_def_grid[:, key.x, key.y] = pi_def.probs for opt_idx in range(omega_dim_current): pi_kl[opt_idx] = ds.kl.kl_divergence(pi_opt[opt_idx], pi_def) kl_pi_opt_grid[opt_idx][key.x, key.y] = pi_kl[opt_idx] pi_kl_avg = torch.stack(tuple(pi_kl.values()), 0).mean(0) kl_pi_def_grid[key.x, key.y] = pi_kl_avg pi_opt_grid = torch.stack(pi_opt_grid, 0) kl_pi_opt_grid = torch.stack(kl_pi_opt_grid, 0) kl_zz_opt_grid = torch.stack(kl_zz_opt_grid, 0) kl_zz_grid = kl_zz_opt_grid.mean(0) return_dict = { 'env_rgb_img': env_rgb_img, 'pi_opt_grid': pi_opt_grid.cpu().numpy().transpose([0, 1, 3, 2]), 'pi_def_grid': pi_def_grid.cpu().numpy().transpose([0, 2, 1]), 'kl_zz_grid': kl_zz_grid.cpu().numpy().T, 'kl_zz_opt_grid': kl_zz_opt_grid.cpu().numpy().transpose([0, 2, 1]), 'kl_pi_def_grid': kl_pi_def_grid.cpu().numpy().T, 'kl_pi_opt_grid': kl_pi_opt_grid.cpu().numpy().transpose([0, 2, 1]), } return return_dict
def eval_success( args, val_envs, vis_env, actor_critic, b_args, bonus_type, bonus_z_encoder, bonus_beta, bonus_normalization, device, num_episodes, ): ARGMAX_POLICY = True episode_count = 0 return_list = [] all_max_room = [] val_envs.modify_attr('render_rgb', [True] * args.num_processes) val_envs.reset_config_rng() # vis_env.reset_config_rng() grid_shape = (vis_env.width, vis_env.height) kl_grid = torch.zeros(*grid_shape).to(device) bonus_grid = torch.zeros(*grid_shape).to(device) while episode_count < num_episodes: reward_list = [] reset_output = val_envs.reset() obs = reset_output[:, 0] info = reset_output[:, 1] obs = dict_stack_helper(obs) info = dict_stack_helper(info) obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) rgb_grid = info['rgb_grid'] recurrent_hidden_states = torch.zeros( args.num_processes, actor_critic.recurrent_hidden_state_size).to(device) masks = torch.ones(args.num_processes, 1).to(device) agent_pos = [val_envs.get_attr('agent_pos')] all_masks = [np.array([True] * args.num_processes)] all_obs = [obs] all_vc = [np.ones(args.num_processes)] for step in range(args.num_steps): _, action, _, recurrent_hidden_states = \ actor_critic.act( inputs=obs, rnn_hxs=recurrent_hidden_states, masks=masks, deterministic=bool(ARGMAX_POLICY)) cpu_actions = action.view(-1).cpu().numpy() obs, reward, _, info = val_envs.step(cpu_actions) reward_list.append(reward) obs = dict_stack_helper(obs) obs = DictObs({key:torch.from_numpy(obs_i).to(device) \ for key, obs_i in obs.items()}) all_obs.append(obs) done = np.stack([item['done'] for item in info], 0) curr_pos = np.stack([item['agent_pos'] for item in info], 0) # curr_dir = np.stack([item['agent_dir'] for item in info], 0) visit_count = np.stack([item['visit_count'] for item in info], 0) all_vc.append(visit_count) agent_pos.append(curr_pos) all_masks.append(done == False) if 'max_room_id' in info[0]: max_room = np.stack([item['max_room_id'] for item in info], 0) else: max_room = np.ones((args.num_processes)) * -1 agent_pos = np.stack(agent_pos, 0) all_masks = np.stack(all_masks, 0) all_max_room.append(max_room) stacked_obs = {} for key in all_obs[0].keys(): stacked_obs[key] = torch.stack([_obs[key] for _obs in all_obs], 0) stacked_obs = DictObs(stacked_obs) stacked_masks = np.stack(all_masks, 0).astype('float32') stacked_masks = torch.from_numpy(stacked_masks).to(device) stacked_visit_count = np.stack(all_vc, 0) if bonus_type != 'count': bonus_kld = bonus_kl_forward( bonus_type=bonus_type, obs=stacked_obs, b_args=b_args, bonus_z_encoder=bonus_z_encoder, masks=stacked_masks, bonus_normalization=bonus_normalization, ) else: bonus_kld = stacked_masks.clone() * 0 episodic_return = np.stack(reward_list, 0).sum(0) return_list.append(episodic_return) episode_count += args.num_processes VIS_COUNT = 1 VIS_IDX = 0 agent_pos = agent_pos[:, VIS_IDX] episode_length = all_masks[:, VIS_IDX].sum() rgb_env_image = rgb_grid[VIS_IDX] bonus_kld = bonus_kld[:, VIS_IDX] visit_count = stacked_visit_count[:, VIS_IDX] rgb_env_image = np.flip(rgb_env_image.transpose([2, 0, 1]), 1) vis_info = make_bonus_grid( bonus_beta=bonus_beta, agent_pos=agent_pos, kl_values=bonus_kld.squeeze(-1).cpu().numpy(), visit_count=visit_count, episode_length=episode_length, grid_shape=grid_shape, ) vis_info['rgb_env_image'] = rgb_env_image all_return = np.concatenate(return_list, 0) success = (all_return > 0).astype('float') all_max_room = np.concatenate(all_max_room, 0) return success, all_max_room, vis_info