def get_ddpgtom_agent(env, agent_id, hidden_layer_sizes, max_replay_buffer_size): observation_space = env.env_specs.observation_space[agent_id] action_space = env.env_specs.action_space[agent_id] return DDPGToMAgent( env_specs=env.env_specs, policy=DeterministicMLPPolicy(input_shapes=(observation_space.shape, ( env.env_specs.action_space.opponent_flat_dim(agent_id), )), output_shape=action_space.shape, hidden_layer_sizes=hidden_layer_sizes, name='policy_agent_{}'.format(agent_id)), qf=MLPValueFunction( input_shapes=(observation_space.shape, (env.env_specs.action_space.flat_dim, )), output_shape=(1, ), hidden_layer_sizes=hidden_layer_sizes, name='qf_agent_{}'.format(agent_id)), opponent_policy=DeterministicMLPPolicy( input_shapes=(observation_space.shape, ), output_shape=( env.env_specs.action_space.opponent_flat_dim(agent_id), ), hidden_layer_sizes=hidden_layer_sizes, name='opponent_policy_agent_{}'.format(agent_id)), replay_buffer=IndexedReplayBuffer( observation_dim=observation_space.shape[0], action_dim=action_space.shape[0], opponent_action_dim=env.env_specs.action_space.opponent_flat_dim( agent_id), max_replay_buffer_size=max_replay_buffer_size), exploration_strategy=OUExploration(action_space), gradient_clipping=10., agent_id=agent_id, )
def get_ddpg_agent(env, agent_id, hidden_layer_sizes, max_replay_buffer_size, policy_type='dete'): observation_space = env.env_specs.observation_space[agent_id] action_space = env.env_specs.action_space[agent_id] if policy_type == 'dete': policy_fn = DeterministicMLPPolicy exploration_strategy = OUExploration(action_space) elif policy_type == 'gumble': policy_fn = RelaxedSoftmaxMLPPolicy exploration_strategy = None return DDPGAgent( env_specs=env.env_specs, policy=policy_fn(input_shapes=(observation_space.shape, ), output_shape=action_space.shape, hidden_layer_sizes=hidden_layer_sizes, name='policy_agent_{}'.format(agent_id)), qf=MLPValueFunction(input_shapes=(observation_space.shape, action_space.shape), output_shape=(1, ), hidden_layer_sizes=hidden_layer_sizes, name='qf_agent_{}'.format(agent_id)), replay_buffer=IndexedReplayBuffer( observation_dim=observation_space.shape[0], action_dim=action_space.shape[0], max_replay_buffer_size=max_replay_buffer_size), exploration_strategy=exploration_strategy, gradient_clipping=10., agent_id=agent_id, )
def get_pr2_agent(env, agent_id, hidden_layer_sizes, max_replay_buffer_size, policy_type="deter"): observation_space = env.env_specs.observation_space[agent_id] action_space = env.env_specs.action_space[agent_id] opponent_action_shape = ( env.env_specs.action_space.opponent_flat_dim(agent_id), ) print(opponent_action_shape, "opponent_action_shape") if policy_type == "dete": policy_fn = DeterministicMLPPolicy exploration_strategy = OUExploration(action_space) elif policy_type == "gumble": policy_fn = RelaxedSoftmaxMLPPolicy exploration_strategy = None return PR2Agent( env_specs=env.env_specs, policy=policy_fn( input_shapes=(observation_space.shape, ), output_shape=action_space.shape, hidden_layer_sizes=hidden_layer_sizes, name="policy_agent_{}".format(agent_id), ), qf=MLPValueFunction( input_shapes=( observation_space.shape, action_space.shape, opponent_action_shape, ), output_shape=(1, ), hidden_layer_sizes=hidden_layer_sizes, name="qf_agent_{}".format(agent_id), ), ind_qf=MLPValueFunction( input_shapes=(observation_space.shape, action_space.shape), output_shape=(1, ), hidden_layer_sizes=hidden_layer_sizes, name="ind_qf_agent_{}".format(agent_id), ), replay_buffer=IndexedReplayBuffer( observation_dim=observation_space.shape[0], action_dim=action_space.shape[0], max_replay_buffer_size=max_replay_buffer_size, opponent_action_dim=opponent_action_shape[0], ), opponent_policy=policy_fn( input_shapes=(observation_space.shape, action_space.shape), output_shape=opponent_action_shape, hidden_layer_sizes=hidden_layer_sizes, name="opponent_policy_agent_{}".format(agent_id), ), exploration_strategy=exploration_strategy, gradient_clipping=10.0, agent_id=agent_id, )
def get_commnet_agent(env, agent_id, hidden_layer_sizes, max_replay_buffer_size, policy_type="deter"): observation_space = env.env_specs.observation_space[agent_id] n = env.env_specs.agent_num action_space = env.env_specs.action_space[agent_id] if policy_type == "deter": policy_fn = DeterministicMLPPolicy exploration_strategy = OUExploration(action_space) elif policy_type == "gumble": policy_fn = RelaxedSoftmaxMLPPolicy exploration_strategy = None return FullyCentralizedAgent( env_specs=env.env_specs, policy=policy_fn( input_shapes=((n, ) + observation_space.shape, ), output_shape=action_space.shape, hidden_layer_sizes=hidden_layer_sizes, name="policy_agent_{}".format(agent_id), ), qf=CommNetValueFunction( input_shapes=((n, ) + observation_space.shape, (n, ) + action_space.shape), output_shape=(1, ), hidden_layer_sizes=hidden_layer_sizes, name="qf_agent_{}".format(agent_id), ), replay_buffer=IndexedReplayBuffer( observation_dim=n * observation_space.shape[0], action_dim=n * action_space.shape[0], max_replay_buffer_size=max_replay_buffer_size, reward_dim=n, terminal_dim=n, ), exploration_strategy=exploration_strategy, gradient_clipping=10.0, agent_id=agent_id, )