コード例 #1
0
ファイル: benchmark_replay.py プロジェクト: yeclairer/soowa
def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  stopwatch.sw.enable()

  interface = sc_pb.InterfaceOptions()
  interface.raw = FLAGS.use_feature_units or FLAGS.use_raw_units
  interface.score = True
  interface.feature_layer.width = 24
  if FLAGS.feature_screen_size and FLAGS.feature_minimap_size:
    FLAGS.feature_screen_size.assign_to(interface.feature_layer.resolution)
    FLAGS.feature_minimap_size.assign_to(
        interface.feature_layer.minimap_resolution)
  if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size:
    FLAGS.rgb_screen_size.assign_to(interface.render.resolution)
    FLAGS.rgb_minimap_size.assign_to(interface.render.minimap_resolution)

  run_config = run_configs.get()
  replay_data = run_config.replay_data(FLAGS.replay)
  start_replay = sc_pb.RequestStartReplay(
      replay_data=replay_data,
      options=interface,
      observed_player_id=1)
  version = replay.get_replay_version(replay_data)
  run_config = run_configs.get(version=version)  # Replace the run config.

  try:
    with run_config.start(
        want_rgb=interface.HasField("render")) as controller:
      info = controller.replay_info(replay_data)
      print(" Replay info ".center(60, "-"))
      print(info)
      print("-" * 60)
      map_path = FLAGS.map_path or info.local_map_path
      if map_path:
        start_replay.map_data = run_config.map_data(map_path)
      controller.start_replay(start_replay)

      feats = features.features_from_game_info(
          game_info=controller.game_info(),
          use_feature_units=FLAGS.use_feature_units,
          use_raw_units=FLAGS.use_raw_units,
          use_unit_counts=interface.raw,
          use_camera_position=False,
          action_space=actions.ActionSpace.FEATURES)

      while True:
        controller.step(FLAGS.step_mul)
        obs = controller.observe()
        feats.transform_obs(obs)
        if obs.player_result:
          break
  except KeyboardInterrupt:
    pass

  print(stopwatch.sw)
コード例 #2
0
ファイル: replay_actions.py プロジェクト: yeclairer/soowa
def main(unused_argv):
    """Dump stats about all the actions that are in use in a set of replays."""
    run_config = run_configs.get()

    if not gfile.Exists(FLAGS.replays):
        sys.exit("{} doesn't exist.".format(FLAGS.replays))

    stats_queue = multiprocessing.Queue()
    stats_thread = threading.Thread(target=stats_printer, args=(stats_queue, ))
    try:
        # For some reason buffering everything into a JoinableQueue makes the
        # program not exit, so save it into a list then slowly fill it into the
        # queue in a separate thread. Grab the list synchronously so we know there
        # is work in the queue before the SC2 processes actually run, otherwise
        # The replay_queue.join below succeeds without doing any work, and exits.
        print("Getting replay list:", FLAGS.replays)
        replay_list = sorted(run_config.replay_paths(FLAGS.replays))
        print(len(replay_list), "replays found.")
        if not replay_list:
            return

        if not FLAGS["sc2_version"].present:  # ie not set explicitly.
            version = replay.get_replay_version(
                run_config.replay_data(replay_list[0]))
            run_config = run_configs.get(version=version)
            print("Assuming version:", version.game_version)

        print()

        stats_thread.start()
        replay_queue = multiprocessing.JoinableQueue(FLAGS.parallel * 10)
        replay_queue_thread = threading.Thread(target=replay_queue_filler,
                                               args=(replay_queue,
                                                     replay_list))
        replay_queue_thread.daemon = True
        replay_queue_thread.start()

        for i in range(min(len(replay_list), FLAGS.parallel)):
            p = ReplayProcessor(i, run_config, replay_queue, stats_queue)
            p.daemon = True
            p.start()
            time.sleep(
                1)  # Stagger startups, otherwise they seem to conflict somehow

        replay_queue.join()  # Wait for the queue to empty.
    except KeyboardInterrupt:
        print("Caught KeyboardInterrupt, exiting.")
    finally:
        stats_queue.put(None)  # Tell the stats_thread to print and exit.
        if stats_thread.is_alive():
            stats_thread.join()
コード例 #3
0
ファイル: replay_info.py プロジェクト: yeclairer/soowa
def _replay_info(replay_path):
  """Query a replay for information."""
  if not replay_path.lower().endswith("sc2replay"):
    print("Must be a replay.")
    return

  run_config = run_configs.get()
  data = run_config.replay_data(replay_path)
  ver = replay.get_replay_version(data)
  FLAGS.set_default("sc2_version", ver.game_version)
  run_config = run_configs.get()  # In case the version changed.
  print("Launching version:", run_config.version.game_version)
  with run_config.start(want_rgb=False) as controller:
    info = controller.replay_info(data)
  print("-" * 60)
  print(info)
コード例 #4
0
def main(_):
    run_config = run_configs.get()

    replay_list = sorted(run_config.replay_paths(FLAGS.input_dir))
    print(len(replay_list), "replays found.\n")

    version = replay.get_replay_version(run_config.replay_data(replay_list[0]))
    run_config = run_configs.get(version=version)  # Replace the run config.

    with run_config.start(want_rgb=False) as controller:
        for replay_path in replay_list:
            replay_data = run_config.replay_data(replay_path)
            info = controller.replay_info(replay_data)

            print(" Starting replay: ".center(60, "-"))
            print("Path:", replay_path)
            print("Size:", len(replay_data), "bytes")
            print(" Replay info: ".center(60, "-"))
            print(info)
            print("-" * 60)

            start_replay = sc_pb.RequestStartReplay(
                replay_data=replay_data,
                options=sc_pb.InterfaceOptions(score=True),
                record_replay=True,
                observed_player_id=1)
            if info.local_map_path:
                start_replay.map_data = run_config.map_data(
                    info.local_map_path, len(info.player_info))
            controller.start_replay(start_replay)

            while True:
                controller.step(1000)
                obs = controller.observe()
                if obs.player_result:
                    print("Stepped", obs.observation.game_loop, "game loops")
                    break

            replay_data = controller.save_replay()

            replay_save_loc = os.path.join(FLAGS.output_dir,
                                           os.path.basename(replay_path))
            with open(replay_save_loc, "wb") as f:
                f.write(replay_data)

            print("Wrote replay, ", len(replay_data), " bytes to:",
                  replay_save_loc)
コード例 #5
0
ファイル: replay_info.py プロジェクト: yeclairer/soowa
def _replay_index(replay_dir):
  """Output information for a directory of replays."""
  run_config = run_configs.get()
  replay_dir = run_config.abs_replay_path(replay_dir)
  print("Checking:", replay_dir)
  replay_paths = list(run_config.replay_paths(replay_dir))
  print("Found %s replays" % len(replay_paths))

  if not replay_paths:
    return

  data = run_config.replay_data(replay_paths[0])
  ver = replay.get_replay_version(data)
  FLAGS.set_default("sc2_version", ver.game_version)
  run_config = run_configs.get()  # In case the version changed.
  print("Launching version:", run_config.version.game_version)
  with run_config.start(want_rgb=False) as controller:
    print("-" * 60)
    print(",".join((
        "filename",
        "version",
        "map_name",
        "game_duration_loops",
        "players",
        "P1-outcome",
        "P1-race",
        "P1-apm",
        "P2-race",
        "P2-apm",
    )))

    try:
      bad_replays = []
      for file_path in replay_paths:
        file_name = os.path.basename(file_path)
        data = run_config.replay_data(file_path)
        try:
          info = controller.replay_info(data)
        except remote_controller.RequestError as e:
          bad_replays.append("%s: %s" % (file_name, e))
          continue
        if info.HasField("error"):
          print("failed:", file_name, info.error, info.error_details)
          bad_replays.append(file_name)
        else:
          out = [
              file_name,
              info.game_version,
              info.map_name,
              info.game_duration_loops,
              len(info.player_info),
              sc_pb.Result.Name(info.player_info[0].player_result.result),
              sc_common.Race.Name(info.player_info[0].player_info.race_actual),
              info.player_info[0].player_apm,
          ]
          if len(info.player_info) >= 2:
            out += [
                sc_common.Race.Name(
                    info.player_info[1].player_info.race_actual),
                info.player_info[1].player_apm,
            ]
          print(u",".join(str(s) for s in out))
    except KeyboardInterrupt:
      pass
    finally:
      if bad_replays:
        print("\n")
        print("Replays with errors:")
        print("\n".join(bad_replays))
コード例 #6
0
def main(argv):
    """Compare the observations from multiple binaries."""
    if len(argv) <= 1:
        sys.exit(
            "Please specify binaries to run / to connect to. For binaries to run, "
            "specify the executable name. For remote connections, specify "
            "<hostname>:<port>. The version must match the replay.")

    targets = argv[1:]

    interface = sc_pb.InterfaceOptions()
    interface.raw = True
    interface.raw_affects_selection = True
    interface.raw_crop_to_playable_area = True
    interface.score = True
    interface.show_cloaked = True
    interface.show_placeholders = True
    interface.feature_layer.width = 24
    interface.feature_layer.resolution.x = 48
    interface.feature_layer.resolution.y = 48
    interface.feature_layer.minimap_resolution.x = 48
    interface.feature_layer.minimap_resolution.y = 48
    interface.feature_layer.crop_to_playable_area = True
    interface.feature_layer.allow_cheating_layers = True

    run_config = run_configs.get()
    replay_data = run_config.replay_data(FLAGS.replay)
    start_replay = sc_pb.RequestStartReplay(replay_data=replay_data,
                                            options=interface,
                                            observed_player_id=1,
                                            realtime=False)
    version = replay.get_replay_version(replay_data)

    timers = []
    controllers = []
    procs = []
    for target in targets:
        timer = stopwatch.StopWatch()
        timers.append(timer)
        with timer("launch"):
            if _is_remote(target):
                host, port = target.split(":")
                controllers.append(
                    remote_controller.RemoteController(host, int(port)))
            else:
                proc = run_configs.get(version=version._replace(
                    binary=target)).start(want_rgb=False)
                procs.append(proc)
                controllers.append(proc.controller)

    diff_counts = [0] * len(controllers)
    diff_paths = all_collections_generated_classes.Counter()

    try:
        print("-" * 80)
        print(controllers[0].replay_info(replay_data))
        print("-" * 80)

        for controller, t in zip(controllers, timers):
            with t("start_replay"):
                controller.start_replay(start_replay)

        # Check the static data.
        static_data = []
        for controller, t in zip(controllers, timers):
            with t("data"):
                static_data.append(controller.data_raw())

        if FLAGS.diff:
            diffs = {
                i: proto_diff.compute_diff(static_data[0], d)
                for i, d in enumerate(static_data[1:], 1)
            }
            if any(diffs.values()):
                print(" Diff in static data ".center(80, "-"))
                for i, diff in diffs.items():
                    if diff:
                        print(targets[i])
                        diff_counts[i] += 1
                        print(diff.report(truncate_to=FLAGS.truncate))
                        for path in diff.all_diffs():
                            diff_paths[
                                path.with_anonymous_array_indices()] += 1
            else:
                print("No diffs in static data.")

        # Run some steps, checking speed and diffing the observations.
        for _ in range(FLAGS.count):
            for controller, t in zip(controllers, timers):
                with t("step"):
                    controller.step(FLAGS.step_mul)

            obs = []
            for controller, t in zip(controllers, timers):
                with t("observe"):
                    obs.append(controller.observe())

            if FLAGS.diff:
                for o in obs:
                    _clear_non_deterministic_fields(o)

                diffs = {
                    i: proto_diff.compute_diff(obs[0], o)
                    for i, o in enumerate(obs[1:], 1)
                }
                if any(diffs.values()):
                    print((" Diff on step: %s " %
                           obs[0].observation.game_loop).center(80, "-"))
                    for i, diff in diffs.items():
                        if diff:
                            print(targets[i])
                            diff_counts[i] += 1
                            print(
                                diff.report(
                                    [image_differencer.image_differencer],
                                    truncate_to=FLAGS.truncate))
                            for path in diff.all_diffs():
                                diff_paths[
                                    path.with_anonymous_array_indices()] += 1

            if obs[0].player_result:
                break
    except KeyboardInterrupt:
        pass
    finally:
        for c in controllers:
            c.quit()
            c.close()

        for p in procs:
            p.close()

    if FLAGS.diff:
        print(" Diff Counts by target ".center(80, "-"))
        for target, count in zip(targets, diff_counts):
            print(" %5d %s" % (count, target))
        print()

        print(" Diff Counts by observation path ".center(80, "-"))
        for path, count in diff_paths.most_common(100):
            print(" %5d %s" % (count, path))
        print()

    print(" Timings ".center(80, "-"))
    for v, t in zip(targets, timers):
        print(v)
        print(t)
コード例 #7
0
def main(unused_argv):
    stopwatch.sw.enable()

    results = []
    try:
        for config, interface in configs:
            print((" Starting: %s " % config).center(60, "-"))
            timeline = []

            run_config = run_configs.get()

            if FLAGS.replay:
                replay_data = run_config.replay_data(FLAGS.replay)
                start_replay = sc_pb.RequestStartReplay(
                    replay_data=replay_data,
                    options=interface,
                    disable_fog=False,
                    observed_player_id=2)
                version = replay.get_replay_version(replay_data)
                run_config = run_configs.get(
                    version=version)  # Replace the run config.
            else:
                map_inst = maps.get(FLAGS.map)
                create = sc_pb.RequestCreateGame(
                    realtime=False,
                    disable_fog=False,
                    random_seed=1,
                    local_map=sc_pb.LocalMap(
                        map_path=map_inst.path,
                        map_data=map_inst.data(run_config)))
                create.player_setup.add(type=sc_pb.Participant)
                create.player_setup.add(type=sc_pb.Computer,
                                        race=sc_common.Terran,
                                        difficulty=sc_pb.VeryEasy)
                join = sc_pb.RequestJoinGame(options=interface,
                                             race=sc_common.Protoss)

            with run_config.start(
                    want_rgb=interface.HasField("render")) as controller:

                if FLAGS.replay:
                    info = controller.replay_info(replay_data)
                    print(" Replay info ".center(60, "-"))
                    print(info)
                    print("-" * 60)
                    if info.local_map_path:
                        start_replay.map_data = run_config.map_data(
                            info.local_map_path)
                    controller.start_replay(start_replay)
                else:
                    controller.create_game(create)
                    controller.join_game(join)

                for _ in range(FLAGS.count):
                    controller.step(FLAGS.step_mul)
                    start = time.time()
                    obs = controller.observe()
                    timeline.append(time.time() - start)
                    if obs.player_result:
                        break

            results.append((config, timeline))
    except KeyboardInterrupt:
        pass

    names, values = zip(*results)

    print("\n\nTimeline:\n")
    print(",".join(names))
    for times in zip(*values):
        print(",".join("%0.2f" % (t * 1000) for t in times))

    print(stopwatch.sw)
コード例 #8
0
ファイル: play.py プロジェクト: Hotpotfish/pysc2
def main(unused_argv):
  """Run SC2 to play a game or a replay."""
  if FLAGS.trace:
    stopwatch.sw.trace()
  elif FLAGS.profile:
    stopwatch.sw.enable()

  if (FLAGS.map and FLAGS.replay) or (not FLAGS.map and not FLAGS.replay):
    sys.exit("Must supply either a map or replay.")

  if FLAGS.replay and not FLAGS.replay.lower().endswith("sc2replay"):
    sys.exit("Replay must end in .SC2Replay.")

  if FLAGS.realtime and FLAGS.replay:
    # TODO(tewalds): Support realtime in replays once the game supports it.
    sys.exit("realtime isn't possible for replays yet.")

  if FLAGS.render and (FLAGS.realtime or FLAGS.full_screen):
    sys.exit("disable pygame rendering if you want realtime or full_screen.")

  if platform.system() == "Linux" and (FLAGS.realtime or FLAGS.full_screen):
    sys.exit("realtime and full_screen only make sense on Windows/MacOS.")

  if not FLAGS.render and FLAGS.render_sync:
    sys.exit("render_sync only makes sense with pygame rendering on.")

  run_config = run_configs.get()

  interface = sc_pb.InterfaceOptions()
  interface.raw = FLAGS.render
  interface.raw_affects_selection = True
  interface.raw_crop_to_playable_area = True
  interface.score = True
  interface.show_cloaked = True
  interface.show_burrowed_shadows = True
  interface.show_placeholders = True
  if FLAGS.feature_screen_size and FLAGS.feature_minimap_size:
    interface.feature_layer.width = FLAGS.feature_camera_width
    FLAGS.feature_screen_size.assign_to(interface.feature_layer.resolution)
    FLAGS.feature_minimap_size.assign_to(
        interface.feature_layer.minimap_resolution)
    interface.feature_layer.crop_to_playable_area = True
    interface.feature_layer.allow_cheating_layers = True
  if FLAGS.render and FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size:
    FLAGS.rgb_screen_size.assign_to(interface.render.resolution)
    FLAGS.rgb_minimap_size.assign_to(interface.render.minimap_resolution)

  max_episode_steps = FLAGS.max_episode_steps

  if FLAGS.map:
    create = sc_pb.RequestCreateGame(
        realtime=FLAGS.realtime,
        disable_fog=FLAGS.disable_fog)
    try:
      map_inst = maps.get(FLAGS.map)
    except maps.lib.NoMapError:
      if FLAGS.battle_net_map:
        create.battlenet_map_name = FLAGS.map
      else:
        raise
    else:
      if map_inst.game_steps_per_episode:
        max_episode_steps = map_inst.game_steps_per_episode
      if FLAGS.battle_net_map:
        create.battlenet_map_name = map_inst.battle_net
      else:
        create.local_map.map_path = map_inst.path
        create.local_map.map_data = map_inst.data(run_config)

    create.player_setup.add(type=sc_pb.Participant)
    create.player_setup.add(type=sc_pb.Computer,
                            race=sc2_env.Race[FLAGS.bot_race],
                            difficulty=sc2_env.Difficulty[FLAGS.difficulty],
                            ai_build=sc2_env.BotBuild[FLAGS.bot_build])
    join = sc_pb.RequestJoinGame(
        options=interface, race=sc2_env.Race[FLAGS.user_race],
        player_name=FLAGS.user_name)
    version = None
  else:
    replay_data = run_config.replay_data(FLAGS.replay)
    start_replay = sc_pb.RequestStartReplay(
        replay_data=replay_data,
        options=interface,
        disable_fog=FLAGS.disable_fog,
        observed_player_id=FLAGS.observed_player)
    version = replay.get_replay_version(replay_data)
    run_config = run_configs.get(version=version)  # Replace the run config.

  with run_config.start(
      full_screen=FLAGS.full_screen,
      window_size=FLAGS.window_size,
      want_rgb=interface.HasField("render")) as controller:
    if FLAGS.map:
      controller.create_game(create)
      controller.join_game(join)
    else:
      info = controller.replay_info(replay_data)
      print(" Replay info ".center(60, "-"))
      print(info)
      print("-" * 60)
      map_path = FLAGS.map_path or info.local_map_path
      if map_path:
        start_replay.map_data = run_config.map_data(map_path,
                                                    len(info.player_info))
      controller.start_replay(start_replay)

    if FLAGS.render:
      renderer = renderer_human.RendererHuman(
          fps=FLAGS.fps, step_mul=FLAGS.step_mul,
          render_sync=FLAGS.render_sync, video=FLAGS.video)
      renderer.run(
          run_config, controller, max_game_steps=FLAGS.max_game_steps,
          game_steps_per_episode=max_episode_steps,
          save_replay=FLAGS.save_replay)
    else:  # Still step forward so the Mac/Windows renderer works.
      try:
        while True:
          frame_start_time = time.time()
          if not FLAGS.realtime:
            controller.step(FLAGS.step_mul)
          obs = controller.observe()

          if obs.player_result:
            break
          time.sleep(max(0, frame_start_time + 1 / FLAGS.fps - time.time()))
      except KeyboardInterrupt:
        pass
      print("Score: ", obs.observation.score.score)
      print("Result: ", obs.player_result)
      if FLAGS.map and FLAGS.save_replay:
        replay_save_loc = run_config.save_replay(
            controller.save_replay(), "local", FLAGS.map)
        print("Replay saved to:", replay_save_loc)
        # Save scores so we know how the human player did.
        with open(replay_save_loc.replace("SC2Replay", "txt"), "w") as f:
          f.write("{}\n".format(obs.observation.score.score))

  if FLAGS.profile:
    print(stopwatch.sw)