コード例 #1
0
def experiment(variant):
    from simple_sup import SimpleSupEnv
    expl_env = SimpleSupEnv(**variant['env_kwars'])
    eval_env = SimpleSupEnv(**variant['env_kwars'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n

    encoder = nn.Sequential(
        nn.Linear(obs_dim, 16),
        nn.ReLU(),
    )
    decoder = nn.Linear(16, action_dim)
    from layers import ReshapeLayer
    sup_learner = nn.Sequential(
        nn.Linear(16, action_dim),
        ReshapeLayer(shape=(1, action_dim)),
    )
    from sup_softmax_policy import SupSoftmaxPolicy
    policy = SupSoftmaxPolicy(encoder, decoder, sup_learner)

    vf = Mlp(
        hidden_sizes=[32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        label_dim=1,
        max_replay_buffer_size=int(1e6),
    )

    from rlkit.torch.vpg.trpo_sup import TRPOSupTrainer
    trainer = TRPOSupTrainer(policy=policy,
                             value_function=vf,
                             vf_criterion=vf_criterion,
                             replay_buffer=replay_buffer,
                             **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #2
0
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name, **variant['env_kwargs'])
    eval_env = make_env(args.exp_name, **variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim
    max_path_length = variant['trainer_kwargs']['max_path_length']

    if variant['load_kwargs']['load']:
        load_dir = variant['load_kwargs']['load_dir']
        load_data = torch.load(load_dir + '/params.pkl', map_location='cpu')
        policy = load_data['trainer/policy']
        vf = load_data['trainer/value_function']
    else:
        hidden_dim = variant['lstm_kwargs']['hidden_dim']
        num_lstm_layers = variant['lstm_kwargs']['num_layers']
        node_dim = variant['gnn_kwargs']['node_dim']

        node_num = expl_env.max_veh_num + 1
        input_node_dim = int(obs_dim / node_num)
        a_0 = np.zeros(action_dim)
        h1_0 = np.zeros((node_num, hidden_dim * num_lstm_layers))
        c1_0 = np.zeros((node_num, hidden_dim * num_lstm_layers))
        h2_0 = np.zeros((node_num, hidden_dim * num_lstm_layers))
        c2_0 = np.zeros((node_num, hidden_dim * num_lstm_layers))
        latent_0 = (h1_0, c1_0, h2_0, c2_0)
        from lstm_net import LSTMNet
        lstm1_ego = LSTMNet(input_node_dim, action_dim, hidden_dim,
                            num_lstm_layers)
        lstm1_other = LSTMNet(input_node_dim, 0, hidden_dim, num_lstm_layers)
        lstm2_ego = LSTMNet(node_dim, 0, hidden_dim, num_lstm_layers)
        lstm2_other = LSTMNet(node_dim, 0, hidden_dim, num_lstm_layers)
        from graph_builder import TrafficGraphBuilder
        gb = TrafficGraphBuilder(
            input_dim=hidden_dim,
            node_num=node_num,
            ego_init=torch.tensor([0., 1.]),
            other_init=torch.tensor([1., 0.]),
        )
        from gnn_net import GNNNet
        gnn = GNNNet(
            pre_graph_builder=gb,
            node_dim=variant['gnn_kwargs']['node_dim'],
            conv_type=variant['gnn_kwargs']['conv_type'],
            num_conv_layers=variant['gnn_kwargs']['num_layers'],
            hidden_activation=variant['gnn_kwargs']['activation'],
        )
        from gnn_lstm2_net import GNNLSTM2Net
        policy_net = GNNLSTM2Net(node_num, gnn, lstm1_ego, lstm1_other,
                                 lstm2_ego, lstm2_other)
        from layers import FlattenLayer, SelectLayer
        decoder = nn.Sequential(SelectLayer(-2, 0), FlattenLayer(2), nn.ReLU(),
                                nn.Linear(hidden_dim, action_dim))
        from layers import ReshapeLayer
        sup_learner = nn.Sequential(
            SelectLayer(-2, np.arange(1, node_num)),
            nn.ReLU(),
            nn.Linear(hidden_dim, label_dim),
        )
        from sup_softmax_lstm_policy import SupSoftmaxLSTMPolicy
        policy = SupSoftmaxLSTMPolicy(
            a_0=a_0,
            latent_0=latent_0,
            obs_dim=obs_dim,
            action_dim=action_dim,
            lstm_net=policy_net,
            decoder=decoder,
            sup_learner=sup_learner,
        )
        print('parameters: ',
              np.sum([p.view(-1).shape[0] for p in policy.parameters()]))

        vf = Mlp(
            hidden_sizes=[32, 32],
            input_size=obs_dim,
            output_size=1,
        )

    vf_criterion = nn.MSELoss()
    from rlkit.torch.policies.make_deterministic import MakeDeterministic
    eval_policy = MakeDeterministic(policy)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )

    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        action_dim=action_dim,
        label_dim=label_num,
        max_replay_buffer_size=int(1e6),
        max_path_length=max_path_length,
        recurrent=True,
    )

    from rlkit.torch.vpg.ppo_sup_vanilla import PPOSupVanillaTrainer
    trainer = PPOSupVanillaTrainer(policy=policy,
                                   value_function=vf,
                                   vf_criterion=vf_criterion,
                                   replay_buffer=replay_buffer,
                                   recurrent=True,
                                   **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function=get_traffic_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #3
0
ファイル: ppo_sup_gnn_multi_0.py プロジェクト: maxiaoba/rlkit
def experiment(variant):
    import sys
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name,**variant['env_kwargs'])
    eval_env = make_env(args.exp_name,**variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n

    from graph_builder_multi import MultiTrafficGraphBuilder
    gb = MultiTrafficGraphBuilder(input_dim=4, node_num=expl_env.max_veh_num+1,
                            ego_init=torch.tensor([0.,1.]),
                            other_init=torch.tensor([1.,0.]),
                            )
    from gnn_net import GNNNet
    gnn = GNNNet( 
                pre_graph_builder = gb, 
                node_dim = 16,
                num_conv_layers=3)

    from layers import SelectLayer
    encoders = []
    encoders.append(nn.Sequential(gnn,SelectLayer(1,0),nn.ReLU()))
    sup_learners = []
    for i in range(expl_env.max_veh_num):
        sup_learner = nn.Sequential(
                gnn,
                SelectLayer(1,i+1),
                nn.ReLU(),
                nn.Linear(16, 2),
                )
        sup_learner = SoftmaxPolicy(sup_learner, learn_temperature=False)
        sup_learners.append(sup_learner)
        encoders.append(sup_learner)

    decoder = Mlp(input_size=int(16+2*expl_env.max_veh_num),
              output_size=action_dim,
              hidden_sizes=[],
            )
    from layers import ConcatLayer
    need_gradients = np.array([True]*len(encoders))
    if variant['no_gradient']:
        need_gradients[1:] = False
    policy = nn.Sequential(
            ConcatLayer(encoders, need_gradients=list(need_gradients), dim=1),
            decoder,
            )
    policy = SoftmaxPolicy(policy, learn_temperature=False)

    vf = Mlp(
        hidden_sizes=[32, 32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy,use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim = obs_dim,
        label_dims = [1]*expl_env.max_veh_num,
        max_replay_buffer_size = int(1e6),
    )

    from rlkit.torch.vpg.ppo_sup import PPOSupTrainer
    trainer = PPOSupTrainer(
        policy=policy,
        value_function=vf,
        vf_criterion=vf_criterion,
        sup_learners=sup_learners,
        replay_buffer=replay_buffer,
        **variant['trainer_kwargs']
    )
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function = get_traffic_path_information,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #4
0
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name, **variant['env_kwargs'])
    eval_env = make_env(args.exp_name, **variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim

    from graph_builder_multi import MultiTrafficGraphBuilder
    policy_gb = MultiTrafficGraphBuilder(
        input_dim=4 + label_dim,
        node_num=expl_env.max_veh_num + 1,
        ego_init=torch.tensor([0., 1.]),
        other_init=torch.tensor([1., 0.]),
    )
    if variant['gnn_kwargs']['attention']:
        from gnn_attention_net import GNNAttentionNet
        gnn_class = GNNAttentionNet
    else:
        from gnn_net import GNNNet
        gnn_class = GNNNet
    policy_gnn = gnn_class(
        pre_graph_builder=policy_gb,
        node_dim=variant['gnn_kwargs']['node'],
        num_conv_layers=variant['gnn_kwargs']['layer'],
        hidden_activation=variant['gnn_kwargs']['activation'],
    )
    from layers import FlattenLayer, SelectLayer
    policy = nn.Sequential(
        policy_gnn, SelectLayer(1, 0), FlattenLayer(), nn.ReLU(),
        nn.Linear(variant['gnn_kwargs']['node'], action_dim))

    sup_gb = MultiTrafficGraphBuilder(
        input_dim=4,
        node_num=expl_env.max_veh_num + 1,
        ego_init=torch.tensor([0., 1.]),
        other_init=torch.tensor([1., 0.]),
    )
    sup_attentioner = None
    from layers import ReshapeLayer
    from gnn_net import GNNNet
    sup_gnn = GNNNet(
        pre_graph_builder=sup_gb,
        node_dim=variant['gnn_kwargs']['node'],
        num_conv_layers=variant['gnn_kwargs']['layer'],
        hidden_activation=variant['gnn_kwargs']['activation'],
    )
    sup_learner = nn.Sequential(
        sup_gnn,
        SelectLayer(1, np.arange(1, expl_env.max_veh_num + 1)),
        nn.ReLU(),
        nn.Linear(variant['gnn_kwargs']['node'], label_dim),
    )
    from sup_sep_softmax_policy import SupSepSoftmaxPolicy
    policy = SupSepSoftmaxPolicy(policy, sup_learner, label_num, label_dim)
    print('parameters: ',
          np.sum([p.view(-1).shape[0] for p in policy.parameters()]))

    vf = Mlp(
        hidden_sizes=[32, 32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    from sup_sep_rollout import sup_sep_rollout
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
        rollout_fn=sup_sep_rollout,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        label_dim=label_num,
        max_replay_buffer_size=int(1e6),
    )

    from rlkit.torch.vpg.trpo_sup_sep import TRPOSupSepTrainer
    trainer = TRPOSupSepTrainer(policy=policy,
                                value_function=vf,
                                vf_criterion=vf_criterion,
                                replay_buffer=replay_buffer,
                                **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function=get_traffic_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #5
0
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name, **variant['env_kwargs'])
    eval_env = make_env(args.exp_name, **variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n

    encoders = []
    mlp = Mlp(
        input_size=obs_dim,
        output_size=32,
        hidden_sizes=[
            32,
        ],
    )
    encoders.append(mlp)
    sup_learners = []
    for i in range(2):
        mlp = Mlp(
            input_size=obs_dim,
            output_size=2,
            hidden_sizes=[
                32,
            ],
        )
        sup_learner = SoftmaxPolicy(mlp, learn_temperature=False)
        sup_learners.append(sup_learner)
        encoders.append(sup_learner)
    decoder = Mlp(
        input_size=int(32 + 2 * 2),
        output_size=action_dim,
        hidden_sizes=[],
    )
    module = CombineNet(
        encoders=encoders,
        decoder=decoder,
        no_gradient=variant['no_gradient'],
    )
    policy = SoftmaxPolicy(module, **variant['policy_kwargs'])

    vf = Mlp(
        hidden_sizes=[32, 32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        label_dims=[1, 1],
        max_replay_buffer_size=int(1e6),
    )

    from rlkit.torch.vpg.ppo_sup import PPOSupTrainer
    trainer = PPOSupTrainer(policy=policy,
                            value_function=vf,
                            vf_criterion=vf_criterion,
                            sup_learners=sup_learners,
                            replay_buffer=replay_buffer,
                            **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function=get_traffic_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #6
0
ファイル: ppo_sup_lstm.py プロジェクト: maxiaoba/rlkit
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name, **variant['env_kwargs'])
    eval_env = make_env(args.exp_name, **variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim
    max_path_length = variant['trainer_kwargs']['max_path_length']

    if variant['load_kwargs']['load']:
        load_dir = variant['load_kwargs']['load_dir']
        load_data = torch.load(load_dir + '/params.pkl', map_location='cpu')
        policy = load_data['trainer/policy']
        vf = load_data['trainer/value_function']
    else:
        hidden_dim = variant['lstm_kwargs']['hidden_dim']
        num_layers = variant['lstm_kwargs']['num_layers']
        a_0 = np.zeros(action_dim)
        h_0 = np.zeros(hidden_dim * num_layers)
        c_0 = np.zeros(hidden_dim * num_layers)
        latent_0 = (h_0, c_0)
        from lstm_net import LSTMNet
        lstm_net = LSTMNet(obs_dim, action_dim, hidden_dim, num_layers)
        decoder = nn.Linear(hidden_dim, action_dim)
        from layers import ReshapeLayer
        sup_learner = nn.Sequential(
            nn.Linear(hidden_dim, int(label_num * label_dim)),
            ReshapeLayer(shape=(label_num, label_dim)),
        )
        from sup_softmax_lstm_policy import SupSoftmaxLSTMPolicy
        policy = SupSoftmaxLSTMPolicy(
            a_0=a_0,
            latent_0=latent_0,
            obs_dim=obs_dim,
            action_dim=action_dim,
            lstm_net=lstm_net,
            decoder=decoder,
            sup_learner=sup_learner,
        )
        print('parameters: ',
              np.sum([p.view(-1).shape[0] for p in policy.parameters()]))

        vf = Mlp(
            hidden_sizes=[32, 32],
            input_size=obs_dim,
            output_size=1,
        )

    vf_criterion = nn.MSELoss()
    from rlkit.torch.policies.make_deterministic import MakeDeterministic
    eval_policy = MakeDeterministic(policy)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )

    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        action_dim=action_dim,
        label_dim=label_num,
        max_replay_buffer_size=int(1e6),
        max_path_length=max_path_length,
        recurrent=True,
    )

    from rlkit.torch.vpg.ppo_sup import PPOSupTrainer
    trainer = PPOSupTrainer(policy=policy,
                            value_function=vf,
                            vf_criterion=vf_criterion,
                            replay_buffer=replay_buffer,
                            recurrent=True,
                            **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function=get_traffic_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #7
0
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name,**variant['env_kwargs'])
    eval_env = make_env(args.exp_name,**variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim

    encoder = nn.Sequential(
             nn.Linear(obs_dim,32),
             nn.ReLU(),
             nn.Linear(32,32),
             nn.ReLU(),
            )
    decoder = nn.Linear(32, action_dim)
    from layers import ReshapeLayer
    sup_learner = nn.Sequential(
            nn.Linear(32, int(label_num*label_dim)),
            ReshapeLayer(shape=(label_num, label_dim)),
        )
    from sup_softmax_policy import SupSoftmaxPolicy
    policy = SupSoftmaxPolicy(encoder, decoder, sup_learner)
    print('parameters: ',np.sum([p.view(-1).shape[0] for p in policy.parameters()]))
    
    vf = Mlp(
        hidden_sizes=[32, 32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy,use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim = obs_dim,
        label_dim = label_num,
        max_replay_buffer_size = int(1e6),
    )

    from rlkit.torch.vpg.trpo_sup import TRPOSupTrainer
    trainer = TRPOSupTrainer(
        policy=policy,
        value_function=vf,
        vf_criterion=vf_criterion,
        replay_buffer=replay_buffer,
        **variant['trainer_kwargs']
    )
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function = get_traffic_path_information,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #8
0
ファイル: ppo_sup_sep2_multi.py プロジェクト: maxiaoba/rlkit
def experiment(variant):
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name,**variant['env_kwargs'])
    eval_env = make_env(args.exp_name,**variant['env_kwargs'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim

    if variant['load_kwargs']['load']:
        load_dir = variant['load_kwargs']['load_dir']
        load_data = torch.load(load_dir+'/params.pkl',map_location='cpu')
        policy = load_data['trainer/policy']
        vf = load_data['trainer/value_function']
    else:
        hidden_dim = variant['mlp_kwargs']['hidden']
        policy = nn.Sequential(
                 nn.Linear(obs_dim+int(label_dim*label_num+label_dim),hidden_dim),
                 nn.ReLU(),
                 nn.Linear(hidden_dim,hidden_dim),
                 nn.ReLU(),
                 nn.Linear(hidden_dim, action_dim)
                )
        from layers import ReshapeLayer
        sup_learner = nn.Sequential(
                 nn.Linear(obs_dim,hidden_dim),
                 nn.ReLU(),
                 nn.Linear(hidden_dim,hidden_dim),
                 nn.ReLU(),
                 nn.Linear(hidden_dim, int(label_num*label_dim)),
                 ReshapeLayer(shape=(label_num, label_dim)),
            )
        from sup_sep_softmax_policy import SupSepSoftmaxPolicy
        policy = SupSepSoftmaxPolicy(policy, sup_learner, label_num, label_dim)
        print('parameters: ',np.sum([p.view(-1).shape[0] for p in policy.parameters()]))

        vf = Mlp(
            hidden_sizes=[32, 32],
            input_size=obs_dim,
            output_size=1,
        )
        
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy,use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    from sup_sep_rollout import sup_sep_rollout
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
        rollout_fn=sup_sep_rollout,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim = obs_dim,
        label_dim = label_num,
        max_replay_buffer_size = int(1e6),
    )

    from rlkit.torch.vpg.ppo_sup_sep import PPOSupSepTrainer
    trainer = PPOSupSepTrainer(
        policy=policy,
        value_function=vf,
        vf_criterion=vf_criterion,
        replay_buffer=replay_buffer,
        **variant['trainer_kwargs']
    )
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        log_path_function = get_traffic_path_information,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #9
0
def experiment(variant):
    from simple_sup import SimpleSupEnv
    expl_env = SimpleSupEnv(**variant['env_kwars'])
    eval_env = SimpleSupEnv(**variant['env_kwars'])
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    label_num = expl_env.label_num
    label_dim = expl_env.label_dim

    hidden_dim = variant['hidden_dim']
    policy = nn.Sequential(
        nn.Linear(obs_dim + int(label_dim * label_num), hidden_dim), nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
        nn.Linear(hidden_dim, action_dim))
    from layers import ReshapeLayer
    sup_learner = nn.Sequential(
        nn.Linear(obs_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, int(label_num * label_dim)),
        ReshapeLayer(shape=(label_num, label_dim)),
    )
    from sup_sep_softmax_policy import SupSepSoftmaxPolicy
    policy = SupSepSoftmaxPolicy(policy, sup_learner, label_num, label_dim)
    print('parameters: ',
          np.sum([p.view(-1).shape[0] for p in policy.parameters()]))

    vf = Mlp(
        hidden_sizes=[32],
        input_size=obs_dim,
        output_size=1,
    )
    vf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(policy, use_preactivation=True)
    expl_policy = policy

    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    from sup_sep_rollout import sup_sep_rollout
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
        rollout_fn=sup_sep_rollout,
    )
    from sup_replay_buffer import SupReplayBuffer
    replay_buffer = SupReplayBuffer(
        observation_dim=obs_dim,
        label_dim=label_num,
        max_replay_buffer_size=int(1e6),
    )

    from rlkit.torch.vpg.ppo_sup_sep import PPOSupSepTrainer
    trainer = PPOSupSepTrainer(policy=policy,
                               value_function=vf,
                               vf_criterion=vf_criterion,
                               replay_buffer=replay_buffer,
                               **variant['trainer_kwargs'])
    algorithm = TorchOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()