Ejemplo n.º 1
0
def enjoy():
    env = make_env(args.env_name, args.seed, 0, True)
    env = DummyVecEnv([env])

    actor_critic, ob_rms = torch.load(
        os.path.join(save_path, args.env_name + ".pt"))

    render_func = env.envs[0].render

    obs_shape = env.observation_space.shape
    obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
    current_obs = torch.zeros(1, *obs_shape)
    states = torch.zeros(1, actor_critic.state_size)
    masks = torch.zeros(1, 1)

    def update_current_obs(obs):
        shape_dim0 = env.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if args.num_stack > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    render_func('human')
    obs = env.reset()
    update_current_obs(obs)

    while True:
        value, action, _, states = actor_critic.act(Variable(current_obs,
                                                             volatile=True),
                                                    Variable(states,
                                                             volatile=True),
                                                    Variable(masks,
                                                             volatile=True),
                                                    deterministic=True)
        states = states.data
        cpu_actions = action.data.squeeze(1).cpu().numpy()

        # Observation, reward and next obs
        obs, reward, done, _ = env.step(cpu_actions)

        time.sleep(0.05)

        masks.fill_(0.0 if done else 1.0)

        if current_obs.dim() == 4:
            current_obs *= masks.unsqueeze(2).unsqueeze(2)
        else:
            current_obs *= masks
        update_current_obs(obs)

        renderer = render_func('human')

        if not renderer.window:
            sys.exit(0)
Ejemplo n.º 2
0
    obs = env.reset()
    update_current_obs(obs)

    while True:
        value, action, _, states = actor_critic.act(Variable(current_obs,
                                                             volatile=True),
                                                    Variable(states,
                                                             volatile=True),
                                                    Variable(masks,
                                                             volatile=True),
                                                    deterministic=True)
        states = states.data
        cpu_actions = action.data.squeeze(1).cpu().numpy()

        # Observation, reward and next obs
        obs, reward, done, _ = env.step(cpu_actions)

        time.sleep(0.05)

        masks.fill_(0.0 if done else 1.0)

        if current_obs.dim() == 4:
            current_obs *= masks.unsqueeze(2).unsqueeze(2)
        else:
            current_obs *= masks
        update_current_obs(obs)

        renderer = render_func('human')

        if not renderer.window:
            sys.exit(0)