Пример #1
0
    def step(self):
        # Does a rollout.
        t = self.I.step_count % self.nsteps
        epinfos = []
        episodes_visited_rooms = []
        for l in range(self.I.nlump):
            obs, prevrews, news, infos = self.env_get(l)
            for env_pos_in_lump, info in enumerate(infos):
                if 'episode' in info:
                    # Information like rooms visited is added to info on end of episode.
                    epinfos.append(info['episode'])
                    info_with_places = info['episode']
                    try:
                        # KC modif
                        # info_with_places['places'] = info['episode']['visited_rooms']
                        info_with_places['places'] = []
                    except:
                        import ipdb
                        ipdb.set_trace()
                    self.I.buf_epinfos[
                        env_pos_in_lump +
                        l * self.I.lump_stride][t] = info_with_places
                if 'room_first_visit' in info:
                    visited_rooms = [
                        room_loc for room_loc, first_visit in
                        info['room_first_visit'].items()
                        if first_visit is not None
                    ]
                    # self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = {
                    #     'visited_rooms': visited_rooms
                    # }
                    episodes_visited_rooms.append(visited_rooms)

            sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
            memsli = slice(None) if self.I.mem_state is NO_STATES else sli
            dict_obs = self.stochpol.ensure_observation_is_dict(obs)
            # TODO kc modif
            # with logger.ProfileKV("policy_inference"):
            with logger.profile_kv("policy_inference"):
                # Calls the policy and value function on current observation.
                acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[
                    memsli], ent = self.stochpol.call(
                        dict_obs,
                        news,
                        self.I.mem_state[memsli],
                        update_obs_stats=self.update_ob_stats_every_step)
            self.env_step(l, acs)

            # Update buffer with transition.
            for k in self.stochpol.ph_ob_keys:
                self.I.buf_obs[k][sli, t] = dict_obs[k]
            self.I.buf_news[sli, t] = news
            self.I.buf_vpreds_int[sli, t] = vpreds_int
            self.I.buf_vpreds_ext[sli, t] = vpreds_ext
            self.I.buf_nlps[sli, t] = nlps
            self.I.buf_acs[sli, t] = acs
            self.I.buf_ent[sli, t] = ent

            if t > 0:
                self.I.buf_rews_ext[sli, t - 1] = prevrews

        self.I.step_count += 1
        if t == self.nsteps - 1 and not self.disable_policy_update:
            # We need to take one extra step so every transition has a reward.
            for l in range(self.I.nlump):
                sli = slice(l * self.I.lump_stride,
                            (l + 1) * self.I.lump_stride)
                memsli = slice(None) if self.I.mem_state is NO_STATES else sli
                nextobs, rews, nextnews, _ = self.env_get(l)
                dict_nextobs = self.stochpol.ensure_observation_is_dict(
                    nextobs)
                for k in self.stochpol.ph_ob_keys:
                    self.I.buf_ob_last[k][sli] = dict_nextobs[k]
                self.I.buf_new_last[sli] = nextnews
                # TODO kc modif
                # with logger.ProfileKV("policy_inference"):
                with logger.profile_kv("policy_inference"):
                    _, self.I.buf_vpred_int_last[
                        sli], self.I.buf_vpred_ext_last[
                            sli], _, _, _ = self.stochpol.call(
                                dict_nextobs,
                                nextnews,
                                self.I.mem_state[memsli],
                                update_obs_stats=False)
                self.I.buf_rews_ext[sli, t] = rews

            # Calcuate the intrinsic rewards for the rollout.
            fd = {}
            fd[self.stochpol.ph_ob[None]] = np.concatenate(
                [self.I.buf_obs[None], self.I.buf_ob_last[None][:, None]], 1)
            fd.update({
                self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5
            })
            fd[self.stochpol.ph_ac] = self.I.buf_acs
            self.I.buf_rews_int[:] = tf.get_default_session().run(
                self.stochpol.int_rew, fd)

            if not self.update_ob_stats_every_step:
                # Update observation normalization parameters after the rollout is completed.
                obs_ = self.I.buf_obs[None].astype(np.float32)
                self.stochpol.ob_rms.update(
                    obs_.reshape((-1, *obs_.shape[2:]))[:, :, :, -1:])
            if not self.testing:
                update_info = self.update()
            else:
                update_info = {}
            self.I.seg_init_mem_state = copy(self.I.mem_state)
            global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum')
            global_deque_mean = dict_gather(
                self.comm_log,
                {n: np.mean(dvs)
                 for n, dvs in self.I.statlists.items()},
                op='mean')
            update_info.update(global_i_stats)
            update_info.update(global_deque_mean)
            self.global_tcount = global_i_stats['tcount']
            for infos_ in self.I.buf_epinfos:
                infos_.clear()
        else:
            update_info = {}

        # Some reporting logic.
        for epinfo, visited_rooms in zip(epinfos, episodes_visited_rooms):
            if self.testing:
                self.I.statlists['eprew_test'].append(epinfo['r'])
                self.I.statlists['eplen_test'].append(epinfo['l'])
            else:
                if visited_rooms:
                    self.local_rooms += list(visited_rooms)
                    self.local_rooms = sorted(list(set(self.local_rooms)))
                    self.I.statlists['eprooms'].append(len(visited_rooms))

                self.I.statlists['eprew'].append(epinfo['r'])
                if self.local_best_ret is None:
                    self.local_best_ret = epinfo["r"]
                elif epinfo["r"] > self.local_best_ret:
                    self.local_best_ret = epinfo["r"]

                self.I.statlists['eplen'].append(epinfo['l'])
                self.I.stats['epcount'] += 1
                self.I.stats['tcount'] += epinfo['l']
                self.I.stats['rewtotal'] += epinfo['r']
                self.I.stats["best_ext_ret"] = self.best_ret

        return {'update': update_info}
Пример #2
0
    def step(self):
        #Does a rollout.
        t = self.I.step_count % self.nsteps
        epinfos = []
        for l in range(self.I.nlump):
            obs, prevrews, news, infos = self.env_get(l)
            for env_pos_in_lump, info in enumerate(infos):
                if 'episode' in info:
                    #Information like rooms visited is added to info on end of episode.
                    epinfos.append(info['episode'])
                    info_with_places = info['episode']
                    try:
                        if 'visited_rooms' not in info['episode']:
                            info_with_places['places'] = {}
                        else:
                            info_with_places['places'] = info['episode'][
                                'visited_rooms']  # only for montezuma env

                    except:
                        import ipdb
                        ipdb.set_trace()
                    self.I.buf_epinfos[
                        env_pos_in_lump +
                        l * self.I.lump_stride][t] = info_with_places

            sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
            memsli = slice(None) if self.I.mem_state is NO_STATES else sli
            dict_obs = self.stochpol.ensure_observation_is_dict(obs)
            with logger.profile_kv("policy_inference"):
                #Calls the policy and value function on current observation.
                acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[
                    memsli], ent = self.stochpol.call(
                        dict_obs,
                        news,
                        self.I.mem_state[memsli],
                        update_obs_stats=self.update_ob_stats_every_step)
            self.env_step(l, acs)

            #Update buffer with transition.
            for k in self.stochpol.ph_ob_keys:
                self.I.buf_obs[k][sli, t] = dict_obs[k]
            self.I.buf_news[sli, t] = news
            self.I.buf_vpreds_int[sli, t] = vpreds_int
            self.I.buf_vpreds_ext[sli, t] = vpreds_ext
            self.I.buf_nlps[sli, t] = nlps
            self.I.buf_acs[sli, t] = acs
            self.I.buf_ent[sli, t] = ent

            if t > 0:
                self.I.buf_rews_ext[sli, t - 1] = prevrews

        self.I.step_count += 1
        if t == self.nsteps - 1 and not self.disable_policy_update:
            #We need to take one extra step so every transition has a reward.
            for l in range(self.I.nlump):
                sli = slice(l * self.I.lump_stride,
                            (l + 1) * self.I.lump_stride)
                memsli = slice(None) if self.I.mem_state is NO_STATES else sli
                nextobs, rews, nextnews, _ = self.env_get(l)
                dict_nextobs = self.stochpol.ensure_observation_is_dict(
                    nextobs)
                for k in self.stochpol.ph_ob_keys:
                    self.I.buf_ob_last[k][sli] = dict_nextobs[k]
                self.I.buf_new_last[sli] = nextnews
                with logger.profile_kv("policy_inference"):
                    _, self.I.buf_vpred_int_last[
                        sli], self.I.buf_vpred_ext_last[
                            sli], _, _, _ = self.stochpol.call(
                                dict_nextobs,
                                nextnews,
                                self.I.mem_state[memsli],
                                update_obs_stats=False)
                self.I.buf_rews_ext[sli, t] = rews

            #Calcuate the intrinsic rewards for the rollout.
            # fd = {}
            # fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1)
            # fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
            #            self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5})
            # fd[self.stochpol.ph_ac] = self.I.buf_acs
            # self.I.buf_rews_int[:] = tf.get_default_session().run(self.stochpol.int_rew, fd)
            # if self.action_balance_coef is not None:
            #     self.I.buf_rews_int_ab[:] = tf.get_default_session().run(self.stochpol.int_rew_ab, fd)

            # use batch to reduce memory used
            # min(8, self.nminibatches) can be instead with self.nminibatches
            envsperbatch = self.I.nenvs // 2
            for start in range(0, self.I.nenvs, envsperbatch):
                end = start + envsperbatch
                mbenvinds = slice(start, end, None)

                fd = {}
                fd[self.stochpol.ph_ob[None]] = np.concatenate([
                    self.I.buf_obs[None][mbenvinds],
                    self.I.buf_ob_last[None][mbenvinds, None]
                ], 1)
                fd.update({
                    self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
                    self.stochpol.ph_std: self.stochpol.ob_rms.var**0.5
                })
                fd[self.stochpol.ph_ac] = self.I.buf_acs[mbenvinds]

                self.I.buf_rews_int[mbenvinds] = tf.get_default_session().run(
                    self.stochpol.int_rew, fd)
                if self.action_balance_coef is not None:
                    self.I.buf_rews_int_ab[mbenvinds] = tf.get_default_session(
                    ).run(self.stochpol.int_rew_ab, fd)

            if not self.update_ob_stats_every_step:
                #Update observation normalization parameters after the rollout is completed.
                obs_ = self.I.buf_obs[None].astype(np.float32)

                obs_ = obs_.reshape((-1, *self.ob_space.shape))

                if len(obs_) == 4:
                    self.stochpol.ob_rms.update(obs_[:, :, :, -1:])

                if len(obs_) == 2:
                    self.stochpol.ob_rms.update(obs_)

            if not self.testing:
                update_info = self.update()
            else:
                update_info = {}
            self.I.seg_init_mem_state = copy(self.I.mem_state)
            global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum')
            global_deque_mean = dict_gather(
                self.comm_log,
                {n: np.mean(dvs)
                 for n, dvs in self.I.statlists.items()},
                op='mean')
            update_info.update(global_i_stats)
            update_info.update(global_deque_mean)
            self.global_tcount = global_i_stats['tcount']
            for infos_ in self.I.buf_epinfos:
                infos_.clear()
        else:
            update_info = {}

        #Some reporting logic.
        for epinfo in epinfos:
            if self.testing:
                self.I.statlists['eprew_test'].append(epinfo['r'])
                self.I.statlists['eplen_test'].append(epinfo['l'])
            else:
                if "visited_rooms" in epinfo:
                    self.local_rooms += list(epinfo["visited_rooms"])
                    self.local_rooms = sorted(list(set(self.local_rooms)))
                    score_multiple = self.I.venvs[0].score_multiple
                    if score_multiple is None:
                        score_multiple = 1000
                    rounded_score = int(
                        epinfo["r"] / score_multiple) * score_multiple
                    self.scores.append(rounded_score)
                    self.scores = sorted(list(set(self.scores)))
                    self.I.statlists['eprooms'].append(
                        len(epinfo["visited_rooms"]))

                self.I.statlists['eprew'].append(epinfo['r'])
                if self.local_best_ret is None:
                    self.local_best_ret = epinfo["r"]
                elif epinfo["r"] > self.local_best_ret:
                    self.local_best_ret = epinfo["r"]

                self.I.statlists['eplen'].append(epinfo['l'])
                self.I.stats['epcount'] += 1
                self.I.stats['tcount'] += epinfo['l']
                self.I.stats['rewtotal'] += epinfo['r']
                # self.I.stats["best_ext_ret"] = self.best_ret

                if 'win' in epinfo:
                    self.I.statlists['win'].append(epinfo['win'])

        return {'update': update_info}