Exemplo n.º 1
0
    def step(self):
        # Does a rollout.
        t = self.I.step_count % self.nsteps
        epinfos = []
        sess = tf.get_default_session()

        for l in range(self.I.nlump):
            obs, prevrews, news, infos = self.env_get(l)
            if prevrews is not None:
                for i in range((prevrews != 0).sum()):
                    sess.run(self.inc_rew_counter)

            _obs = 255 * np.squeeze(obs[0]).astype(np.float32)
            _obs = _obs.astype(np.uint8)

            for env_pos_in_lump, info in enumerate(infos):
                if 'episode' in info:
                    # Information like rooms visited is added to info on end of episode.
                    epinfos.append(info['episode'])
                    info_with_places = info['episode']
                    try:
                        info_with_places['places'] = info['episode']['visited_rooms']
                    except:
                        import ipdb; ipdb.set_trace()
                    self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places

            sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
            memsli = slice(None) if self.I.mem_state is NO_STATES else sli

            dict_obs = self.stochpol.ensure_observation_is_dict(obs)
            with logger.ProfileKV("policy_inference"):
                # Calls the policy and value function on current observation.
                acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call(
                    dict_obs, news, self.I.mem_state[memsli], update_obs_stats=self.update_ob_stats_every_step)

            self.env_step(l, acs)

            # Update buffer with transition.
            for k in self.stochpol.ph_ob_keys:
                self.I.buf_obs[k][sli, t] = dict_obs[k]
            self.I.buf_news[sli, t] = news
            self.I.buf_vpreds_int[sli, t] = vpreds_int
            self.I.buf_vpreds_ext[sli, t] = vpreds_ext
            self.I.buf_nlps[sli, t] = nlps
            self.I.buf_acs[sli, t] = acs
            self.I.buf_ent[sli, t] = ent

            if t > 0:
                self.I.buf_rews_ext[sli, t-1] = prevrews

        self.I.step_count += 1
        if t == self.nsteps - 1 and not self.disable_policy_update:
            # We need to take one extra step so every transition has a reward.

            for l in range(self.I.nlump):
                sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
                memsli = slice(None) if self.I.mem_state is NO_STATES else sli
                nextobs, rews, nextnews, _ = self.env_get(l)
                if rews is not None:
                    for i in range((rews != 0).sum()):
                        sess.run(self.inc_rew_counter)
                dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs)
                for k in self.stochpol.ph_ob_keys:
                    self.I.buf_ob_last[k][sli] = dict_nextobs[k]
                self.I.buf_new_last[sli] = nextnews
                with logger.ProfileKV("policy_inference"):
                    _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False)
                self.I.buf_rews_ext[sli, t] = rews

            # Calcuate the intrinsic rewards for the rollout.
            fd = {}
            fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1)
            fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                       self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5})
            fd[self.stochpol.ph_ac] = self.I.buf_acs

            self.I.buf_rews_int[:] = tf.get_default_session().run(self.stochpol.int_rew, fd)

            if not self.update_ob_stats_every_step:
                # Update observation normalization parameters after the rollout is completed.
                obs_ = self.I.buf_obs[None].astype(np.float32)
                self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:])
            if not self.testing:
                update_info = self.update()
            else:
                update_info = {}
            self.I.seg_init_mem_state = copy(self.I.mem_state)
            global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum')
            global_deque_mean = dict_gather(self.comm_log, { n : safemean(dvs) for n,dvs in self.I.statlists.items() }, op='mean')
            update_info.update(global_i_stats)
            update_info.update(global_deque_mean)
            self.global_tcount = global_i_stats['tcount']
            for infos_ in self.I.buf_epinfos:
                infos_.clear()
        else:
            update_info = {}

        # Some reporting logic.
        for epinfo in epinfos:
            if self.testing:

                self.I.statlists['eprew_test'].append(epinfo['r'])
                self.I.statlists['eplen_test'].append(epinfo['l'])

            else:
                if "visited_rooms" in epinfo:
                    self.local_rooms += list(epinfo["visited_rooms"])
                    self.local_rooms = sorted(list(set(self.local_rooms)))
                    score_multiple = self.I.venvs[0].score_multiple
                    if score_multiple is None:
                        score_multiple = 1000
                    rounded_score = int(epinfo["r"] / score_multiple) * score_multiple
                    self.scores.append(rounded_score)
                    self.scores = sorted(list(set(self.scores)))
                    self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"]))

                self.I.statlists['eprew'].append(epinfo['r'])
                if self.local_best_ret is None:
                    self.local_best_ret = epinfo["r"]
                elif epinfo["r"] > self.local_best_ret:
                    self.local_best_ret = epinfo["r"]

                self.I.statlists['eplen'].append(epinfo['l'])
                self.I.stats['epcount'] += 1
                self.I.stats['tcount'] += epinfo['l']
                self.I.stats['rewtotal'] += epinfo['r']
                # self.I.stats["best_ext_ret"] = self.best_ret

        return {'update' : update_info}
    def step(self):
        #Does a rollout.
        t = self.I.step_count % self.nsteps
        epinfos = []

        self.check_goto_next_policy()

        for l in range(self.I.nlump):
            obs, prevrews, ec_rews, news, infos, ram_states, monitor_rews = self.env_get(l)




            for env_pos_in_lump, info in enumerate(infos):
                if 'episode' in info:
                    #Information like rooms visited is added to info on end of episode.
                    epinfos.append(info['episode'])
                    info_with_places = info['episode']
                    try:
                        info_with_places['places'] = info['episode']['visited_rooms']
                    except:
                        import ipdb; ipdb.set_trace()
                    self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places


                    self.check_episode(env_pos_in_lump+l*self.I.lump_stride)

            sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
            memsli = slice(None) if self.I.mem_state is NO_STATES else sli





            dict_obs = self.stochpol.ensure_observation_is_dict(obs)
            with logger.ProfileKV("policy_inference"):
                #Calls the policy and value function on current observation.
                acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call(dict_obs, news, self.I.mem_state[memsli],
                                                                                                               update_obs_stats=self.update_ob_stats_every_step)
            self.env_step(l, acs)





            #Update buffer with transition.
            for k in self.stochpol.ph_ob_keys:
                self.I.buf_obs[k][sli, t] = dict_obs[k]
            self.I.buf_news[sli, t] = news
            self.I.buf_vpreds_int[sli, t] = vpreds_int
            self.I.buf_vpreds_ext[sli, t] = vpreds_ext
            self.I.buf_nlps[sli, t] = nlps
            self.I.buf_acs[sli, t] = acs
            self.I.buf_ent[sli, t] = ent

            if t > 0:

                prevrews = [self.filter_rew(prevrews[k], infos[k]['unclip_rew'], infos[k]['position'], infos[k]['open_door_type'],k)
                                for k in range(self.I.nenvs)]
    
                prevrews = np.asarray(prevrews)
                #print(prevrews)

                self.I.buf_rews_ext[sli, t-1] = prevrews
                self.I.buf_rews_ec[sli, t-1] = ec_rews

                if self.rnd_type=='oracle':
                    #buf_rews_int = [
                    #    self.I.oracle_visited_count.update_position(infos[k]['position'])
                    #    for k in range(self.I.nenvs)]

                    buf_rews_int = [
                        self.update_rnd(infos[k]['position'], k)
                        for k in range(self.I.nenvs)]

                    #print(buf_rews_int)

                    buf_rews_int = np.array(buf_rews_int)
                    self.I.buf_rews_int[sli, t-1] = buf_rews_int

    


        self.I.step_count += 1
        if t == self.nsteps - 1 and not self.disable_policy_update:
            #We need to take one extra step so every transition has a reward.
            for l in range(self.I.nlump):
                sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
                memsli = slice(None) if self.I.mem_state is NO_STATES else sli
                nextobs, rews, ec_rews, nextnews, infos, ram_states, monitor_rews = self.env_get(l)
                dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs)
                for k in self.stochpol.ph_ob_keys:
                    self.I.buf_ob_last[k][sli] = dict_nextobs[k]
                self.I.buf_new_last[sli] = nextnews
                with logger.ProfileKV("policy_inference"):
                    _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False)
                

                rews = [self.filter_rew(rews[k], infos[k]['unclip_rew'], infos[k]['position'], infos[k]['open_door_type'],k)
                                for k in range(self.I.nenvs)]
    
                rews = np.asarray(rews)

                self.I.buf_rews_ext[sli, t] = rews
                self.I.buf_rews_ec[sli, t] = ec_rews

                if self.rnd_type=='oracle':
                    #buf_rews_int = [
                    #    self.I.oracle_visited_count.update_position(infos[k]['position'])
                    #    for k in range(self.I.nenvs)]
                    
                    buf_rews_int = [
                        self.update_rnd(infos[k]['position'], k)
                        for k in range(self.I.nenvs)]

                    buf_rews_int = np.array(buf_rews_int)
                    self.I.buf_rews_int[sli, t] = buf_rews_int

            if self.rnd_type =='rnd':
                #compute RND
                fd = {}
                fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1)
                fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                               self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5})
                fd[self.stochpol.ph_ac] = self.I.buf_acs
                self.I.buf_rews_int[:] = tf_util.get_session().run(self.stochpol.int_rew, fd) * self.I.buf_rews_ec
            elif self.rnd_type =='oracle':
            #compute oracle count-based reward
                fd = {}
            else:
                raise ValueError('Unknown exploration reward: {}'.format(
                  self._exploration_reward))
            #Calcuate the intrinsic rewards for the rollout (for each step).
            '''
            envsperbatch = self.I.nenvs // self.nminibatches

            #fd = {}

            
            #[nenvs, nstep+1, h,w,stack]
            #fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1)
            start = 0
            while start < self.I.nenvs:
                end = start + envsperbatch
                mbenvinds = slice(start, end, None)
    
                fd = {}
        
                fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None][mbenvinds],  self.I.buf_ob_last[None][mbenvinds, None]], 1)
    
    
                fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                           self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5})
                fd[self.stochpol.ph_ac] = self.I.buf_acs[mbenvinds]


    


                # if dead,  we set rew_int to zero
                #self.I.buf_rews_int[mbenvinds] = (1 -self.I.buf_news[mbenvinds]) * self.sess.run(self.stochpol.int_rew, fd)
                
                rews_int = tf_util.get_session().run(self.stochpol.int_rew, fd)
                self.I.buf_rews_int[mbenvinds] = rews_int * self.I.buf_rews_ec[mbenvinds]

                start +=envsperbatch

            '''
            if not self.update_ob_stats_every_step:
                #Update observation normalization parameters after the rollout is completed.
                obs_ = self.I.buf_obs[None].astype(np.float32)
                self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:])
                feed = {self.stochpol.ph_mean: self.stochpol.ob_rms.mean, self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5\
                        , self.stochpol.ph_count: self.stochpol.ob_rms.count}
                self.sess.run(self.assign_op, feed)

            if not self.testing:


                logger.info(self.I.cur_gen_idx,self.I.rews_found_by_contemporary)
                update_info = self.update()


                self.I.oracle_visited_count.sync()

                self.I.cur_oracle_visited_count.sync()
                self.I.cur_oracle_visited_count_for_next_gen.sync()

            else:
                update_info = {}
            self.I.seg_init_mem_state = copy(self.I.mem_state)
            global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum')
            global_deque_mean = dict_gather(self.comm_log, { n : np.mean(dvs) for n,dvs in self.I.statlists.items() }, op='mean')
            update_info.update(global_i_stats)
            update_info.update(global_deque_mean)
            self.global_tcount = global_i_stats['tcount']
            for infos_ in self.I.buf_epinfos:
                infos_.clear()
        else:
            update_info = {}

        #Some reporting logic.
        for epinfo in epinfos:
            if self.testing:
                self.I.statlists['eprew_test'].append(epinfo['r'])
                self.I.statlists['eplen_test'].append(epinfo['l'])
            else:
                if "visited_rooms" in epinfo:
                    self.local_rooms += list(epinfo["visited_rooms"])
                    self.local_rooms = sorted(list(set(self.local_rooms)))
                    score_multiple = self.I.venvs[0].score_multiple
                    if score_multiple is None:
                        score_multiple = 1000
                    rounded_score = int(epinfo["r"] / score_multiple) * score_multiple
                    self.scores.append(rounded_score)
                    self.scores = sorted(list(set(self.scores)))
                    self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"]))

                self.I.statlists['eprew'].append(epinfo['r'])
                if self.local_best_ret is None:
                    self.local_best_ret = epinfo["r"]
                elif epinfo["r"] > self.local_best_ret:
                    self.local_best_ret = epinfo["r"]

                self.I.statlists['eplen'].append(epinfo['l'])
                self.I.stats['epcount'] += 1
                self.I.stats['tcount'] += epinfo['l']
                self.I.stats['rewtotal'] += epinfo['r']
                # self.I.stats["best_ext_ret"] = self.best_ret


        return {'update' : update_info}