예제 #1
0
def experiment(variant):
    from rlkit.envs.gym_minigrid.gym_minigrid import envs

    expl_env = ToolsEnv(**variant['env_kwargs'])
    eval_env = ToolsEnv(**variant['env_kwargs'])

    rollout_env = ToolsEnv(**variant['env_kwargs'])

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    layer_size = variant['algo_kwargs']['layer_size']
    lifetime = variant['env_kwargs'].get('time_horizon', 0) == 0
    if lifetime:
        assert eval_env.time_horizon == 0, 'cannot have time horizon for lifetime env'

    qf = gen_network(variant['algo_kwargs'], action_dim, layer_size)
    target_qf = gen_network(variant['algo_kwargs'], action_dim, layer_size)

    qf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(qf)
    # eval_policy = SoftmaxQPolicy(qf)
    expl_policy = PolicyWrappedWithExplorationStrategy(
        EpsilonGreedyDecay(expl_env.action_space, 1e-5, 1, 0.1),
        eval_policy,
    )
    if lifetime:
        eval_policy = expl_policy
    # expl_policy = PolicyWrappedWithExplorationStrategy(
    #     EpsilonGreedy(expl_env.action_space, 0.5),
    #     eval_policy,
    # )
    if eval_env.time_horizon == 0:
        collector_class = LifetimeMdpPathCollector if lifetime else MdpPathCollector
    else:
        collector_class = MdpPathCollectorConfig
    eval_path_collector = collector_class(
        eval_env,
        eval_policy,
        # render=True
    )
    expl_path_collector = collector_class(expl_env, expl_policy)
    trainer = DoubleDQNTrainer(qf=qf,
                               target_qf=target_qf,
                               qf_criterion=qf_criterion,
                               **variant['algo_kwargs']['trainer_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['algo_kwargs']['replay_buffer_size'], expl_env)
    #algo_class = TorchLifetimeRLAlgorithm if lifetime else TorchBatchRLAlgorithm
    algo_class = TorchHumanInputLifetimeRLAlgorithm
    algorithm = algo_class(trainer=trainer,
                           exploration_env=expl_env,
                           evaluation_env=eval_env,
                           rollout_env=rollout_env,
                           exploration_data_collector=expl_path_collector,
                           evaluation_data_collector=eval_path_collector,
                           replay_buffer=replay_buffer,
                           **variant['algo_kwargs']['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
예제 #2
0
파일: dqn_gnn.py 프로젝트: maxiaoba/rlkit
def experiment(variant):
    import sys
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name)
    eval_env = make_env(args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n

    gb = TrafficGraphBuilder(input_dim=4,
                             ego_init=torch.tensor([0., 1.]),
                             other_init=torch.tensor([1., 0.]),
                             edge_index=torch.tensor([[0, 0, 1, 2],
                                                      [1, 2, 0, 0]]))
    qf = GNNNet(pre_graph_builder=gb,
                node_dim=16,
                output_dim=action_dim,
                post_mlp_kwargs=variant['qf_kwargs'],
                num_conv_layers=3)

    target_qf = copy.deepcopy(qf)
    eval_policy = ArgmaxDiscretePolicy(qf)
    expl_policy = PolicyWrappedWithExplorationStrategy(
        EpsilonGreedy(expl_env.action_space, variant['epsilon']),
        eval_policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    qf_criterion = nn.MSELoss()
    trainer = DoubleDQNTrainer(qf=qf,
                               target_qf=target_qf,
                               qf_criterion=qf_criterion,
                               **variant['trainer_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
예제 #3
0
def experiment(variant):
    import sys
    from traffic.make_env import make_env
    expl_env = make_env(args.exp_name)
    eval_env = make_env(args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.n

    qf = Mlp(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['qf_kwargs']
    )
    target_qf = copy.deepcopy(qf)
    eval_policy = ArgmaxDiscretePolicy(qf)
    expl_policy = PolicyWrappedWithExplorationStrategy(
        EpsilonGreedy(expl_env.action_space, variant['epsilon']),
        eval_policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )
    qf_criterion = nn.MSELoss()
    trainer = DoubleDQNTrainer(
        qf=qf,
        target_qf=target_qf,
        qf_criterion=qf_criterion,
        **variant['trainer_kwargs']
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function = get_traffic_path_information,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    # common.initialise(variant)

    setup_logger("name-of-experiment", variant=variant)
    ptu.set_gpu_mode(True)

    expl_env = gym.make(variant["env_name"], seed=5)
    eval_env = gym.make(variant["env_name"], seed=5)

    ANCILLARY_GOAL_SIZE = 16
    SYMBOLIC_ACTION_SIZE = (
        12
    )  # Size of embedding (ufva/multihead) for goal space direction to controller
    GRID_SIZE = 31

    action_dim = ANCILLARY_GOAL_SIZE
    symbolic_action_space = gym.spaces.Discrete(ANCILLARY_GOAL_SIZE)
    symb_env = gym.make(variant["env_name"])
    symb_env.action_space = symbolic_action_space

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_env)

    qf = Mlp(
        input_size=n,
        output_size=action_dim,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )
    target_qf = Mlp(
        input_size=n,
        output_size=action_dim,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )

    planner = ENHSPPlanner()

    # collect
    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/e77c75eed02e4b38a0a308789fbfcbd8/data/params.pkl"  # collect
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_collect_policy = policies["exploration/policy"]
    loaded_collect_policy.rnn_hxs = loaded_collect_policy.rnn_hxs[0].unsqueeze(
        0)
    eval_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)
    expl_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)

    # other
    # filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/cf5c31afe0724acd8f6398d77a80443e/data/params.pkl"  # other (RC 28)
    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/4989f4bcbadb4ac58c3668c068d63225/data/params.pkl"  # other (RC 55)
    # filepath = "/home/achester/Documents/misc/craft-model/params.pkl"
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_other_policy = policies["exploration/policy"]
    loaded_other_policy.rnn_hxs = loaded_other_policy.rnn_hxs[0].unsqueeze(0)
    eval_other = CraftController(loaded_other_policy, n=GRID_SIZE)
    expl_other = CraftController(loaded_other_policy, n=GRID_SIZE)

    eval_controller = PretrainedController([eval_collect, eval_other])
    expl_controller = PretrainedController([expl_collect, expl_other])

    function_env = gym.make(variant["env_name"])

    qf_criterion = nn.MSELoss()
    if variant["softmax"]:
        eval_learner = SoftmaxDiscretePolicy(qf, variant["temperature"])
    else:
        eval_learner = ArgmaxDiscretePolicy(qf)

    expl_learner = PolicyWrappedWithExplorationStrategy(
        LinearEpsilonGreedy(symbolic_action_space,
                            anneal_schedule=variant["anneal_schedule"]),
        eval_learner,
    )

    eval_policy = LearnPlanPolicy(
        eval_learner,
        planner,
        eval_controller,
        num_processes=1,
        vectorised=False,
        env=function_env,
    )

    expl_policy = LearnPlanPolicy(
        expl_learner,
        planner,
        expl_controller,
        num_processes=1,
        vectorised=False,
        env=function_env,
    )

    eval_path_collector = IntermediatePathCollector(
        eval_env,
        eval_policy,
        rollout=intermediate_rollout,
        gamma=1,
        render=variant["render"],
        single_plan_discounting=variant["trainer_kwargs"]
        ["single_plan_discounting"],
        experience_interval=variant["experience_interval"],
    )
    expl_path_collector = IntermediatePathCollector(
        expl_env,
        expl_policy,
        rollout=intermediate_rollout,
        gamma=variant["trainer_kwargs"]["discount"],
        render=variant["render"],
        single_plan_discounting=variant["trainer_kwargs"]
        ["single_plan_discounting"],
        experience_interval=variant["experience_interval"],
    )

    if variant["double_dqn"]:
        trainer = DoubleDQNTrainer(qf=qf,
                                   target_qf=target_qf,
                                   qf_criterion=qf_criterion,
                                   **variant["trainer_kwargs"])
    else:
        trainer = DQNTrainer(qf=qf,
                             target_qf=target_qf,
                             qf_criterion=qf_criterion,
                             **variant["trainer_kwargs"])
    replay_buffer = PlanReplayBuffer(variant["replay_buffer_size"], symb_env)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])

    algorithm.to(ptu.device)

    algorithm.train()