Beispiel #1
0
def rom_agent(model_name,
              i,
              env,
              M,
              u_range,
              base_kwargs,
              g=False,
              mu=1.5,
              game_name='matrix'):
    print(model_name)
    joint = False
    squash = True
    opponent_modelling = True
    squash_func = tf.tanh
    correct_tanh = True
    sampling = False
    # TODO deal with particle problem.
    if 'particle' in game_name:
        sampling = True
        squash = True
        squash_func = tf.nn.softmax
        correct_tanh = False

    pool = SimpleReplayBuffer(env.env_specs,
                              max_replay_buffer_size=1e6,
                              joint=joint,
                              agent_id=i)

    opponent_policy = GaussianPolicy(env.env_specs,
                                     hidden_layer_sizes=(M, M),
                                     squash=True,
                                     joint=False,
                                     agent_id=i,
                                     name='opponent_policy')
    conditional_policy = GaussianConditionalPolicy(
        env.env_specs,
        cond_policy=opponent_policy,
        hidden_layer_sizes=(M, M),
        name='gaussian_conditional_policy',
        opponent_policy=False,
        squash=True,
        joint=False,
        agent_id=i)
    with tf.variable_scope('target_levelk_{}'.format(i), reuse=True):
        target_opponent_policy = GaussianPolicy(env.env_specs,
                                                hidden_layer_sizes=(M, M),
                                                squash=True,
                                                joint=False,
                                                agent_id=i,
                                                name='target_opponent_policy')
        target_conditional_policy = GaussianConditionalPolicy(
            env.env_specs,
            cond_policy=target_opponent_policy,
            hidden_layer_sizes=(M, M),
            name='target_gaussian_conditional_policy',
            opponent_policy=False,
            squash=True,
            joint=False,
            agent_id=i)

    joint_qf = NNJointQFunction(env_spec=env.env_specs,
                                hidden_layer_sizes=[M, M],
                                joint=joint,
                                agent_id=i)
    target_joint_qf = NNJointQFunction(env_spec=env.env_specs,
                                       hidden_layer_sizes=[M, M],
                                       name='target_joint_qf',
                                       agent_id=i)

    plotter = None

    agent = ROMMEO(base_kwargs=base_kwargs,
                   agent_id=i,
                   env=env,
                   pool=pool,
                   joint_qf=joint_qf,
                   target_joint_qf=target_joint_qf,
                   policy=conditional_policy,
                   opponent_policy=opponent_policy,
                   target_policy=target_conditional_policy,
                   plotter=plotter,
                   policy_lr=1e-2,
                   qf_lr=1e-2,
                   joint=True,
                   value_n_particles=16,
                   kernel_fn=adaptive_isotropic_gaussian_kernel,
                   kernel_n_particles=32,
                   kernel_update_ratio=0.5,
                   td_target_update_interval=1,
                   discount=0.95,
                   reward_scale=1,
                   tau=0.01,
                   save_full_state=False,
                   opponent_modelling=True)
    return agent
Beispiel #2
0
def pr2ac_agent(model_name,
                i,
                env,
                M,
                u_range,
                base_kwargs,
                k=0,
                g=False,
                mu=1.5,
                game_name='matrix',
                aux=True):
    joint = False
    squash = True
    squash_func = tf.tanh
    correct_tanh = True
    sampling = False
    if 'particle' in game_name:
        sampling = True
        squash = True
        squash_func = tf.nn.softmax
        correct_tanh = False

    pool = SimpleReplayBuffer(env.env_specs,
                              max_replay_buffer_size=1e6,
                              joint=joint,
                              agent_id=i)

    opponent_conditional_policy = StochasticNNConditionalPolicy(
        env.env_specs,
        hidden_layer_sizes=(M, M),
        name='opponent_conditional_policy',
        squash=squash,
        squash_func=squash_func,
        sampling=sampling,
        u_range=u_range,
        joint=joint,
        agent_id=i)

    if g:
        policies = []
        target_policies = []
        for kk in range(1, k + 1):
            policy, target_policy = get_level_k_policy(
                env,
                kk,
                M,
                i,
                u_range,
                opponent_conditional_policy,
                game_name=game_name)
            policies.append(policy)
            target_policies.append(target_policy)
        policy = GeneralizedMultiLevelPolicy(env.env_specs,
                                             policies=policies,
                                             agent_id=i,
                                             k=k,
                                             mu=mu)
        target_policy = GeneralizedMultiLevelPolicy(env.env_specs,
                                                    policies=policies,
                                                    agent_id=i,
                                                    k=k,
                                                    mu=mu,
                                                    correct_tanh=correct_tanh)
    else:
        if k == 0:
            policy = DeterministicNNPolicy(env.env_specs,
                                           hidden_layer_sizes=(M, M),
                                           squash=squash,
                                           squash_func=squash_func,
                                           sampling=sampling,
                                           u_range=u_range,
                                           joint=False,
                                           agent_id=i)
            target_policy = DeterministicNNPolicy(env.env_specs,
                                                  hidden_layer_sizes=(M, M),
                                                  name='target_policy',
                                                  squash=squash,
                                                  squash_func=squash_func,
                                                  sampling=sampling,
                                                  u_range=u_range,
                                                  joint=False,
                                                  agent_id=i)
        if k > 0:
            policy, target_policy = get_level_k_policy(
                env,
                k,
                M,
                i,
                u_range,
                opponent_conditional_policy,
                game_name=game_name)

    joint_qf = NNJointQFunction(env_spec=env.env_specs,
                                hidden_layer_sizes=[M, M],
                                joint=joint,
                                agent_id=i)
    target_joint_qf = NNJointQFunction(env_spec=env.env_specs,
                                       hidden_layer_sizes=[M, M],
                                       name='target_joint_qf',
                                       joint=joint,
                                       agent_id=i)

    qf = NNQFunction(env_spec=env.env_specs,
                     hidden_layer_sizes=[M, M],
                     joint=False,
                     agent_id=i)
    plotter = None

    agent = MAVBAC(base_kwargs=base_kwargs,
                   agent_id=i,
                   env=env,
                   pool=pool,
                   joint_qf=joint_qf,
                   target_joint_qf=target_joint_qf,
                   qf=qf,
                   policy=policy,
                   target_policy=target_policy,
                   conditional_policy=opponent_conditional_policy,
                   plotter=plotter,
                   policy_lr=1e-2,
                   qf_lr=1e-2,
                   joint=False,
                   value_n_particles=16,
                   kernel_fn=adaptive_isotropic_gaussian_kernel,
                   kernel_n_particles=32,
                   kernel_update_ratio=0.5,
                   td_target_update_interval=1,
                   discount=0.95,
                   reward_scale=1,
                   tau=0.01,
                   save_full_state=False,
                   k=k,
                   aux=aux)
    return agent