Пример #1
0
    def testSpeed(self):
        count = 1000

        def run():
            for _ in range(count):
                with sw("name"):
                    pass

        sw = stopwatch.StopWatch()
        for _ in range(10):
            sw.enabled = True
            sw.trace = False
            with sw("enabled"):
                run()

            sw.enabled = True
            sw.trace = True
            with sw("trace"):
                run()

            sw.enabled = True  # To catch "disabled".
            with sw("disabled"):
                sw.enabled = False
                run()

        # No asserts. Succeed but print the timings.
        print(sw)
Пример #2
0
    def testDivideZero(self):
        sw = stopwatch.StopWatch()
        with sw("zero"):
            pass

        # Just make sure this doesn't have a divide by 0 for when the total is 0.
        self.assertIn("zero", str(sw))
Пример #3
0
    def testStopwatch(self, mock_time):
        mock_time.return_value = 0
        sw = stopwatch.StopWatch()
        with sw("one"):
            mock_time.return_value += 0.002
        with sw("one"):
            mock_time.return_value += 0.004
        with sw("two"):
            with sw("three"):
                mock_time.return_value += 0.006

        @sw.decorate
        def four():
            mock_time.return_value += 0.004

        four()

        @sw.decorate("five")
        def foo():
            mock_time.return_value += 0.005

        foo()

        out = str(sw)

        # The names should be in sorted order.
        names = [l.split(None)[0] for l in out.splitlines()[1:]]
        self.assertEqual(names, ["five", "four", "one", "two", "two.three"])

        one_line = out.splitlines()[3].split(None)
        self.assertLess(one_line[5], one_line[6])  # min < max
        self.assertEqual(one_line[7], "2")  # num
        # Can't test the rest since they'll be flaky.

        # Allow a few small rounding errors for the round trip.
        round_trip = str(stopwatch.StopWatch.parse(out))
        self.assertLess(ham_dist(out, round_trip), 15,
                        "%s != %s" % (out, round_trip))
Пример #4
0
 def testDecoratorEnabled(self):
     sw = stopwatch.StopWatch()
     self.assertNotEqual(round, sw.decorate(round))
     self.assertNotEqual(round, sw.decorate("name")(round))
Пример #5
0
def main(argv):
    """Compare the observations from multiple binaries."""
    if len(argv) <= 1:
        sys.exit(
            "Please specify binaries to run / to connect to. For binaries to run, "
            "specify the executable name. For remote connections, specify "
            "<hostname>:<port>. The version must match the replay.")

    targets = argv[1:]

    interface = sc_pb.InterfaceOptions()
    interface.raw = True
    interface.raw_affects_selection = True
    interface.raw_crop_to_playable_area = True
    interface.score = True
    interface.show_cloaked = True
    interface.show_placeholders = True
    interface.feature_layer.width = 24
    interface.feature_layer.resolution.x = 48
    interface.feature_layer.resolution.y = 48
    interface.feature_layer.minimap_resolution.x = 48
    interface.feature_layer.minimap_resolution.y = 48
    interface.feature_layer.crop_to_playable_area = True
    interface.feature_layer.allow_cheating_layers = True

    run_config = run_configs.get()
    replay_data = run_config.replay_data(FLAGS.replay)
    start_replay = sc_pb.RequestStartReplay(replay_data=replay_data,
                                            options=interface,
                                            observed_player_id=1,
                                            realtime=False)
    version = replay.get_replay_version(replay_data)

    timers = []
    controllers = []
    procs = []
    for target in targets:
        timer = stopwatch.StopWatch()
        timers.append(timer)
        with timer("launch"):
            if _is_remote(target):
                host, port = target.split(":")
                controllers.append(
                    remote_controller.RemoteController(host, int(port)))
            else:
                proc = run_configs.get(version=version._replace(
                    binary=target)).start(want_rgb=False)
                procs.append(proc)
                controllers.append(proc.controller)

    diff_counts = [0] * len(controllers)
    diff_paths = all_collections_generated_classes.Counter()

    try:
        print("-" * 80)
        print(controllers[0].replay_info(replay_data))
        print("-" * 80)

        for controller, t in zip(controllers, timers):
            with t("start_replay"):
                controller.start_replay(start_replay)

        # Check the static data.
        static_data = []
        for controller, t in zip(controllers, timers):
            with t("data"):
                static_data.append(controller.data_raw())

        if FLAGS.diff:
            diffs = {
                i: proto_diff.compute_diff(static_data[0], d)
                for i, d in enumerate(static_data[1:], 1)
            }
            if any(diffs.values()):
                print(" Diff in static data ".center(80, "-"))
                for i, diff in diffs.items():
                    if diff:
                        print(targets[i])
                        diff_counts[i] += 1
                        print(diff.report(truncate_to=FLAGS.truncate))
                        for path in diff.all_diffs():
                            diff_paths[
                                path.with_anonymous_array_indices()] += 1
            else:
                print("No diffs in static data.")

        # Run some steps, checking speed and diffing the observations.
        for _ in range(FLAGS.count):
            for controller, t in zip(controllers, timers):
                with t("step"):
                    controller.step(FLAGS.step_mul)

            obs = []
            for controller, t in zip(controllers, timers):
                with t("observe"):
                    obs.append(controller.observe())

            if FLAGS.diff:
                for o in obs:
                    _clear_non_deterministic_fields(o)

                diffs = {
                    i: proto_diff.compute_diff(obs[0], o)
                    for i, o in enumerate(obs[1:], 1)
                }
                if any(diffs.values()):
                    print((" Diff on step: %s " %
                           obs[0].observation.game_loop).center(80, "-"))
                    for i, diff in diffs.items():
                        if diff:
                            print(targets[i])
                            diff_counts[i] += 1
                            print(
                                diff.report(
                                    [image_differencer.image_differencer],
                                    truncate_to=FLAGS.truncate))
                            for path in diff.all_diffs():
                                diff_paths[
                                    path.with_anonymous_array_indices()] += 1

            if obs[0].player_result:
                break
    except KeyboardInterrupt:
        pass
    finally:
        for c in controllers:
            c.quit()
            c.close()

        for p in procs:
            p.close()

    if FLAGS.diff:
        print(" Diff Counts by target ".center(80, "-"))
        for target, count in zip(targets, diff_counts):
            print(" %5d %s" % (count, target))
        print()

        print(" Diff Counts by observation path ".center(80, "-"))
        for path, count in diff_paths.most_common(100):
            print(" %5d %s" % (count, path))
        print()

    print(" Timings ".center(80, "-"))
    for v, t in zip(targets, timers):
        print(v)
        print(t)
Пример #6
0
import tensorflow as tf
import numpy as np

from pysc2.agents.base_agent import BaseAgent
from pysc2.lib import actions, stopwatch

from common import util

from agents.a3c.estimators import configure_estimators

sw = stopwatch.StopWatch()


class Worker(BaseAgent):
    def __init__(self,
                 name,
                 device,
                 session,
                 m_size,
                 s_size,
                 global_optimizers,
                 network,
                 map_name,
                 learning_rate,
                 discount_factor,
                 eta,
                 beta,
                 summary_writer=None):

        super().__init__()
        self.name = name