예제 #1
0
    def test_custom_port_app(self, mock_unity3d):
        """Test if the base_port + worker_id is different
        for each environment when using custom ports"""

        _ = Unity3DEnv(file_name="app", port=5010)
        args, kwargs_first = mock_unity3d.call_args
        _ = Unity3DEnv(file_name="app", port=5010)
        args, kwargs_second = mock_unity3d.call_args
        self.assertNotEqual(
            kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
            kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
예제 #2
0
    def test_ports_multi_app(self, mock_unity3d):
        """Test if the base_port + worker_id
        is different for each environment"""

        _ = Unity3DEnv(file_name="app", port=None)
        args, kwargs_first = mock_unity3d.call_args
        _ = Unity3DEnv(file_name="app", port=None)
        args, kwargs_second = mock_unity3d.call_args
        self.assertNotEqual(
            kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
            kwargs_second.get("base_port") + kwargs_second.get("worker_id"),
        )
예제 #3
0
    def test_port_editor(self, mock_unity3d):
        """Test if the environment uses the editor port
         when no environment file is provided"""

        _ = Unity3DEnv(port=None)
        args, kwargs = mock_unity3d.call_args
        mock_unity3d.assert_called_once()
        self.assertEqual(5004, kwargs.get("base_port"))
예제 #4
0
    def test_port_app(self, mock_unity3d):
        """Test if the environment uses the correct port
        when the environment file is provided"""

        _ = Unity3DEnv(file_name="app", port=None)
        args, kwargs = mock_unity3d.call_args
        mock_unity3d.assert_called_once()
        self.assertEqual(5005, kwargs.get("base_port"))
예제 #5
0
        # We are remote worker or we are local worker with num_workers=0:
        # Create a PolicyServerInput.
        if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
            return PolicyServerInput(
                ioctx,
                SERVER_ADDRESS,
                args.port + ioctx.worker_index -
                (1 if ioctx.worker_index > 0 else 0),
            )
        # No InputReader (PolicyServerInput) needed.
        else:
            return None

    # Get the multi-agent policies dict and agent->policy
    # mapping-fn.
    policies, policy_mapping_fn = Unity3DEnv.get_policy_configs_for_game(
        args.env)

    # The entire config will be sent to connecting clients so they can
    # build their own samplers (and also Policy objects iff
    # `inference_mode=local` on clients' command line).
    config = {
        # Indicate that the Trainer we setup here doesn't need an actual env.
        # Allow spaces to be determined by user (see below).
        "env": None,
        # Use the `PolicyServerInput` to generate experiences.
        "input": _input,
        # Use n worker processes to listen on different ports.
        "num_workers": args.num_workers,
        # Disable OPE, since the rollouts are coming from online clients.
        "off_policy_estimation_methods": {},
        # Other settings.
예제 #6
0
)

if __name__ == "__main__":
    args = parser.parse_args()

    # Start the client for sending environment information (e.g. observations,
    # actions) to a policy server (listening on port 9900).
    client = PolicyClient(
        "http://" + args.server + ":" + str(args.port),
        inference_mode=args.inference_mode,
        update_interval=args.update_interval_local_mode,
    )

    # Start and reset the actual Unity3DEnv (either already running Unity3D
    # editor or a binary (game) to be started automatically).
    env = Unity3DEnv(file_name=args.game, episode_horizon=args.horizon)
    obs = env.reset()
    eid = client.start_episode(training_enabled=not args.no_train)

    # Keep track of the total reward per episode.
    total_rewards_this_episode = 0.0

    # Loop infinitely through the env.
    while True:
        # Get actions from the Policy server given our current obs.
        actions = client.get_action(eid, obs)
        # Apply actions to our env.
        obs, rewards, dones, infos = env.step(actions)
        total_rewards_this_episode += sum(rewards.values())
        # Log rewards and single-agent dones.
        client.log_returns(eid, rewards, infos, multiagent_done_dict=dones)
예제 #7
0
    "--horizon",
    type=int,
    default=3000,
    help="The max. number of `step()`s for any episode (per agent) before "
    "it'll be reset again automatically.")
parser.add_argument("--torch", action="store_true")

if __name__ == "__main__":
    ray.init()

    args = parser.parse_args()

    tune.register_env(
        "unity3d", lambda c: Unity3DEnv(
            file_name=c["file_name"],
            no_graphics=(args.env != "VisualHallway" and c["file_name"] is
                         not None),
            episode_horizon=c["episode_horizon"],
        ))

    # Get policies (different agent types; "behaviors" in MLAgents) and
    # the mappings from individual agents to Policies.
    policies, policy_mapping_fn = \
        Unity3DEnv.get_policy_configs_for_game(args.env)

    config = {
        "env": "unity3d",
        "env_config": {
            "file_name": args.file_name,
            "episode_horizon": args.horizon,
        },
        # For running in editor, force to use just one Worker (we only have
예제 #8
0
parser.add_argument("--no-restore",
                    action="store_true",
                    help="Whether to load the Policy "
                    "weights from a previous checkpoint")

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    # Create a fake-env for the server. This env will never be used (neither
    # for sampling, nor for evaluation) and its obs/action Spaces do not
    # matter either (multi-agent config below defines Spaces per Policy).
    register_env("fake_unity", lambda c: RandomMultiAgentEnv(c))

    policies, policy_mapping_fn = \
        Unity3DEnv.get_policy_configs_for_game(args.env)

    # The entire config will be sent to connecting clients so they can
    # build their own samplers (and also Policy objects iff
    # `inference_mode=local` on clients' command line).
    config = {
        # Use the connector server to generate experiences.
        "input":
        (lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)),
        # Use a single worker process (w/ SyncSampler) to run the server.
        "num_workers":
        0,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],

        # Other settings.