示例#1
0
    def testFlRgbActionSpec(self):
        feats = features.Features(feature_screen_width=84,
                                  feature_screen_height=80,
                                  feature_minimap_width=64,
                                  feature_minimap_height=67,
                                  rgb_screen_width=128,
                                  rgb_screen_height=132,
                                  rgb_minimap_width=74,
                                  rgb_minimap_height=77,
                                  action_space=actions.ActionSpace.FEATURES)
        spec = feats.action_spec()
        self.assertEqual(spec.types.screen.sizes, (84, 80))
        self.assertEqual(spec.types.screen2.sizes, (84, 80))
        self.assertEqual(spec.types.minimap.sizes, (64, 67))

        feats = features.Features(feature_screen_width=84,
                                  feature_screen_height=80,
                                  feature_minimap_width=64,
                                  feature_minimap_height=67,
                                  rgb_screen_width=128,
                                  rgb_screen_height=132,
                                  rgb_minimap_width=74,
                                  rgb_minimap_height=77,
                                  action_space=actions.ActionSpace.RGB)
        spec = feats.action_spec()
        self.assertEqual(spec.types.screen.sizes, (128, 132))
        self.assertEqual(spec.types.screen2.sizes, (128, 132))
        self.assertEqual(spec.types.minimap.sizes, (74, 77))
def classify_actions():
  """Provide a rough estimate of actions based on race."""
  for race in ACTIONS.keys():
      ACTIONS[race]=[]
  race_keywords = {'T':['siege','missile','hangar','hell','stim','supply','advanced','tactic','drill'],'Z':['tunnel','swarm','locust','lurker','creep','carap','parasi', 'nyda','worm','fung','groo', 'centrifugal','bile','contam','gland','explode'], 'P':['photon','blink','psi','pulsar', 'void', 'warp','guardian','orbit','gravit','chrono', 'puri','stasis']}
  for race in UNITS.keys():
    if race not in race_keywords:
        race_keywords[race]=[]
    for unit in UNITS[race].keys():
        keys = unit.lower().split('_')
        for key in keys:
            if key not in race_keywords[race]:
                race_keywords[race].append(key)
  feats = features.Features(
    screen_size_px=(84,84),
    minimap_size_px=(64,64))
  action_spec = feats.action_spec()
  for func in action_spec.functions:
    choice = 'N'
    bestMatches = {'N':"No",'T':"",'Z':"",'P':""}    
    fun_name = func.name.replace('Behavior','').replace('Effect','').replace('Research','').replace('Train','').replace('Build','').lower()
    for race in UNITS.keys():
        if race == 'N':
            continue
        for keyword in race_keywords[race]:
            if len(keyword)>len(bestMatches[race]) and keyword in fun_name:
                bestMatches[race] = keyword
        if bestMatches[choice] < bestMatches[race]:
            choice = race
    ACTIONS[choice].append(func.id)
示例#3
0
def sample_frames(action_path):
    agent_intf = features.AgentInterfaceFormat(
        feature_dimensions=features.Dimensions(screen=(1, 1), minimap=(1, 1)))
    feat = features.Features(agent_intf)

    with open(action_path) as f:
        actions = json.load(f)

    frame_id = 0
    result_frames = []
    for action_step in actions:  # Get actions performed since previous observed frame
        frame_id += FLAGS.step_mul  # Advance to current frame
        action_name = None
        for action_str in action_step:  # Search all actions from step
            action = Parse(action_str, sc_pb.Action())
            try:
                func_id = feat.reverse_action(action).function
                func_name = FUNCTIONS[func_id].name
                if func_name.split('_')[0] in {
                        'Build', 'Train', 'Research', 'Morph', 'Cancel',
                        'Halt', 'Stop'
                }:  # Macro action found in step
                    action_name = func_name
                    break  # Macro step found, no need to process further actions from this step
            except:
                pass
        if (action_name is not None) or (
            (frame_id % FLAGS.skip)
                == 0):  # This is a macro step or fixed recording step
            result_frames.append(frame_id)

    return result_frames
示例#4
0
def sample_action_from_player(action_path):
    feat = features.Features(screen_size_px=(1, 1), minimap_size_px=(1, 1))
    with open(action_path) as f:
        actions = json.load(f)

    frame_id = 0
    result_frames = []
    for action_strs in actions:
        action_name = None
        for action_str in action_strs:
            action = Parse(action_str, sc_pb.Action())
            try:
                func_id = feat.reverse_action(action).function
                func_name = FUNCTIONS[func_id].name
                if func_name.split('_')[0] in {'Attack', 'Scan', 'Behavior','BorrowUp', 'Effect','Hallucination', 'Harvest', 'Hold','Land','Lift', \
   'Load','Move','Patrol','Rally','Smart','TrainWarp', 'UnloadAll', 'UnloadAllAt''Build', 'Train', 'Research', 'Morph',\
    'Cancel', 'Halt', 'Stop'}:
                    action_name = func_name
                    break
            except:
                pass
        if frame_id > 0 and (action_name is not None
                             or frame_id % FLAGS.skip == 0):
            result_frames.append(frame_id - FLAGS.step_mul)

        frame_id += FLAGS.step_mul

    return result_frames
示例#5
0
    def _setup_game(self):
        # Save the maps so they can access it.
        map_path = os.path.basename(self._map.path)
        self._parallel.run(
            (c.save_map, map_path, self._run_config.map_data(self._map.path))
            for c in self._controllers)

        # construct interface
        interface = sc_pb.InterfaceOptions(raw=True, score=True)
        self._screen.assign_to(interface.feature_layer.resolution)
        self._minimap.assign_to(interface.feature_layer.minimap_resolution)

        # Create the create request.
        create = sc_pb.RequestCreateGame(local_map=sc_pb.LocalMap(
            map_path=map_path))
        for _ in range(len(self._agents)):
            create.player_setup.add(type=sc_pb.Participant)

        # Create the join request.
        joins = [self._join_pb(race, interface) for race in self._races]

        # This is where actually game plays
        # Create and Join
        print("create")
        self._controllers[0].create_game(create)
        print("join")
        self._parallel.run(
            (c.join_game, join) for join, c in zip(joins, self._controllers))
        print("play_game")
        self._game_infos = self._parallel.run(
            (c.game_info) for c in self._controllers)
        self._features = [features.Features(info) for info in self._game_infos]
        print("setup game ok")
示例#6
0
def sample_action_from_player(action_path):
    feat = features.Features(screen_size_px=(1, 1), minimap_size_px=(1, 1))
    with open(action_path) as f:
        actions = json.load(f)

    frame_id = 0
    result_frames = []
    for action_strs in actions:
        action_name = None
        for action_str in action_strs:
            action = Parse(action_str, sc_pb.Action())
            try:
                func_id = feat.reverse_action(action).function
                func_name = FUNCTIONS[func_id].name
                if func_name.split('_')[0] in {
                        'Build', 'Train', 'Research', 'Morph', 'Cancel',
                        'Halt', 'Stop'
                }:
                    action_name = func_name
                    break
            except:
                pass
        if frame_id > 0 and (action_name is not None
                             or frame_id % FLAGS.skip == 0):
            result_frames.append(frame_id - FLAGS.step_mul)

        frame_id += FLAGS.step_mul

    return result_frames
示例#7
0
 def testReversingUnknownAction(self):
   feats = features.Features(screen_size_px=(84, 80), minimap_size_px=(64, 67),
                             hide_specific_actions=False)
   sc2_action = sc_pb.Action()
   sc2_action.action_feature_layer.unit_command.ability_id = 6  # Cheer
   func_call = feats.reverse_action(sc2_action)
   self.assertEqual(func_call.function, 0)  # No-op
示例#8
0
 def testIdsMatchIndex(self):
   feats = features.Features(screen_size_px=(84, 80), minimap_size_px=(64, 67))
   action_spec = feats.action_spec()
   for func_index, func_def in enumerate(action_spec.functions):
     self.assertEqual(func_index, func_def.id)
   for type_index, type_def in enumerate(action_spec.types):
     self.assertEqual(type_index, type_def.id)
示例#9
0
  def testSpecificActionsAreReversible(self):
    """Test that the `transform_action` and `reverse_action` are inverses."""
    feats = features.Features(screen_size_px=(84, 80), minimap_size_px=(64, 67),
                              hide_specific_actions=False)
    action_spec = feats.action_spec()

    for func_def in action_spec.functions:
      for _ in range(10):
        func_call = self.gen_random_function_call(action_spec, func_def.id)

        sc2_action = feats.transform_action(
            None, func_call, skip_available=True)
        func_call2 = feats.reverse_action(sc2_action)
        sc2_action2 = feats.transform_action(
            None, func_call2, skip_available=True)
        if func_def.id == actions.FUNCTIONS.select_rect.id:
          # Need to check this one manually since the same rect can be
          # defined in multiple ways.
          def rect(a):
            return point.Rect(point.Point(*a[1]).floor(),
                              point.Point(*a[2]).floor())

          self.assertEqual(func_call.function, func_call2.function)
          self.assertEqual(len(func_call.arguments), len(func_call2.arguments))
          self.assertEqual(func_call.arguments[0], func_call2.arguments[0])
          self.assertEqual(rect(func_call.arguments),
                           rect(func_call2.arguments))
        else:
          self.assertEqual(func_call, func_call2, msg=sc2_action)
        self.assertEqual(sc2_action, sc2_action2)
示例#10
0
def parse_replay(replay_player_path, sampled_action_path, reward):
    if os.path.isfile(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path)):
        return

    # Global Info
    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalInfos', replay_player_path)) as f:
        global_info = json.load(f)
    units_info = static_data.StaticData(Parse(global_info['data_raw'], sc_pb.ResponseData())).units
    feat = features.Features(Parse(global_info['game_info'], sc_pb.ResponseGameInfo()))

    # Sampled Actions
    with open(sampled_action_path) as f:
        sampled_action = json.load(f)
    sampled_action_id = [id // FLAGS.step_mul + 1 for id in sampled_action]

    # Actions
    with open(os.path.join(FLAGS.parsed_replay_path, 'Actions', replay_player_path)) as f:
        actions = json.load(f)
    actions = [None if len(actions[idx]) == 0 else Parse(actions[idx][0], sc_pb.Action())
                for idx in sampled_action_id]

    # Observations
    observations =  [obs for obs in stream.parse(os.path.join(FLAGS.parsed_replay_path,
                            'SampledObservations', replay_player_path), sc_pb.ResponseObservation)]

    assert len(sampled_action) == len(sampled_action_id) == len(actions) == len(observations)

    states = process_replay(sampled_action, actions, observations, feat, units_info, reward)

    with open(os.path.join(FLAGS.parsed_replay_path, 'GlobalFeatures', replay_player_path), 'w') as f:
        json.dump(states, f)
示例#11
0
  def _finalize(self, interface, action_space, use_feature_units, visualize):
    game_info = self._controllers[0].game_info()
    static_data = self._controllers[0].data()
    if not self._map_name:
      self._map_name = game_info.map_name

    if game_info.options.render != interface.render:
      logging.warning(
          "Actual interface options don't match requested options:\n"
          "Requested:\n%s\n\nActual:\n%s", interface, game_info.options)

    self._features = features.Features(game_info=game_info,
                                       action_space=action_space,
                                       use_feature_units=use_feature_units)
    if visualize:
      self._renderer_human = renderer_human.RendererHuman()
      self._renderer_human.init(game_info, static_data)
    else:
      self._renderer_human = None

    self._metrics = metrics.Metrics(self._map_name)
    self._metrics.increment_instance()

    self._last_score = None
    self._total_steps = 0
    self._episode_steps = 0
    self._episode_count = 0
    self._obs = None
    self._state = environment.StepType.LAST  # Want to jump to `reset`.
    logging.info("Environment is ready on map: %s", self._map_name)
示例#12
0
    def start(self):
        _features = features.Features(self.controller.game_info())

        while True:
            self.controller.step(self.step_mul)
            obs = self.controller.observe()
            agent_obs = _features.transform_obs(obs.observation)

            if obs.player_result:  # Episide over.
                self._state = StepType.LAST
                discount = 0
            else:
                discount = self.discount

            self._episode_steps += self.step_mul

            step = TimeStep(step_type=self._state,
                            reward=0,
                            discount=discount,
                            observation=agent_obs)

            self.agent.step(step, obs.actions)

            if obs.player_result:
                break

            self._state = StepType.MID
示例#13
0
    def start(self, replay_file_path):
        self.openReplay(replay_file_path)

        _features = features.Features(self.controller.game_info())

        while True:
            self.controller.step(self.step_mul)
            obs = self.controller.observe()
            agent_obs = _features.transform_obs(obs.observation)

            if obs.player_result:
                self._state = StepType.LAST
                discount = 0
            else:
                discount = self.discount

            self._episode_steps += self.step_mul

            step = TimeStep(step_type=self._state,
                            reward=0,
                            discount=discount,
                            observation=agent_obs)

            self.agent.step(step, obs.actions, self._state == StepType.LAST)

            if obs.player_result:
                break

            self._state = StepType.MID

        self.closeReplay()
示例#14
0
def agent_runner(controller, join):
    """Run the agent in a thread."""
    agent_module, agent_name = FLAGS.agent.rsplit(".", 1)
    agent_cls = getattr(importlib.import_module(agent_module), agent_name)
    agent = agent_cls()

    interface = sc_pb.InterfaceOptions()
    interface.raw = True
    interface.score = True
    interface.feature_layer.width = 24
    interface.feature_layer.resolution.x = FLAGS.feature_screen_size
    interface.feature_layer.resolution.y = FLAGS.feature_screen_size
    interface.feature_layer.minimap_resolution.x = FLAGS.feature_minimap_size
    interface.feature_layer.minimap_resolution.y = FLAGS.feature_minimap_size
    # if FLAGS.rgb_screen_size and FLAGS.rgb_minimap_size:
    #   if FLAGS.rgb_screen_size < FLAGS.rgb_minimap_size:
    #     sys.exit("Screen size can't be smaller than minimap size.")
    #   interface.render.resolution.x = FLAGS.rgb_screen_size
    #   interface.render.resolution.y = FLAGS.rgb_screen_size
    #   interface.render.minimap_resolution.x = FLAGS.rgb_minimap_size
    #   interface.render.minimap_resolution.y = FLAGS.rgb_minimap_size

    j = sc_pb.RequestJoinGame()
    j.CopyFrom(join)
    j.options.CopyFrom(interface)
    j.race = sc2_env.Race[FLAGS.agent_race]
    controller.join_game(j)

    feats = features.Features(game_info=controller.game_info())
    agent.setup(feats.observation_spec(), feats.action_spec())

    state = environment.StepType.FIRST
    reward = 0
    discount = 1
    while True:
        frame_start_time = time.time()
        if not FLAGS.realtime:
            controller.step(FLAGS.step_mul)
        obs = controller.observe()
        if obs.player_result:  # Episode over.
            state = environment.StepType.LAST
            discount = 0

        agent_obs = feats.transform_obs(obs)

        timestep = environment.TimeStep(step_type=state,
                                        reward=reward,
                                        discount=discount,
                                        observation=agent_obs)

        action = agent.step(timestep)
        if state == environment.StepType.LAST:
            break
        controller.act(feats.transform_action(obs.observation, action))

        if FLAGS.realtime:
            time.sleep(
                max(0, frame_start_time - time.time() + FLAGS.step_mul / 22.4))
    controller.quit()
示例#15
0
 def setUp(self):
     super(AvailableActionsTest, self).setUp()
     self.obs = text_format.Parse(observation_text_proto,
                                  sc_pb.Observation())
     self.features = features.Features(feature_screen_width=84,
                                       feature_screen_height=80,
                                       feature_minimap_width=64,
                                       feature_minimap_height=67)
示例#16
0
 def testIdsMatchIndex(self):
     feats = features.Features(
         features.AgentInterfaceFormat(
             feature_dimensions=RECTANGULAR_DIMENSIONS))
     action_spec = feats.action_spec()
     for func_index, func_def in enumerate(action_spec.functions):
         self.assertEqual(func_index, func_def.id)
     for type_index, type_def in enumerate(action_spec.types):
         self.assertEqual(type_index, type_def.id)
示例#17
0
 def testReversingUnknownAction(self):
     feats = features.Features(
         features.AgentInterfaceFormat(
             feature_dimensions=RECTANGULAR_DIMENSIONS,
             hide_specific_actions=False))
     sc2_action = sc_pb.Action()
     sc2_action.action_feature_layer.unit_command.ability_id = 6  # Cheer
     func_call = feats.reverse_action(sc2_action)
     self.assertEqual(func_call.function, 0)  # No-op
示例#18
0
  def testValidFunctionsAreConsistent(self):
    feats = features.Features(screen_size_px=(84, 80), minimap_size_px=(64, 67))

    valid_funcs = feats.action_spec()
    for func_def in valid_funcs.functions:
      func = actions.FUNCTIONS[func_def.id]
      self.assertEqual(func_def.id, func.id)
      self.assertEqual(func_def.name, func.name)
      self.assertEqual(len(func_def.args), len(func.args))
示例#19
0
def main(unused_argv):
    agent = SupportAI(game_type)

    try:
        while True:
            with sc2_env.SC2Env(
                    map_name="Simple64",
                    players=[
                        sc2_env.Agent(player_race),
                        sc2_env.Bot(enemy_race, sc2_env.Difficulty.very_easy)
                    ],
                    agent_interface_format=features.AgentInterfaceFormat(
                        feature_dimensions=features.Dimensions(screen=86,
                                                               minimap=86),
                        use_feature_units=True),
                    step_mul=16,
                    game_steps_per_episode=0,
                    visualize=False,
                    realtime=True) as env:

                feats = features.Features(
                    features.AgentInterfaceFormat(
                        feature_dimensions=features.Dimensions(screen=86,
                                                               minimap=86)))
                action_spec = feats.action_spec()

                agent.setup(env.observation_spec(), env.action_spec())

                timesteps = env.reset()
                agent.reset()

                while True:
                    step_actions = [agent.step(timesteps[0])]
                    feats = features.Features(
                        features.AgentInterfaceFormat(
                            feature_dimensions=features.Dimensions(
                                screen=86, minimap=86)))
                    action_spec = feats.action_spec()
                    if timesteps[0].last():
                        break
                    timesteps = env.step(step_actions)

    except KeyboardInterrupt:
        pass
    def _get_replay_data(self, controller, config):
        """Runs a replay to get the replay data."""
        f = features.Features(game_info=controller.game_info())

        observations = {}
        last_actions = []
        for _ in range(config.num_observations):
            raw_obs = controller.observe()
            o = raw_obs.observation
            obs = f.transform_obs(raw_obs)

            if raw_obs.action_errors:
                print('action errors:', raw_obs.action_errors)

            if o.game_loop == 2:
                # Center camera is initiated automatically by the game and reported
                # at frame 2.
                last_actions = [actions.FUNCTIONS.move_camera.id]

            self.assertEqual(last_actions, list(obs.last_actions))

            unit_type = obs.feature_screen.unit_type
            observations[o.game_loop] = unit_type

            if o.game_loop in config.actions:
                func = config.actions[o.game_loop](obs)

                print((' loop: %s ' % o.game_loop).center(80, '-'))
                print(_obs_string(obs))
                scv_y, scv_x = (units.Terran.SCV == unit_type).nonzero()
                print('scv locations: ', sorted(list(zip(scv_x, scv_y))))
                print('available actions: ',
                      list(sorted(obs.available_actions)))
                print('Making action: %s' % (func, ))

                # Ensure action is available.
                # If a build action is available, we have managed to target an SCV.
                self.assertIn(func.function, obs.available_actions)

                if (func.function
                        in (actions.FUNCTIONS.Build_SupplyDepot_screen.id,
                            actions.FUNCTIONS.Build_Barracks_screen.id)):
                    # Ensure we can build on that position.
                    x, y = func.arguments[1]
                    self.assertEqual(_EMPTY, unit_type[y, x])

                action = f.transform_action(o, func)
                last_actions = [func.function]
                controller.act(action)
            else:
                last_actions = []

            controller.step()

        replay_data = controller.save_replay()
        return replay_data, observations
示例#21
0
    def testCanPickleSpecs(self):
        feats = features.Features(feature_screen_size=84,
                                  feature_minimap_size=64)

        action_spec = feats.action_spec()
        observation_spec = feats.observation_spec()

        self.assertEqual(action_spec, pickle.loads(pickle.dumps(action_spec)))
        self.assertEqual(observation_spec,
                         pickle.loads(pickle.dumps(observation_spec)))
示例#22
0
 def testReversingUnknownAction(self):
     feats = features.Features(feature_screen_width=84,
                               feature_screen_height=80,
                               feature_minimap_width=64,
                               feature_minimap_height=67,
                               hide_specific_actions=False)
     sc2_action = sc_pb.Action()
     sc2_action.action_feature_layer.unit_command.ability_id = 6  # Cheer
     func_call = feats.reverse_action(sc2_action)
     self.assertEqual(func_call.function, 0)  # No-op
示例#23
0
 def testIdsMatchIndex(self):
     feats = features.Features(feature_screen_width=84,
                               feature_screen_height=80,
                               feature_minimap_width=64,
                               feature_minimap_height=67)
     action_spec = feats.action_spec()
     for func_index, func_def in enumerate(action_spec.functions):
         self.assertEqual(func_index, func_def.id)
     for type_index, type_def in enumerate(action_spec.types):
         self.assertEqual(type_index, type_def.id)
示例#24
0
    def testCanPickleSpecs(self):
        feats = features.Features(
            features.AgentInterfaceFormat(
                feature_dimensions=SQUARE_DIMENSIONS))
        action_spec = feats.action_spec()
        observation_spec = feats.observation_spec()

        self.assertEqual(action_spec, pickle.loads(pickle.dumps(action_spec)))
        self.assertEqual(observation_spec,
                         pickle.loads(pickle.dumps(observation_spec)))
示例#25
0
 def setUp(self):
     self._features = features.Features(features.AgentInterfaceFormat(
         feature_dimensions=features.Dimensions(screen=(64, 60),
                                                minimap=(32, 28)),
         rgb_dimensions=features.Dimensions(screen=(128, 124),
                                            minimap=(64, 60)),
         action_space=actions.ActionSpace.FEATURES,
         use_feature_units=True),
                                        map_size=point.Point(256, 256))
     self._obs_spec = self._features.observation_spec()
     self._builder = dummy_observation.Builder(self._obs_spec)
示例#26
0
    def testValidFunctionsAreConsistent(self):
        feats = features.Features(
            features.AgentInterfaceFormat(
                feature_dimensions=RECTANGULAR_DIMENSIONS))

        valid_funcs = feats.action_spec()
        for func_def in valid_funcs.functions:
            func = actions.FUNCTIONS[func_def.id]
            self.assertEqual(func_def.id, func.id)
            self.assertEqual(func_def.name, func.name)
            self.assertEqual(len(func_def.args), len(func.args))  # pylint: disable=g-generic-assert
    def start(self):
        _features = features.Features(self.controller.game_info())

        frames = random.sample(
            np.arange(self.info.game_duration_loops).tolist(),
            self.info.game_duration_loops)
        # frames = frames[0 : min(self.frames_per_game, self.info.game_duration_loops)]
        step_mul = 10
        frames = frames[0:int(self.info.game_duration_loops) // step_mul]
        frames.sort()

        last_frame = 0
        i = 0
        # for frame in frames:
        skips = step_mul
        while i < self.info.game_duration_loops:
            # skips = frame - last_frame
            # last_frame = frame
            i += skips
            self.controller.step(skips)
            obs = self.controller.observe()
            agent_obs = _features.transform_obs(obs.observation)

            if obs.player_result:  # Episode over.
                self._state = StepType.LAST
                discount = 0
            else:
                discount = self.discount

            self._episode_steps += skips

            step = TimeStep(step_type=self._state,
                            reward=0,
                            discount=discount,
                            observation=agent_obs)

            self.agent.step(step, obs.actions, self.info, _features)

            if obs.player_result:
                break

            self._state = StepType.MID

        print("Saving data")
        #print(self.agent.states)
        pickle.dump({
            "info": self.info,
            "state": self.agent.states
        }, open("data/" + self.replay_file_name + ".p", "wb"))
        print("Data successfully saved")
        self.agent.states = []
        print("Data flushed")

        print("Done")
示例#28
0
    def testValidFunctionsAreConsistent(self):
        feats = features.Features(feature_screen_width=84,
                                  feature_screen_height=80,
                                  feature_minimap_width=64,
                                  feature_minimap_height=67)

        valid_funcs = feats.action_spec()
        for func_def in valid_funcs.functions:
            func = actions.FUNCTIONS[func_def.id]
            self.assertEqual(func_def.id, func.id)
            self.assertEqual(func_def.name, func.name)
            self.assertEqual(len(func_def.args), len(func.args))
示例#29
0
    def process_replay(self, controller, replay_data, map_data, player_id):
        """Process a single replay, updating the stats."""
        self._update_stage("start_replay")
        controller.start_replay(
            sc_pb.RequestStartReplay(replay_data=replay_data,
                                     map_data=map_data,
                                     options=interface,
                                     observed_player_id=player_id))

        feat = features.Features(controller.game_info())

        self.stats.replay_stats.replays += 1
        self._update_stage("step")
        controller.step()
        while True:
            self.stats.replay_stats.steps += 1
            self._update_stage("observe")
            obs = controller.observe()

            for action in obs.actions:
                act_fl = action.action_feature_layer
                if act_fl.HasField("unit_command"):
                    self.stats.replay_stats.made_abilities[
                        act_fl.unit_command.ability_id] += 1
                if act_fl.HasField("camera_move"):
                    self.stats.replay_stats.camera_move += 1
                if act_fl.HasField("unit_selection_point"):
                    self.stats.replay_stats.select_pt += 1
                if act_fl.HasField("unit_selection_rect"):
                    self.stats.replay_stats.select_rect += 1
                if action.action_ui.HasField("control_group"):
                    self.stats.replay_stats.control_group += 1

                try:
                    func = feat.reverse_action(action).function
                except ValueError:
                    func = -1
                self.stats.replay_stats.made_actions[func] += 1

            for valid in obs.observation.abilities:
                self.stats.replay_stats.valid_abilities[valid.ability_id] += 1

            for u in obs.observation.raw_data.units:
                self.stats.replay_stats.unit_ids[u.unit_type] += 1

            for ability_id in feat.available_actions(obs.observation):
                self.stats.replay_stats.valid_actions[ability_id] += 1

            if obs.player_result:
                break

            self._update_stage("step")
            controller.step(FLAGS.step_mul)
示例#30
0
 def testFlRgbObservationSpec(self):
     feats = features.Features(
         features.AgentInterfaceFormat(
             feature_dimensions=RECTANGULAR_DIMENSIONS,
             rgb_dimensions=features.Dimensions(screen=(128, 132),
                                                minimap=(74, 77)),
             action_space=actions.ActionSpace.FEATURES))
     obs_spec = feats.observation_spec()
     self.assertEqual(obs_spec["feature_screen"], (17, 80, 84))
     self.assertEqual(obs_spec["feature_minimap"], (7, 67, 64))
     self.assertEqual(obs_spec["rgb_screen"], (132, 128, 3))
     self.assertEqual(obs_spec["rgb_minimap"], (77, 74, 3))