def reset(self): if isinstance(self.init_state, list): self.cur_state = np.random.choice(self.init_state) else: self.cur_state = self.init_state self.__timesteps = 0 obs = flat_to_one_hot(self.cur_state, ndim=self.nstates) return self._wrap_obs(obs)
def step(self, a): transition_probs = self.transitions[self.cur_state, a] next_state = np.random.choice(np.arange(self.nstates), p=transition_probs) r = self.reward[self.cur_state, a] self.cur_state = next_state obs = flat_to_one_hot(self.cur_state, ndim=self.nstates) self.__timesteps += 1 done = False if (self.terminate_on_reward and r > 0) or (self.__timesteps > self.max_timesteps): done = True return self._wrap_obs(obs), r, done, {}
def plot_costs(self, paths, cost_fn, dirname=None, itr=0, policy=None, use_traj_paths=False): if not use_traj_paths: # iterate through states, and each action - makes sense for non-rnn costs obses = [] acts = [] for (x, a) in itertools.product(range(self.nstates), range(self.nactions)): obs = flat_to_one_hot(x, ndim=self.nstates) act = flat_to_one_hot(a, ndim=self.nactions) obses.append(obs) acts.append(act) path = {'observations': np.array(obses), 'actions': np.array(acts)} if policy is not None: if hasattr(policy, 'set_env_infos'): policy.set_env_infos(path.get('env_infos', {})) actions, agent_infos = policy.get_actions(path['observations']) path['agent_infos'] = agent_infos paths = [path] plots = cost_fn.debug_eval(paths, policy=policy) for plot in plots: plots[plot] = plots[plot].squeeze() for plot in plots: data = plots[plot] data = np.reshape(data, (self.nstates, self.nactions)) self.plot_data(data, dirname=dirname, fname=plot + '_itr%d', itr=itr)
def initial_state_distribution(self): return flat_to_one_hot(self.init_state, ndim=self.nstates)