def experiment(variant): from traffic.make_env import make_env expl_env = make_env(args.exp_name, **variant['env_kwargs']) eval_env = make_env(args.exp_name, **variant['env_kwargs']) obs_dim = eval_env.observation_space.low.size action_dim = eval_env.action_space.n label_num = expl_env.label_num label_dim = expl_env.label_dim from graph_builder_multi import MultiTrafficGraphBuilder policy_gb = MultiTrafficGraphBuilder( input_dim=4 + label_dim, node_num=expl_env.max_veh_num + 1, ego_init=torch.tensor([0., 1.]), other_init=torch.tensor([1., 0.]), ) if variant['gnn_kwargs']['attention']: from gnn_attention_net import GNNAttentionNet gnn_class = GNNAttentionNet else: from gnn_net import GNNNet gnn_class = GNNNet policy_gnn = gnn_class( pre_graph_builder=policy_gb, node_dim=variant['gnn_kwargs']['node'], num_conv_layers=variant['gnn_kwargs']['layer'], hidden_activation=variant['gnn_kwargs']['activation'], ) from layers import FlattenLayer, SelectLayer policy = nn.Sequential( policy_gnn, SelectLayer(1, 0), FlattenLayer(), nn.ReLU(), nn.Linear(variant['gnn_kwargs']['node'], action_dim)) sup_gb = MultiTrafficGraphBuilder( input_dim=4, node_num=expl_env.max_veh_num + 1, ego_init=torch.tensor([0., 1.]), other_init=torch.tensor([1., 0.]), ) sup_attentioner = None from layers import ReshapeLayer from gnn_net import GNNNet sup_gnn = GNNNet( pre_graph_builder=sup_gb, node_dim=variant['gnn_kwargs']['node'], num_conv_layers=variant['gnn_kwargs']['layer'], hidden_activation=variant['gnn_kwargs']['activation'], ) sup_learner = nn.Sequential( sup_gnn, SelectLayer(1, np.arange(1, expl_env.max_veh_num + 1)), nn.ReLU(), nn.Linear(variant['gnn_kwargs']['node'], label_dim), ) from sup_sep_softmax_policy import SupSepSoftmaxPolicy policy = SupSepSoftmaxPolicy(policy, sup_learner, label_num, label_dim) print('parameters: ', np.sum([p.view(-1).shape[0] for p in policy.parameters()])) vf = Mlp( hidden_sizes=[32, 32], input_size=obs_dim, output_size=1, ) vf_criterion = nn.MSELoss() eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True) expl_policy = policy eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) from sup_sep_rollout import sup_sep_rollout expl_path_collector = MdpPathCollector( expl_env, expl_policy, rollout_fn=sup_sep_rollout, ) from sup_replay_buffer import SupReplayBuffer replay_buffer = SupReplayBuffer( observation_dim=obs_dim, label_dim=label_num, max_replay_buffer_size=int(1e6), ) from rlkit.torch.vpg.trpo_sup_sep import TRPOSupSepTrainer trainer = TRPOSupSepTrainer(policy=policy, value_function=vf, vf_criterion=vf_criterion, replay_buffer=replay_buffer, **variant['trainer_kwargs']) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, log_path_function=get_traffic_path_information, **variant['algorithm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(variant): from traffic.make_env import make_env expl_env = make_env(args.exp_name, **variant['env_kwargs']) eval_env = make_env(args.exp_name, **variant['env_kwargs']) obs_dim = eval_env.observation_space.low.size action_dim = eval_env.action_space.n label_num = expl_env.label_num label_dim = expl_env.label_dim policy = nn.Sequential( nn.Linear(obs_dim + int(label_dim * label_num + label_dim), 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, action_dim)) from layers import ReshapeLayer sup_learner = nn.Sequential( nn.Linear(obs_dim, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, int(label_num * label_dim)), ReshapeLayer(shape=(label_num, label_dim)), ) from sup_sep_softmax_policy import SupSepSoftmaxPolicy policy = SupSepSoftmaxPolicy(policy, sup_learner, label_num, label_dim) print('parameters: ', np.sum([p.view(-1).shape[0] for p in policy.parameters()])) vf = Mlp( hidden_sizes=[32, 32], input_size=obs_dim, output_size=1, ) vf_criterion = nn.MSELoss() eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True) expl_policy = policy eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) from sup_sep_rollout import sup_sep_rollout expl_path_collector = MdpPathCollector( expl_env, expl_policy, rollout_fn=sup_sep_rollout, ) from sup_replay_buffer import SupReplayBuffer replay_buffer = SupReplayBuffer( observation_dim=obs_dim, label_dim=label_num, max_replay_buffer_size=int(1e6), ) from rlkit.torch.vpg.trpo_sup_sep import TRPOSupSepTrainer trainer = TRPOSupSepTrainer(policy=policy, value_function=vf, vf_criterion=vf_criterion, replay_buffer=replay_buffer, **variant['trainer_kwargs']) algorithm = TorchOnlineRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, log_path_function=get_traffic_path_information, **variant['algorithm_kwargs']) algorithm.to(ptu.device) algorithm.train()