Example #1
0
def demo(run=-1, length=None, test=True, N=None, env=None, agent=None, d=0):
    env = explorer.Explorer(d + 1) if env is None else env
    world = env.reset()
    if agent is None:
        agent = Agent(env).cuda()
        agent.load_state_dict(storing.load(run)['agent'], strict=False)

    world = env.reset()
    steps = 0
    with recording.ParallelEncoder(env.plot_state, N=N) as encoder, \
            tqdm(total=length) as pbar:
        while True:
            decision = agent(world[None], sample=True, test=test,
                             value=True).squeeze(0)
            world = env.step(decision)
            steps += 1
            pbar.update(1)
            if length is None and world.reset.any():
                break
            state = env.state(d)
            encoder(
                arrdict.numpyify(arrdict.arrdict(**state, decision=decision)))
            if (steps == length):
                break
    encoder.notebook()
    return encoder
Example #2
0
def display(scenery, e=0):
    ax = plt.axes()

    state = arrdict.numpyify(arrdict.arrdict(scenery=scenery.state(e)))

    plotting.plot_lines(ax, state, zoom=False)
    plotting.plot_lights(ax, state)

    plotting.adjust_view(ax, state, zoom=False)

    return ax.figure
Example #3
0
 def dataframe(self, **kwargs):
     soln = self.solve(**kwargs)
     successor = self._trans[
         torch.arange(self.n_states, device=self.device),
         soln.policy].argmax(-1)
     successor = [self._names[i] for i in successor]
     df = pd.DataFrame(
         arrdict.numpyify(
             dict(
                 name=self._names,
                 obs=[
                     tuple(f'{x:.2f}' for x in o)
                     for o in arrdict.numpyify(self._obs)
                 ],
                 term=self._terminal,
                 start=self._start,
                 value=soln.value,
                 policy=soln.policy,
                 successor=successor,
             ))).sort_index()
     df.index.name = 'idx'
     return df
Example #4
0
 def display(self, e=0):
     return self.plot_state(arrdict.numpyify(self.state(e=e)))
Example #5
0
 def display(self, e=None, **kwargs):
     ax = self.plot_worlds(arrdict.numpyify(arrdict.arrdict(self)),
                           e=e,
                           **kwargs)
     plt.close(ax.figure)
     return ax
Example #6
0
def record_worlds(worlds, N=0):
    state = arrdict.numpyify(worlds)
    with recording.ParallelEncoder(plot_all(worlds.plot_worlds), N=N, fps=1) as encoder:
        for i in range(state.board.shape[0]):
            encoder(state[i])
    return encoder
Example #7
0
def _dataframe(traj):
    if isinstance(traj, dict):
        return [([k] + kk, vv) for k, v in traj.items()
                for kk, vv in _dataframe(v)]
    if isinstance(traj, torch.Tensor):
        return [([], pd.Series([x for x in arrdict.numpyify(traj)]))]