def _set_worker_converters(worker: RolloutWorker): worker_delegate_policy = worker.policy_map["delegate_policy"] player_converters = [] for p in range(2): player_converter = RestrictedToBaseGameActionSpaceConverter( delegate_policy=worker_delegate_policy, policy_specs=player_to_base_game_action_specs[p], load_policy_spec_fn=create_get_pure_strat_cached(cache=weights_cache)) player_converters.append(player_converter) worker.foreach_env(lambda env: env.set_action_conversion(p, player_converter)) worker_delegate_policy.player_converters = player_converters
def _set_conversions(worker: RolloutWorker): def _set_restricted_env_convertions(restricted_env): assert isinstance(restricted_env, RestrictedGame) for agent_id, action_policy_specs in agent_id_to_restricted_game_specs.items(): if len(action_policy_specs) > 0: convertor = RestrictedToBaseGameActionSpaceConverter( delegate_policy=worker.policy_map[delegate_policy_id], policy_specs=action_policy_specs, load_policy_spec_fn=load_policy_spec_fn) restricted_env.set_action_conversion(agent_id=agent_id, converter=convertor) worker.foreach_env(_set_restricted_env_convertions)
def _set_worker_converters(worker: RolloutWorker): worker_delegate_policy = worker.policy_map[delegate_policy_id] for p, player_converter in player_converters.items(): worker.foreach_env(lambda env: env.set_obs_conversion_dict(p, player_converter)) worker_delegate_policy.player_converters = player_converters
def _set_worker_converters(worker: RolloutWorker): worker_delegate_policy = worker.policy_map["delegate_policy"] for p in range(2): worker.foreach_env(lambda env: env.set_obs_conversion_dict(p, player_converters[p])) worker_delegate_policy.player_converters = player_converters