コード例 #1
0
ファイル: train_dqn.py プロジェクト: youjp/Hierarchical-DQN
def make_agent(agent_type, env, num_clusters, use_extra_travel_penalty,
               use_extra_bit, use_controller_dqn, use_intrinsic_timeout,
               use_memory, memory_size, pretrain_controller):
    if agent_type == 'dqn':
        return dqn.DqnAgent(state_dims=[2],
                            num_actions=2)  # env.action_space.n
    elif agent_type == 'h_dqn':
        meta_controller_state_fn, check_subgoal_fn, num_subgoals, subgoals = clustering.get_cluster_fn(
            n_clusters=num_clusters, extra_bit=use_extra_bit)

        return hierarchical_dqn.HierarchicalDqnAgent(
            state_sizes=[num_subgoals, [2]],
            agent_types=['tabular', 'network'],
            subgoals=subgoals,
            num_subgoals=num_subgoals,
            num_primitive_actions=2,  # env.action_space.n
            meta_controller_state_fn=meta_controller_state_fn,
            check_subgoal_fn=check_subgoal_fn,
            use_extra_travel_penalty=use_extra_travel_penalty,
            use_extra_bit_for_subgoal_center=use_extra_bit,
            use_controller_dqn=use_controller_dqn,
            use_intrinsic_timeout=use_intrinsic_timeout,
            use_memory=use_memory,
            memory_size=memory_size,
            pretrain_controller=pretrain_controller)
コード例 #2
0
def make_agent(agent_type, env):
    if agent_type == 'dqn':
        return dqn.DqnAgent(state_dims=[2],
                            num_actions=2)  # env.action_space.n
    elif agent_type == 'h_dqn':
        meta_controller_state_fn, check_subgoal_fn, num_subgoals, subgoals = clustering.get_cluster_fn(
            n_clusters=4, extra_bit=False)

        return hierarchical_dqn.HierarchicalDqnAgent(
            state_sizes=[[num_subgoals], 2],
            subgoals=subgoals,
            num_subgoals=num_subgoals,
            num_primitive_actions=2,  # env.action_space.n
            meta_controller_state_fn=meta_controller_state_fn,
            check_subgoal_fn=check_subgoal_fn)
コード例 #3
0
def make_agent(agent_type, env, load=True):
    if agent_type == 'dqn':
        return dqn.DqnAgent(state_dims=[2], num_actions=env.action_space.n)
    elif agent_type == 'h_dqn':
        meta_controller_state_fn, check_subgoal_fn, num_subgoals = None, check_subgoal, 2

        # subgoals = [\
        #     [-.7,-.2],
        #     [-1,0],
        #     [.5,.2],
        #     [ 1,0]
        # ]
        #clustering.get_cluster_fn(n_clusters=num_clusters, extra_bit=use_extra_bit)

        return hierarchical_dqn.HierarchicalDqnAgent(
            state_sizes=env.observation_space.shape,
            subgoals=subgoals,
            num_subgoals=num_subgoals,
            num_primitive_actions=env.action_space.n,
            meta_controller_state_fn=meta_controller_state_fn,
            check_subgoal_fn=check_subgoal_fn,
            load=load)