def test_basic(self): e = env.DeepTraffic(lanes_side=1, patches_ahead=20, patches_behind=10, history=0) r = e.reset() self.assertEqual(r.shape, (1, 3, 30))
def test_hist(self): e = env.DeepTraffic(lanes_side=1, patches_ahead=20, patches_behind=10, history=3) r = e.reset() self.assertEqual(r.shape, (19, 3, 30)) np.testing.assert_array_equal(r[4], 1.0) np.testing.assert_array_equal(r[5:9], 0.0) np.testing.assert_array_equal(r[9], 1.0) np.testing.assert_array_equal(r[10:14], 0.0) np.testing.assert_array_equal(r[14], 1.0) np.testing.assert_array_equal(r[15:19], 0.0)
def test_agent(ini, net, steps=1000, rounds=5, device=torch.device('cpu')): round_means = [] for _ in range(rounds): speed_hist = [] test_env = env.DeepTraffic(lanes_side=ini.env_lanes_side, patches_ahead=ini.env_patches_ahead, patches_behind=ini.env_patches_behind, history=ini.env_history, obs=ini.env_obs) obs = test_env.reset() for _ in range(steps): speed_hist.append(test_env.current_speed()) obs_v = torch.tensor([obs]).to(device) q_v = net(obs_v)[0] act_idx = torch.argmax(q_v).item() obs, reward, _, _ = test_env.step(act_idx) round_means.append(np.mean(speed_hist)) return np.mean(round_means), np.std(round_means)
action='store_true', help="Display model and exit") args = parser.parse_args() ini = config.Settings(args.ini) device = torch.device("cuda" if ini.train_cuda else "cpu") if not args.show_model: name = pathlib.Path(args.ini).stem + "-" + args.name save_path = pathlib.Path("saves") / name save_path.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(comment="-" + name) e = env.DeepTraffic(lanes_side=ini.env_lanes_side, patches_ahead=ini.env_patches_ahead, patches_behind=ini.env_patches_behind, history=ini.env_history, obs=ini.env_obs) orig_env = e obs_shape = e.obs_shape e = gym.wrappers.TimeLimit(e, max_episode_steps=ini.env_steps_limit) log.info("Environment created, obs shape %s", obs_shape) model_class = model.MODELS[ini.train_model] net = model_class(obs_shape, e.action_space.n).to(device) log.info("Model: %s", net) if args.show_model: sys.exit(0) tgt_net = ptan.agent.TargetNet(net)