예제 #1
0
 def test_basic(self):
     e = env.DeepTraffic(lanes_side=1,
                         patches_ahead=20,
                         patches_behind=10,
                         history=0)
     r = e.reset()
     self.assertEqual(r.shape, (1, 3, 30))
예제 #2
0
 def test_hist(self):
     e = env.DeepTraffic(lanes_side=1,
                         patches_ahead=20,
                         patches_behind=10,
                         history=3)
     r = e.reset()
     self.assertEqual(r.shape, (19, 3, 30))
     np.testing.assert_array_equal(r[4], 1.0)
     np.testing.assert_array_equal(r[5:9], 0.0)
     np.testing.assert_array_equal(r[9], 1.0)
     np.testing.assert_array_equal(r[10:14], 0.0)
     np.testing.assert_array_equal(r[14], 1.0)
     np.testing.assert_array_equal(r[15:19], 0.0)
예제 #3
0
def test_agent(ini, net, steps=1000, rounds=5, device=torch.device('cpu')):
    round_means = []
    for _ in range(rounds):
        speed_hist = []
        test_env = env.DeepTraffic(lanes_side=ini.env_lanes_side,
                                   patches_ahead=ini.env_patches_ahead,
                                   patches_behind=ini.env_patches_behind,
                                   history=ini.env_history,
                                   obs=ini.env_obs)
        obs = test_env.reset()

        for _ in range(steps):
            speed_hist.append(test_env.current_speed())
            obs_v = torch.tensor([obs]).to(device)
            q_v = net(obs_v)[0]
            act_idx = torch.argmax(q_v).item()
            obs, reward, _, _ = test_env.step(act_idx)
        round_means.append(np.mean(speed_hist))
    return np.mean(round_means), np.std(round_means)
예제 #4
0
                        action='store_true',
                        help="Display model and exit")
    args = parser.parse_args()
    ini = config.Settings(args.ini)

    device = torch.device("cuda" if ini.train_cuda else "cpu")

    if not args.show_model:
        name = pathlib.Path(args.ini).stem + "-" + args.name
        save_path = pathlib.Path("saves") / name
        save_path.mkdir(parents=True, exist_ok=True)
        writer = SummaryWriter(comment="-" + name)

    e = env.DeepTraffic(lanes_side=ini.env_lanes_side,
                        patches_ahead=ini.env_patches_ahead,
                        patches_behind=ini.env_patches_behind,
                        history=ini.env_history,
                        obs=ini.env_obs)
    orig_env = e
    obs_shape = e.obs_shape
    e = gym.wrappers.TimeLimit(e, max_episode_steps=ini.env_steps_limit)

    log.info("Environment created, obs shape %s", obs_shape)
    model_class = model.MODELS[ini.train_model]
    net = model_class(obs_shape, e.action_space.n).to(device)
    log.info("Model: %s", net)

    if args.show_model:
        sys.exit(0)

    tgt_net = ptan.agent.TargetNet(net)