Esempio n. 1
0
def human():
    """Run a host which expects one player to connect remotely."""
    run_config = run_configs.get()

    map_inst = maps.get(FLAGS.map)

    if not FLAGS.rgb_screen_size or not FLAGS.rgb_minimap_size:
        logging.info(
            "Use --rgb_screen_size and --rgb_minimap_size if you want rgb "
            "observations.")

    ports = [FLAGS.config_port + p for p in range(5)]  # tcp + 2 * num_players
    if not all(portpicker.is_port_free(p) for p in ports):
        sys.exit("Need 5 free ports after the config port.")

    proc = None
    ssh_proc = None
    tcp_conn = None
    udp_sock = None
    try:
        proc = run_config.start(extra_ports=ports[1:],
                                timeout_seconds=300,
                                host=FLAGS.host,
                                window_loc=(50, 50))

        tcp_port = ports[0]
        settings = {
            "remote": FLAGS.remote,
            "game_version": proc.version.game_version,
            "realtime": FLAGS.realtime,
            "map_name": map_inst.name,
            "map_path": map_inst.path,
            "map_data": map_inst.data(run_config),
            "ports": {
                "server": {
                    "game": ports[1],
                    "base": ports[2]
                },
                "client": {
                    "game": ports[3],
                    "base": ports[4]
                },
            }
        }

        create = sc_pb.RequestCreateGame(
            realtime=settings["realtime"],
            local_map=sc_pb.LocalMap(map_path=settings["map_path"]))
        create.player_setup.add(type=sc_pb.Participant)
        create.player_setup.add(type=sc_pb.Participant)

        controller = proc.controller
        controller.save_map(settings["map_path"], settings["map_data"])
        controller.create_game(create)

        if FLAGS.remote:
            ssh_proc = lan_sc2_env.forward_ports(
                FLAGS.remote, proc.host, [settings["ports"]["client"]["base"]],
                [tcp_port, settings["ports"]["server"]["base"]])

        print("-" * 80)
        print("Join: play_vs_agent --host %s --config_port %s" %
              (proc.host, tcp_port))
        print("-" * 80)

        tcp_conn = lan_sc2_env.tcp_server(
            lan_sc2_env.Addr(proc.host, tcp_port), settings)

        if FLAGS.remote:
            udp_sock = lan_sc2_env.udp_server(
                lan_sc2_env.Addr(proc.host,
                                 settings["ports"]["client"]["game"]))

            lan_sc2_env.daemon_thread(
                lan_sc2_env.tcp_to_udp,
                (tcp_conn, udp_sock,
                 lan_sc2_env.Addr(proc.host,
                                  settings["ports"]["server"]["game"])))

            lan_sc2_env.daemon_thread(lan_sc2_env.udp_to_tcp,
                                      (udp_sock, tcp_conn))

        join = sc_pb.RequestJoinGame()
        join.shared_port = 0  # unused
        join.server_ports.game_port = settings["ports"]["server"]["game"]
        join.server_ports.base_port = settings["ports"]["server"]["base"]
        join.client_ports.add(game_port=settings["ports"]["client"]["game"],
                              base_port=settings["ports"]["client"]["base"])

        join.race = sc2_env.Race[FLAGS.user_race]
        join.player_name = FLAGS.user_name
        if FLAGS.render:
            join.options.raw = True
            join.options.score = True
            if FLAGS.feature_screen_size and FLAGS.feature_minimap_size:
                fl = join.options.feature_layer
                fl.width = 24
                FLAGS.feature_screen_size.assign_to(fl.resolution)
                FLAGS.feature_minimap_size.assign_to(fl.minimap_resolution)
            if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size:
                FLAGS.rgb_screen_size.assign_to(join.options.render.resolution)
                FLAGS.rgb_minimap_size.assign_to(
                    join.options.render.minimap_resolution)
        controller.join_game(join)

        if FLAGS.render:
            renderer = renderer_human.RendererHuman(fps=FLAGS.fps,
                                                    render_feature_grid=False)
            renderer.run(run_configs.get(), controller, max_episodes=1)
        else:  # Still step forward so the Mac/Windows renderer works.
            while True:
                frame_start_time = time.time()
                if not FLAGS.realtime:
                    controller.step()
                obs = controller.observe()

                if obs.player_result:
                    break
                time.sleep(
                    max(0, frame_start_time - time.time() + 1 / FLAGS.fps))
    except KeyboardInterrupt:
        pass
    finally:
        if tcp_conn:
            tcp_conn.close()
        if proc:
            proc.close()
        if udp_sock:
            udp_sock.close()
        if ssh_proc:
            ssh_proc.terminate()
            for _ in range(5):
                if ssh_proc.poll() is not None:
                    break
                time.sleep(1)
            if ssh_proc.poll() is None:
                ssh_proc.kill()
                ssh_proc.wait()
Esempio n. 2
0
def server():
    """Run a host which expects one player to connect remotely."""
    run_config = run_configs.get()

    map_inst = maps.get(FLAGS.map)

    if not FLAGS.rgb_screen_size or not FLAGS.rgb_minimap_size:
        logging.info(
            "Use --rgb_screen_size and --rgb_minimap_size if you want rgb "
            "observations.")

    ports = [FLAGS.port0, FLAGS.port1, FLAGS.port2, FLAGS.port3, FLAGS.port4]
    if not all(portpicker.is_port_free(p) for p in ports):
        sys.exit("Need 5 free ports after the config port.")

    proc = None
    tcp_conn = None

    try:
        proc = run_config.start(extra_ports=ports[1:],
                                timeout_seconds=300,
                                host=FLAGS.host,
                                window_loc=(50, 50))

        tcp_port = ports[0]
        settings = {
            "remote": False,
            "game_version": proc.version.game_version,
            "realtime": FLAGS.realtime,
            "map_name": map_inst.name,
            "map_path": map_inst.path,
            "map_data": map_inst.data(run_config),
            "ports": {
                "server": {
                    "game": ports[1],
                    "base": ports[2]
                },
                "client": {
                    "game": ports[3],
                    "base": ports[4]
                },
            }
        }

        create = sc_pb.RequestCreateGame(
            realtime=settings["realtime"],
            local_map=sc_pb.LocalMap(map_path=settings["map_path"]),
            disable_fog=FLAGS.disable_fog)
        create.player_setup.add(type=sc_pb.Participant)
        create.player_setup.add(type=sc_pb.Participant)

        controller = proc.controller
        controller.save_map(settings["map_path"], settings["map_data"])
        controller.create_game(create)

        print("-" * 80)
        print("Join: agent_vs_agent --host %s --config_port %s" %
              (proc.host, tcp_port))
        print("-" * 80)

        tcp_conn = lan_sc2_env.tcp_server(
            lan_sc2_env.Addr(proc.host, tcp_port), settings)

        join = sc_pb.RequestJoinGame()
        join.shared_port = 0  # unused
        join.server_ports.game_port = settings["ports"]["server"]["game"]
        join.server_ports.base_port = settings["ports"]["server"]["base"]
        join.client_ports.add(game_port=settings["ports"]["client"]["game"],
                              base_port=settings["ports"]["client"]["base"])

        join.race = sc2_env.Race[FLAGS.agent_race]
        join.options.raw = True
        join.options.score = True
        if FLAGS.feature_screen_size and FLAGS.feature_minimap_size:
            fl = join.options.feature_layer
            fl.width = 24
            FLAGS.feature_screen_size.assign_to(fl.resolution)
            FLAGS.feature_minimap_size.assign_to(fl.minimap_resolution)

        if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size:
            FLAGS.rgb_screen_size.assign_to(join.options.render.resolution)
            FLAGS.rgb_minimap_size.assign_to(
                join.options.render.minimap_resolution)

        controller.join_game(join)

        with lan_server_sc2_env.LanServerSC2Env(
                race=sc2_env.Race[FLAGS.agent_race],
                step_mul=FLAGS.step_mul,
                agent_interface_format=sc2_env.parse_agent_interface_format(
                    feature_screen=FLAGS.feature_screen_size,
                    feature_minimap=FLAGS.feature_minimap_size,
                    rgb_screen=FLAGS.rgb_screen_size,
                    rgb_minimap=FLAGS.rgb_minimap_size,
                    action_space=FLAGS.action_space,
                    use_feature_units=FLAGS.use_feature_units),
                visualize=False,
                controller=controller,
                map_name=FLAGS.map) as env:
            agent_module, agent_name = FLAGS.agent.rsplit(".", 1)
            agent_cls = getattr(importlib.import_module(agent_module),
                                agent_name)
            agent_kwargs = {}
            if FLAGS.agent_config:
                agent_kwargs['config_path'] = FLAGS.agent_config
            agents = [agent_cls(**agent_kwargs)]

            try:
                run_loop(agents, env, FLAGS.max_steps)
            except lan_server_sc2_env.RestartException:
                pass

            if FLAGS.save_replay:
                env.save_replay(agent_cls.__name__)
    finally:
        if tcp_conn:
            tcp_conn.close()
        if proc:
            proc.close()