def step(self): # Does a rollout. t = self.I.step_count % self.nsteps epinfos = [] sess = tf.get_default_session() for l in range(self.I.nlump): obs, prevrews, news, infos = self.env_get(l) if prevrews is not None: for i in range((prevrews != 0).sum()): sess.run(self.inc_rew_counter) _obs = 255 * np.squeeze(obs[0]).astype(np.float32) _obs = _obs.astype(np.uint8) for env_pos_in_lump, info in enumerate(infos): if 'episode' in info: # Information like rooms visited is added to info on end of episode. epinfos.append(info['episode']) info_with_places = info['episode'] try: info_with_places['places'] = info['episode']['visited_rooms'] except: import ipdb; ipdb.set_trace() self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) memsli = slice(None) if self.I.mem_state is NO_STATES else sli dict_obs = self.stochpol.ensure_observation_is_dict(obs) with logger.ProfileKV("policy_inference"): # Calls the policy and value function on current observation. acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call( dict_obs, news, self.I.mem_state[memsli], update_obs_stats=self.update_ob_stats_every_step) self.env_step(l, acs) # Update buffer with transition. for k in self.stochpol.ph_ob_keys: self.I.buf_obs[k][sli, t] = dict_obs[k] self.I.buf_news[sli, t] = news self.I.buf_vpreds_int[sli, t] = vpreds_int self.I.buf_vpreds_ext[sli, t] = vpreds_ext self.I.buf_nlps[sli, t] = nlps self.I.buf_acs[sli, t] = acs self.I.buf_ent[sli, t] = ent if t > 0: self.I.buf_rews_ext[sli, t-1] = prevrews self.I.step_count += 1 if t == self.nsteps - 1 and not self.disable_policy_update: # We need to take one extra step so every transition has a reward. for l in range(self.I.nlump): sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) memsli = slice(None) if self.I.mem_state is NO_STATES else sli nextobs, rews, nextnews, _ = self.env_get(l) if rews is not None: for i in range((rews != 0).sum()): sess.run(self.inc_rew_counter) dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs) for k in self.stochpol.ph_ob_keys: self.I.buf_ob_last[k][sli] = dict_nextobs[k] self.I.buf_new_last[sli] = nextnews with logger.ProfileKV("policy_inference"): _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False) self.I.buf_rews_ext[sli, t] = rews # Calcuate the intrinsic rewards for the rollout. fd = {} fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1) fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean, self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5}) fd[self.stochpol.ph_ac] = self.I.buf_acs self.I.buf_rews_int[:] = tf.get_default_session().run(self.stochpol.int_rew, fd) if not self.update_ob_stats_every_step: # Update observation normalization parameters after the rollout is completed. obs_ = self.I.buf_obs[None].astype(np.float32) self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:]) if not self.testing: update_info = self.update() else: update_info = {} self.I.seg_init_mem_state = copy(self.I.mem_state) global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum') global_deque_mean = dict_gather(self.comm_log, { n : safemean(dvs) for n,dvs in self.I.statlists.items() }, op='mean') update_info.update(global_i_stats) update_info.update(global_deque_mean) self.global_tcount = global_i_stats['tcount'] for infos_ in self.I.buf_epinfos: infos_.clear() else: update_info = {} # Some reporting logic. for epinfo in epinfos: if self.testing: self.I.statlists['eprew_test'].append(epinfo['r']) self.I.statlists['eplen_test'].append(epinfo['l']) else: if "visited_rooms" in epinfo: self.local_rooms += list(epinfo["visited_rooms"]) self.local_rooms = sorted(list(set(self.local_rooms))) score_multiple = self.I.venvs[0].score_multiple if score_multiple is None: score_multiple = 1000 rounded_score = int(epinfo["r"] / score_multiple) * score_multiple self.scores.append(rounded_score) self.scores = sorted(list(set(self.scores))) self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"])) self.I.statlists['eprew'].append(epinfo['r']) if self.local_best_ret is None: self.local_best_ret = epinfo["r"] elif epinfo["r"] > self.local_best_ret: self.local_best_ret = epinfo["r"] self.I.statlists['eplen'].append(epinfo['l']) self.I.stats['epcount'] += 1 self.I.stats['tcount'] += epinfo['l'] self.I.stats['rewtotal'] += epinfo['r'] # self.I.stats["best_ext_ret"] = self.best_ret return {'update' : update_info}
def step(self): #Does a rollout. t = self.I.step_count % self.nsteps epinfos = [] self.check_goto_next_policy() for l in range(self.I.nlump): obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = self.env_get(l) for env_pos_in_lump, info in enumerate(infos): if 'episode' in info: #Information like rooms visited is added to info on end of episode. epinfos.append(info['episode']) info_with_places = info['episode'] try: info_with_places['places'] = info['episode']['visited_rooms'] except: import ipdb; ipdb.set_trace() self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places self.check_episode(env_pos_in_lump+l*self.I.lump_stride) sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) memsli = slice(None) if self.I.mem_state is NO_STATES else sli dict_obs = self.stochpol.ensure_observation_is_dict(obs) with logger.ProfileKV("policy_inference"): #Calls the policy and value function on current observation. acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call(dict_obs, news, self.I.mem_state[memsli], update_obs_stats=self.update_ob_stats_every_step) self.env_step(l, acs) #Update buffer with transition. for k in self.stochpol.ph_ob_keys: self.I.buf_obs[k][sli, t] = dict_obs[k] self.I.buf_news[sli, t] = news self.I.buf_vpreds_int[sli, t] = vpreds_int self.I.buf_vpreds_ext[sli, t] = vpreds_ext self.I.buf_nlps[sli, t] = nlps self.I.buf_acs[sli, t] = acs self.I.buf_ent[sli, t] = ent if t > 0: prevrews = [self.filter_rew(prevrews[k], infos[k]['unclip_rew'], infos[k]['position'], infos[k]['open_door_type'],k) for k in range(self.I.nenvs)] prevrews = np.asarray(prevrews) #print(prevrews) self.I.buf_rews_ext[sli, t-1] = prevrews self.I.buf_rews_ec[sli, t-1] = ec_rews if self.rnd_type=='oracle': #buf_rews_int = [ # self.I.oracle_visited_count.update_position(infos[k]['position']) # for k in range(self.I.nenvs)] buf_rews_int = [ self.update_rnd(infos[k]['position'], k) for k in range(self.I.nenvs)] #print(buf_rews_int) buf_rews_int = np.array(buf_rews_int) self.I.buf_rews_int[sli, t-1] = buf_rews_int self.I.step_count += 1 if t == self.nsteps - 1 and not self.disable_policy_update: #We need to take one extra step so every transition has a reward. for l in range(self.I.nlump): sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) memsli = slice(None) if self.I.mem_state is NO_STATES else sli nextobs, rews, ec_rews, nextnews, infos, ram_states, monitor_rews = self.env_get(l) dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs) for k in self.stochpol.ph_ob_keys: self.I.buf_ob_last[k][sli] = dict_nextobs[k] self.I.buf_new_last[sli] = nextnews with logger.ProfileKV("policy_inference"): _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False) rews = [self.filter_rew(rews[k], infos[k]['unclip_rew'], infos[k]['position'], infos[k]['open_door_type'],k) for k in range(self.I.nenvs)] rews = np.asarray(rews) self.I.buf_rews_ext[sli, t] = rews self.I.buf_rews_ec[sli, t] = ec_rews if self.rnd_type=='oracle': #buf_rews_int = [ # self.I.oracle_visited_count.update_position(infos[k]['position']) # for k in range(self.I.nenvs)] buf_rews_int = [ self.update_rnd(infos[k]['position'], k) for k in range(self.I.nenvs)] buf_rews_int = np.array(buf_rews_int) self.I.buf_rews_int[sli, t] = buf_rews_int if self.rnd_type =='rnd': #compute RND fd = {} fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1) fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean, self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5}) fd[self.stochpol.ph_ac] = self.I.buf_acs self.I.buf_rews_int[:] = tf_util.get_session().run(self.stochpol.int_rew, fd) * self.I.buf_rews_ec elif self.rnd_type =='oracle': #compute oracle count-based reward fd = {} else: raise ValueError('Unknown exploration reward: {}'.format( self._exploration_reward)) #Calcuate the intrinsic rewards for the rollout (for each step). ''' envsperbatch = self.I.nenvs // self.nminibatches #fd = {} #[nenvs, nstep+1, h,w,stack] #fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1) start = 0 while start < self.I.nenvs: end = start + envsperbatch mbenvinds = slice(start, end, None) fd = {} fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None][mbenvinds], self.I.buf_ob_last[None][mbenvinds, None]], 1) fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean, self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5}) fd[self.stochpol.ph_ac] = self.I.buf_acs[mbenvinds] # if dead, we set rew_int to zero #self.I.buf_rews_int[mbenvinds] = (1 -self.I.buf_news[mbenvinds]) * self.sess.run(self.stochpol.int_rew, fd) rews_int = tf_util.get_session().run(self.stochpol.int_rew, fd) self.I.buf_rews_int[mbenvinds] = rews_int * self.I.buf_rews_ec[mbenvinds] start +=envsperbatch ''' if not self.update_ob_stats_every_step: #Update observation normalization parameters after the rollout is completed. obs_ = self.I.buf_obs[None].astype(np.float32) self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:]) feed = {self.stochpol.ph_mean: self.stochpol.ob_rms.mean, self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5\ , self.stochpol.ph_count: self.stochpol.ob_rms.count} self.sess.run(self.assign_op, feed) if not self.testing: logger.info(self.I.cur_gen_idx,self.I.rews_found_by_contemporary) update_info = self.update() self.I.oracle_visited_count.sync() self.I.cur_oracle_visited_count.sync() self.I.cur_oracle_visited_count_for_next_gen.sync() else: update_info = {} self.I.seg_init_mem_state = copy(self.I.mem_state) global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum') global_deque_mean = dict_gather(self.comm_log, { n : np.mean(dvs) for n,dvs in self.I.statlists.items() }, op='mean') update_info.update(global_i_stats) update_info.update(global_deque_mean) self.global_tcount = global_i_stats['tcount'] for infos_ in self.I.buf_epinfos: infos_.clear() else: update_info = {} #Some reporting logic. for epinfo in epinfos: if self.testing: self.I.statlists['eprew_test'].append(epinfo['r']) self.I.statlists['eplen_test'].append(epinfo['l']) else: if "visited_rooms" in epinfo: self.local_rooms += list(epinfo["visited_rooms"]) self.local_rooms = sorted(list(set(self.local_rooms))) score_multiple = self.I.venvs[0].score_multiple if score_multiple is None: score_multiple = 1000 rounded_score = int(epinfo["r"] / score_multiple) * score_multiple self.scores.append(rounded_score) self.scores = sorted(list(set(self.scores))) self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"])) self.I.statlists['eprew'].append(epinfo['r']) if self.local_best_ret is None: self.local_best_ret = epinfo["r"] elif epinfo["r"] > self.local_best_ret: self.local_best_ret = epinfo["r"] self.I.statlists['eplen'].append(epinfo['l']) self.I.stats['epcount'] += 1 self.I.stats['tcount'] += epinfo['l'] self.I.stats['rewtotal'] += epinfo['r'] # self.I.stats["best_ext_ret"] = self.best_ret return {'update' : update_info}