예제 #1
0
def get_ddpgtom_agent(env, agent_id, hidden_layer_sizes,
                      max_replay_buffer_size):
    observation_space = env.env_specs.observation_space[agent_id]
    action_space = env.env_specs.action_space[agent_id]
    return DDPGToMAgent(
        env_specs=env.env_specs,
        policy=DeterministicMLPPolicy(input_shapes=(observation_space.shape, (
            env.env_specs.action_space.opponent_flat_dim(agent_id), )),
                                      output_shape=action_space.shape,
                                      hidden_layer_sizes=hidden_layer_sizes,
                                      name='policy_agent_{}'.format(agent_id)),
        qf=MLPValueFunction(
            input_shapes=(observation_space.shape,
                          (env.env_specs.action_space.flat_dim, )),
            output_shape=(1, ),
            hidden_layer_sizes=hidden_layer_sizes,
            name='qf_agent_{}'.format(agent_id)),
        opponent_policy=DeterministicMLPPolicy(
            input_shapes=(observation_space.shape, ),
            output_shape=(
                env.env_specs.action_space.opponent_flat_dim(agent_id), ),
            hidden_layer_sizes=hidden_layer_sizes,
            name='opponent_policy_agent_{}'.format(agent_id)),
        replay_buffer=IndexedReplayBuffer(
            observation_dim=observation_space.shape[0],
            action_dim=action_space.shape[0],
            opponent_action_dim=env.env_specs.action_space.opponent_flat_dim(
                agent_id),
            max_replay_buffer_size=max_replay_buffer_size),
        exploration_strategy=OUExploration(action_space),
        gradient_clipping=10.,
        agent_id=agent_id,
    )
예제 #2
0
def get_ddpg_agent(env,
                   agent_id,
                   hidden_layer_sizes,
                   max_replay_buffer_size,
                   policy_type='dete'):
    observation_space = env.env_specs.observation_space[agent_id]
    action_space = env.env_specs.action_space[agent_id]
    if policy_type == 'dete':
        policy_fn = DeterministicMLPPolicy
        exploration_strategy = OUExploration(action_space)
    elif policy_type == 'gumble':
        policy_fn = RelaxedSoftmaxMLPPolicy
        exploration_strategy = None
    return DDPGAgent(
        env_specs=env.env_specs,
        policy=policy_fn(input_shapes=(observation_space.shape, ),
                         output_shape=action_space.shape,
                         hidden_layer_sizes=hidden_layer_sizes,
                         name='policy_agent_{}'.format(agent_id)),
        qf=MLPValueFunction(input_shapes=(observation_space.shape,
                                          action_space.shape),
                            output_shape=(1, ),
                            hidden_layer_sizes=hidden_layer_sizes,
                            name='qf_agent_{}'.format(agent_id)),
        replay_buffer=IndexedReplayBuffer(
            observation_dim=observation_space.shape[0],
            action_dim=action_space.shape[0],
            max_replay_buffer_size=max_replay_buffer_size),
        exploration_strategy=exploration_strategy,
        gradient_clipping=10.,
        agent_id=agent_id,
    )
예제 #3
0
def get_pr2_agent(env,
                  agent_id,
                  hidden_layer_sizes,
                  max_replay_buffer_size,
                  policy_type="deter"):
    observation_space = env.env_specs.observation_space[agent_id]
    action_space = env.env_specs.action_space[agent_id]
    opponent_action_shape = (
        env.env_specs.action_space.opponent_flat_dim(agent_id), )
    print(opponent_action_shape, "opponent_action_shape")
    if policy_type == "dete":
        policy_fn = DeterministicMLPPolicy
        exploration_strategy = OUExploration(action_space)
    elif policy_type == "gumble":
        policy_fn = RelaxedSoftmaxMLPPolicy
        exploration_strategy = None
    return PR2Agent(
        env_specs=env.env_specs,
        policy=policy_fn(
            input_shapes=(observation_space.shape, ),
            output_shape=action_space.shape,
            hidden_layer_sizes=hidden_layer_sizes,
            name="policy_agent_{}".format(agent_id),
        ),
        qf=MLPValueFunction(
            input_shapes=(
                observation_space.shape,
                action_space.shape,
                opponent_action_shape,
            ),
            output_shape=(1, ),
            hidden_layer_sizes=hidden_layer_sizes,
            name="qf_agent_{}".format(agent_id),
        ),
        ind_qf=MLPValueFunction(
            input_shapes=(observation_space.shape, action_space.shape),
            output_shape=(1, ),
            hidden_layer_sizes=hidden_layer_sizes,
            name="ind_qf_agent_{}".format(agent_id),
        ),
        replay_buffer=IndexedReplayBuffer(
            observation_dim=observation_space.shape[0],
            action_dim=action_space.shape[0],
            max_replay_buffer_size=max_replay_buffer_size,
            opponent_action_dim=opponent_action_shape[0],
        ),
        opponent_policy=policy_fn(
            input_shapes=(observation_space.shape, action_space.shape),
            output_shape=opponent_action_shape,
            hidden_layer_sizes=hidden_layer_sizes,
            name="opponent_policy_agent_{}".format(agent_id),
        ),
        exploration_strategy=exploration_strategy,
        gradient_clipping=10.0,
        agent_id=agent_id,
    )
예제 #4
0
def get_commnet_agent(env,
                      agent_id,
                      hidden_layer_sizes,
                      max_replay_buffer_size,
                      policy_type="deter"):
    observation_space = env.env_specs.observation_space[agent_id]
    n = env.env_specs.agent_num
    action_space = env.env_specs.action_space[agent_id]
    if policy_type == "deter":
        policy_fn = DeterministicMLPPolicy
        exploration_strategy = OUExploration(action_space)
    elif policy_type == "gumble":
        policy_fn = RelaxedSoftmaxMLPPolicy
        exploration_strategy = None
    return FullyCentralizedAgent(
        env_specs=env.env_specs,
        policy=policy_fn(
            input_shapes=((n, ) + observation_space.shape, ),
            output_shape=action_space.shape,
            hidden_layer_sizes=hidden_layer_sizes,
            name="policy_agent_{}".format(agent_id),
        ),
        qf=CommNetValueFunction(
            input_shapes=((n, ) + observation_space.shape,
                          (n, ) + action_space.shape),
            output_shape=(1, ),
            hidden_layer_sizes=hidden_layer_sizes,
            name="qf_agent_{}".format(agent_id),
        ),
        replay_buffer=IndexedReplayBuffer(
            observation_dim=n * observation_space.shape[0],
            action_dim=n * action_space.shape[0],
            max_replay_buffer_size=max_replay_buffer_size,
            reward_dim=n,
            terminal_dim=n,
        ),
        exploration_strategy=exploration_strategy,
        gradient_clipping=10.0,
        agent_id=agent_id,
    )