Exemplo n.º 1
0
    def test_slicing(self):
        a = named_array.NamedNumpyArray([1, 2, 3, 4, 5], list("abcde"))
        self.assertArrayEqual(a[:], [1, 2, 3, 4, 5])
        self.assertArrayEqual(a[::], [1, 2, 3, 4, 5])
        self.assertArrayEqual(a[::2], [1, 3, 5])
        self.assertArrayEqual(a[::-1], [5, 4, 3, 2, 1])
        self.assertEqual(a[:].a, 1)
        self.assertEqual(a[::].b, 2)
        self.assertEqual(a[::2].c, 3)
        with self.assertRaises(AttributeError):
            a[::2].d  # pylint: disable=pointless-statement
        self.assertEqual(a[::-1].e, 5)
        self.assertArrayEqual(a[a % 2 == 0], [2, 4])
        self.assertEqual(a[a % 2 == 0].b, 2)

        a = named_array.NamedNumpyArray([[1, 2, 3, 4], [5, 6, 7, 8]],
                                        [None, list("abcd")])
        self.assertArrayEqual(a[:], [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertArrayEqual(a[::], [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertArrayEqual(a[:, :], [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertArrayEqual(a[:, ...], [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertArrayEqual(a[..., ::], [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertArrayEqual(a[:, ::2], [[1, 3], [5, 7]])

        self.assertArrayEqual(a[::-1], [[5, 6, 7, 8], [1, 2, 3, 4]])
        self.assertArrayEqual(a[..., ::-1], [[4, 3, 2, 1], [8, 7, 6, 5]])
        self.assertArrayEqual(a[:, ::-1], [[4, 3, 2, 1], [8, 7, 6, 5]])
        self.assertArrayEqual(a[:, ::-2], [[4, 2], [8, 6]])
        self.assertArrayEqual(a[:, -2::-2], [[3, 1], [7, 5]])
        self.assertArrayEqual(a[::-1, -2::-2], [[7, 5], [3, 1]])
        self.assertArrayEqual(a[..., 0, 0], 1)  # weird scalar arrays...

        a = named_array.NamedNumpyArray(
            [[[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
             [[[8, 9], [10, 11]], [[12, 13], [14, 15]]]],
            [["a", "b"], ["c", "d"], ["e", "f"], ["g", "h"]])
        self.assertEqual(a.a.c.e.g, 0)
        self.assertEqual(a.b.c.f.g, 10)
        self.assertEqual(a.b.d.f.h, 15)
        self.assertArrayEqual(a[0, ..., 0], [[0, 2], [4, 6]])
        self.assertArrayEqual(a[0, ..., 1], [[1, 3], [5, 7]])
        self.assertArrayEqual(a[0, 0, ..., 1], [1, 3])
        self.assertArrayEqual(a[0, ..., 1, 1], [3, 7])
        self.assertArrayEqual(a[..., 1, 1], [[3, 7], [11, 15]])
        self.assertArrayEqual(a[1, 0, ...], [[8, 9], [10, 11]])

        self.assertArrayEqual(a["a", ..., "g"], [[0, 2], [4, 6]])
        self.assertArrayEqual(a["a", ...],
                              [[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
        self.assertArrayEqual(a[..., "g"],
                              [[[0, 2], [4, 6]], [[8, 10], [12, 14]]])
        self.assertArrayEqual(a["a", "c"], [[0, 1], [2, 3]])
        self.assertArrayEqual(a["a", ...].c, [[0, 1], [2, 3]])
        self.assertArrayEqual(a["a", ..., "g"].c, [0, 2])

        with self.assertRaises(TypeError):
            a[np.array([[0, 1], [0, 1]])]  # pylint: disable=pointless-statement, expression-not-assigned

        with self.assertRaises(IndexError):
            a[..., 0, ...]  # pylint: disable=pointless-statement
Exemplo n.º 2
0
    def test_string(self):
        a = named_array.NamedNumpyArray([1, 3, 6], ["a", "b", "c"],
                                        dtype=np.int32)
        self.assertEqual(str(a), "[1 3 6]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([1, 3, 6], ['a', 'b', 'c'], "
                          "dtype=int32)"))

        a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [None, ["a", "b"]])
        self.assertEqual(str(a), "[[1 3]\n [6 8]]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([[1, 3],\n"
                          "                 [6, 8]], [None, ['a', 'b']])"))

        a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [["a", "b"], None])
        self.assertEqual(str(a), "[[1 3]\n [6 8]]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([[1, 3],\n"
                          "                 [6, 8]], [['a', 'b'], None])"))

        a = named_array.NamedNumpyArray([list(range(50))] * 50,
                                        [None, ["a%s" % i for i in range(50)]])
        self.assertIn("49", str(a))
        self.assertIn("49", repr(a))

        a = named_array.NamedNumpyArray([list(range(50))] * 50,
                                        [["a%s" % i for i in range(50)], None])
        self.assertIn("49", str(a))
        self.assertIn("49", repr(a))
Exemplo n.º 3
0
 def test_empty_array(self):
     named_array.NamedNumpyArray([], [None, ["a", "b"]])
     with self.assertRaises(ValueError):
         # Must be the right length.
         named_array.NamedNumpyArray([], [["a", "b"]])
     with self.assertRaises(ValueError):
         # Returning an empty slice is not supported, and it's not clear how or
         # even if it should be supported.
         named_array.NamedNumpyArray([], [["a", "b"], None])
Exemplo n.º 4
0
 def test_masking(self):
     a = named_array.NamedNumpyArray([[1, 2, 3, 4], [5, 6, 7, 8]],
                                     [None, list("abcd")])
     self.assertArrayEqual(a[a > 2], [3, 4, 5, 6, 7, 8])
     self.assertArrayEqual(a[a < 4], [1, 2, 3])
     self.assertArrayEqual(a[a % 2 == 0], [2, 4, 6, 8])
     self.assertArrayEqual(a[a % 3 == 0], [3, 6])
Exemplo n.º 5
0
 def test_named_array_multi_first(self):
     a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [["a", "b"], None])
     self.assertArrayEqual(a.a, [1, 3])
     self.assertArrayEqual(a[1], [6, 8])
     self.assertArrayEqual(a["b"], [6, 8])
     self.assertArrayEqual(a[::-1], [[6, 8], [1, 3]])
     self.assertArrayEqual(a[::-1][::-1], [[1, 3], [6, 8]])
     self.assertArrayEqual(a[::-1, ::-1], [[8, 6], [3, 1]])
     self.assertArrayEqual(a[::-1][0], [6, 8])
     self.assertArrayEqual(a[::-1, 0], [6, 1])
     self.assertArrayEqual(a[::-1, 1], [8, 3])
     self.assertArrayEqual(a[::-1].a, [1, 3])
     self.assertArrayEqual(a[::-1].a[0], 1)
     self.assertArrayEqual(a[::-1].b, [6, 8])
     self.assertArrayEqual(a[[0, 0]], [[1, 3], [1, 3]])
     with self.assertRaises(TypeError):
         a[[0, 0]].a  # pylint: disable=pointless-statement
     self.assertEqual(a[0, 1], 3)
     self.assertEqual(a[(0, 1)], 3)
     self.assertEqual(a["a", 0], 1)
     self.assertEqual(a["b", 0], 6)
     self.assertEqual(a["b", 1], 8)
     self.assertEqual(a.a[0], 1)
     self.assertArrayEqual(a[a > 2], [3, 6, 8])
     self.assertArrayEqual(a[a % 3 == 0], [3, 6])
     with self.assertRaises(TypeError):
         a[0].a  # pylint: disable=pointless-statement
Exemplo n.º 6
0
    def test_single_dimension(self, names):
        a = named_array.NamedNumpyArray([1, 3, 6], names)
        self.assertEqual(a[0], 1)
        self.assertEqual(a[1], 3)
        self.assertEqual(a[2], 6)
        self.assertEqual(a[-1], 6)
        self.assertEqual(a.a, 1)
        self.assertEqual(a.b, 3)
        self.assertEqual(a.c, 6)
        with self.assertRaises(AttributeError):
            a.d  # pylint: disable=pointless-statement
        self.assertEqual(a["a"], 1)
        self.assertEqual(a["b"], 3)
        self.assertEqual(a["c"], 6)
        with self.assertRaises(KeyError):
            a["d"]  # pylint: disable=pointless-statement

        # range slicing
        self.assertArrayEqual(a[0:2], [1, 3])
        self.assertArrayEqual(a[1:3], [3, 6])
        self.assertArrayEqual(a[0:2:], [1, 3])
        self.assertArrayEqual(a[0:2:1], [1, 3])
        self.assertArrayEqual(a[::2], [1, 6])
        self.assertArrayEqual(a[::-1], [6, 3, 1])
        self.assertEqual(a[1:3][0], 3)
        self.assertEqual(a[1:3].b, 3)
        self.assertEqual(a[1:3].c, 6)

        # list slicing
        self.assertArrayEqual(a[[0, 0]], [1, 1])
        self.assertArrayEqual(a[[0, 1]], [1, 3])
        self.assertArrayEqual(a[[1, 0]], [3, 1])
        self.assertArrayEqual(a[[1, 2]], [3, 6])
        self.assertArrayEqual(a[np.array([0, 2])], [1, 6])
        self.assertEqual(a[[1, 2]].b, 3)
        self.assertEqual(a[[2, 0]].c, 6)
        with self.assertRaises(TypeError):
            # Duplicates lead to unnamed dimensions.
            a[[0, 0]].a  # pylint: disable=pointless-statement

        a[1] = 4
        self.assertEqual(a[1], 4)
        self.assertEqual(a.b, 4)
        self.assertEqual(a["b"], 4)

        a[1:2] = 2
        self.assertEqual(a[1], 2)
        self.assertEqual(a.b, 2)
        self.assertEqual(a["b"], 2)

        a[[1]] = 3
        self.assertEqual(a[1], 3)
        self.assertEqual(a.b, 3)
        self.assertEqual(a["b"], 3)

        a.b = 5
        self.assertEqual(a[1], 5)
        self.assertEqual(a.b, 5)
        self.assertEqual(a["b"], 5)
Exemplo n.º 7
0
 def test_named_array_multi_second(self):
     a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [None, ["a", "b"]])
     self.assertArrayEqual(a[0], [1, 3])
     self.assertEqual(a[0, 1], 3)
     self.assertEqual(a[0, "a"], 1)
     self.assertEqual(a[0, "b"], 3)
     self.assertEqual(a[1, "b"], 8)
     self.assertEqual(a[0].a, 1)
     with self.assertRaises(TypeError):
         a.a  # pylint: disable=pointless-statement
Exemplo n.º 8
0
 def test_named_array_multi_first(self):
     a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [["a", "b"], None])
     self.assertArrayEqual(a.a, [1, 3])
     self.assertArrayEqual(a[1], [6, 8])
     self.assertArrayEqual(a["b"], [6, 8])
     self.assertEqual(a[0, 1], 3)
     self.assertEqual(a["a", 0], 1)
     self.assertEqual(a["b", 0], 6)
     self.assertEqual(a["b", 1], 8)
     self.assertEqual(a.a[0], 1)
     with self.assertRaises(TypeError):
         a[0].a  # pylint: disable=pointless-statement
Exemplo n.º 9
0
def ext_score(obs):
    score_details = obs.observation.score.score_details
    return named_array.NamedNumpyArray([
        obs.observation.score.score,
        score_details.idle_production_time,
        score_details.idle_worker_time,
        score_details.total_value_units,
        score_details.total_value_structures,
        score_details.killed_value_units,
        score_details.killed_value_structures,
        score_details.collected_minerals,
        score_details.collected_vespene,
        score_details.collection_rate_minerals,
        score_details.collection_rate_vespene,
        score_details.spent_minerals,
        score_details.spent_vespene,
    ],
                                       names=features.ScoreCumulative,
                                       dtype=np.int32)
Exemplo n.º 10
0
    def test_string(self):
        a = named_array.NamedNumpyArray([1, 3, 6], ["a", "b", "c"],
                                        dtype=np.int32)
        self.assertEqual(str(a), "[1 3 6]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([1, 3, 6], ['a', 'b', 'c'], "
                          "dtype=int32)"))

        a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [None, ["a", "b"]])
        self.assertEqual(str(a), "[[1 3]\n [6 8]]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([[1, 3],\n"
                          "                 [6, 8]], [None, ['a', 'b']])"))

        a = named_array.NamedNumpyArray([[1, 3], [6, 8]], [["a", "b"], None])
        self.assertEqual(str(a), "[[1 3]\n [6 8]]")
        self.assertEqual(repr(a),
                         ("NamedNumpyArray([[1, 3],\n"
                          "                 [6, 8]], [['a', 'b'], None])"))

        a = named_array.NamedNumpyArray(
            [0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [str(i) for i in range(13)],
            dtype=np.int32)
        numpy_repr = np.array_repr(a)
        if "\n" in numpy_repr:  # ie numpy > 1.14
            self.assertEqual(repr(a), """
NamedNumpyArray([ 0,  0,  0, 50,  0,  0,  0,  0,  0,  0,  0,  0,  0],
                ['0', '1', '2', '3', '4', '...', '8', '9', '10', '11', '12'],
                dtype=int32)""".strip())  # Keep the middle newlines.
        else:
            self.assertEqual(repr(a), (
                "NamedNumpyArray("
                "[ 0,  0,  0, 50,  0,  0,  0,  0,  0,  0,  0,  0,  0], "
                "['0', '1', '2', '3', '4', '...', '8', '9', '10', '11', '12'], "
                "dtype=int32)"))  # Note the lack of newlines.

        a = named_array.NamedNumpyArray([list(range(50))] * 50,
                                        [None, ["a%s" % i for i in range(50)]])
        self.assertIn("49", str(a))
        self.assertIn("49", repr(a))
        self.assertIn("a4", repr(a))
        self.assertIn("a49", repr(a))

        a = named_array.NamedNumpyArray([list(range(50))] * 50,
                                        [["a%s" % i for i in range(50)], None])
        self.assertIn("49", str(a))
        self.assertIn("49", repr(a))
        self.assertIn("a4", repr(a))
        self.assertIn("a49", repr(a))
Exemplo n.º 11
0
    def test_single_dimension(self, names):
        a = named_array.NamedNumpyArray([1, 3, 6], names)
        self.assertEqual(a[0], 1)
        self.assertEqual(a[1], 3)
        self.assertEqual(a[2], 6)
        self.assertEqual(a[-1], 6)
        self.assertEqual(a.a, 1)
        self.assertEqual(a.b, 3)
        self.assertEqual(a.c, 6)
        with self.assertRaises(AttributeError):
            a.d  # pylint: disable=pointless-statement
        self.assertEqual(a["a"], 1)
        self.assertEqual(a["b"], 3)
        self.assertEqual(a["c"], 6)
        with self.assertRaises(KeyError):
            a["d"]  # pylint: disable=pointless-statement

        self.assertArrayEqual(a[0:2], [1, 3])
        self.assertArrayEqual(a[1:3], [3, 6])
        self.assertArrayEqual(a[0:2:], [1, 3])
        self.assertArrayEqual(a[0:2:1], [1, 3])
        self.assertEqual(a[1:3][0], 3)
        self.assertEqual(a[1:3, 0], 3)
        self.assertEqual(a[1:3].b, 3)
        self.assertEqual(a[1:3].c, 6)

        a[1] = 4
        self.assertEqual(a[1], 4)
        self.assertEqual(a.b, 4)
        self.assertEqual(a["b"], 4)

        a[1:2] = 2
        self.assertEqual(a[1], 2)
        self.assertEqual(a.b, 2)
        self.assertEqual(a["b"], 2)

        a.b = 5
        self.assertEqual(a[1], 5)
        self.assertEqual(a.b, 5)
        self.assertEqual(a["b"], 5)
Exemplo n.º 12
0
    def custom_transform_obs(self, obs):
        """Customized rendering of SC2 observations into something an agent can handle."""
        out = self.transform_obs(obs)
        aif = self._agent_interface_format

        def or_zeros(layer, size):
            if layer is not None:
                return layer.astype(np.int32, copy=False)
            else:
                return np.zeros((size.y, size.x), dtype=np.int32)

        if aif.feature_dimensions:
            out['feature_spatial'] = named_array.NamedNumpyArray(
                np.stack(
                    or_zeros(f.unpack(obs.observation),
                             aif.feature_dimensions.minimap)
                    for f in SPATIAL_FEATURES),
                names=[SpatialFeatures, None, None])

        if aif.rgb_dimensions:
            raise NotImplementedError

        return out
Exemplo n.º 13
0
def _get_minimaps(obs):
    # required game version >= 4.8.2, so we get all the maps from minimap
    def or_zeros(layer, size):
        if layer is not None:
            return layer.astype(np.int32, copy=False)
        else:
            return np.zeros((size.y, size.x), dtype=np.int32)

    if hasattr(obs.observation, 'feature_minimap'):
        return obs.observation.feature_minimap
    elif hasattr(obs.observation, 'feature_layer_data'):
        feature_minimap = named_array.NamedNumpyArray(np.stack(
            or_zeros(
                f.unpack(obs.observation),
                obs.observation.feature_layer_data.minimap_renders.creep.size)
            for f in MINIMAP_FEATURES),
                                                      names=[
                                                          MinimapFeatures,
                                                          None, None
                                                      ])
        return feature_minimap
    else:
        raise KeyError(
            'obs.observation has no feature_minimap or feature_layer_data!')
Exemplo n.º 14
0
    def get_current_state(self, obs):
        base = len(self.get_units_by_type(obs, units.Zerg.Hatchery))
        base = base if base > 0 else len(
            self.get_units_by_type(obs, units.Zerg.Lair)) * 2
        current_state = named_array.NamedNumpyArray(
            [
                base,
                len(self.get_units_by_type(obs, units.Zerg.SpawningPool)),
                len(self.get_units_by_type(obs, units.Zerg.Queen)),
                len(self.get_units_by_type(obs, units.Zerg.Extractor)),
                len(self.get_units_by_type(obs, units.Zerg.RoachWarren)),
                # obs.observation.player.larva_count always zero
                self.restraint_value_space(
                    len(self.get_units_by_type(obs, units.Zerg.Larva))),
                obs.observation.player.army_count,
                self.restraint_value_space(
                    obs.observation.player.idle_worker_count),
                self.enemy_race.value,
                *self.get_alliance_squares(obs, features.PlayerRelative.ENEMY),
                *self.get_alliance_squares(obs, features.PlayerRelative.SELF)
            ],
            names=State)

        return current_state
Exemplo n.º 15
0
 def test_pickle(self):
     arr = named_array.NamedNumpyArray([1, 3, 6], ["a", "b", "c"])
     pickled = pickle.loads(pickle.dumps(arr))
     self.assertTrue(np.all(arr == pickled))
     self.assertEqual(repr(pickled),
                      "NamedNumpyArray([1, 3, 6], ['a', 'b', 'c'])")
Exemplo n.º 16
0
    def transform_obs(self, obs):
        """Render some SC2 observations into something an agent can handle."""
        empty = np.array([], dtype=np.int32).reshape((0, 7))
        out = named_array.NamedDict({  # Fill out some that are sometimes empty.
            "single_select": empty,
            "multi_select": empty,
            "build_queue": empty,
            "cargo": empty,
            "cargo_slots_available": np.array([0], dtype=np.int32),
        })

        def or_zeros(layer, size):
            if layer is not None:
                return layer.astype(np.int32, copy=False)
            else:
                return np.zeros((size.y, size.x), dtype=np.int32)

        aif = self._agent_interface_format

        if aif.feature_dimensions:
            out["feature_screen"] = named_array.NamedNumpyArray(
                np.stack(
                    or_zeros(f.unpack(obs.observation),
                             aif.feature_dimensions.screen)
                    for f in SCREEN_FEATURES),
                names=[ScreenFeatures, None, None])
            out["feature_minimap"] = named_array.NamedNumpyArray(
                np.stack(
                    or_zeros(f.unpack(obs.observation),
                             aif.feature_dimensions.minimap)
                    for f in MINIMAP_FEATURES),
                names=[MinimapFeatures, None, None])

        if aif.rgb_dimensions:
            out["rgb_screen"] = Feature.unpack_rgb_image(
                obs.observation.render_data.map).astype(np.int32)
            out["rgb_minimap"] = Feature.unpack_rgb_image(
                obs.observation.render_data.minimap).astype(np.int32)

        out["last_actions"] = np.array(
            [self.reverse_action(a).function for a in obs.actions],
            dtype=np.int32)

        out["action_result"] = np.array([o.result for o in obs.action_errors],
                                        dtype=np.int32)

        out["alerts"] = np.array(obs.observation.alerts, dtype=np.int32)

        out["game_loop"] = np.array([obs.observation.game_loop],
                                    dtype=np.int32)

        score_details = obs.observation.score.score_details
        out["score_cumulative"] = named_array.NamedNumpyArray(
            [
                obs.observation.score.score,
                score_details.idle_production_time,
                score_details.idle_worker_time,
                score_details.total_value_units,
                score_details.total_value_structures,
                score_details.killed_value_units,
                score_details.killed_value_structures,
                score_details.collected_minerals,
                score_details.collected_vespene,
                score_details.collection_rate_minerals,
                score_details.collection_rate_vespene,
                score_details.spent_minerals,
                score_details.spent_vespene,
            ],
            names=ScoreCumulative,
            dtype=np.int32)

        def get_score_details(key, details, categories):
            row = getattr(details, key.name)
            return [getattr(row, category.name) for category in categories]

        out["score_by_category"] = named_array.NamedNumpyArray(
            [
                get_score_details(key, score_details, ScoreCategories)
                for key in ScoreByCategory
            ],
            names=[ScoreByCategory, ScoreCategories],
            dtype=np.int32)

        out["score_by_vital"] = named_array.NamedNumpyArray(
            [
                get_score_details(key, score_details, ScoreVitals)
                for key in ScoreByVital
            ],
            names=[ScoreByVital, ScoreVitals],
            dtype=np.int32)

        player = obs.observation.player_common
        out["player"] = named_array.NamedNumpyArray([
            player.player_id,
            player.minerals,
            player.vespene,
            player.food_used,
            player.food_cap,
            player.food_army,
            player.food_workers,
            player.idle_worker_count,
            player.army_count,
            player.warp_gate_count,
            player.larva_count,
        ],
                                                    names=Player,
                                                    dtype=np.int32)

        def unit_vec(u):
            return np.array(
                (
                    u.unit_type,
                    u.player_relative,
                    u.health,
                    u.shields,
                    u.energy,
                    u.transport_slots_taken,
                    int(u.build_progress * 100),  # discretize
                ),
                dtype=np.int32)

        ui = obs.observation.ui_data

        with sw("ui"):
            groups = np.zeros((10, 2), dtype=np.int32)
            for g in ui.groups:
                groups[g.control_group_index, :] = (g.leader_unit_type,
                                                    g.count)
            out["control_groups"] = groups

            if ui.single:
                out["single_select"] = named_array.NamedNumpyArray(
                    [unit_vec(ui.single.unit)], [None, UnitLayer])

            if ui.multi and ui.multi.units:
                out["multi_select"] = named_array.NamedNumpyArray(
                    [unit_vec(u) for u in ui.multi.units], [None, UnitLayer])

            if ui.cargo and ui.cargo.passengers:
                out["single_select"] = named_array.NamedNumpyArray(
                    [unit_vec(ui.single.unit)], [None, UnitLayer])
                out["cargo"] = named_array.NamedNumpyArray(
                    [unit_vec(u) for u in ui.cargo.passengers],
                    [None, UnitLayer])
                out["cargo_slots_available"] = np.array(
                    [ui.cargo.slots_available], dtype=np.int32)

            if ui.production and ui.production.build_queue:
                out["single_select"] = named_array.NamedNumpyArray(
                    [unit_vec(ui.production.unit)], [None, UnitLayer])
                out["build_queue"] = named_array.NamedNumpyArray(
                    [unit_vec(u) for u in ui.production.build_queue],
                    [None, UnitLayer])

        def full_unit_vec(u, pos_transform, is_raw=False):
            screen_pos = pos_transform.fwd_pt(point.Point.build(u.pos))
            screen_radius = pos_transform.fwd_dist(u.radius)
            return np.array(
                (
                    # Match unit_vec order
                    u.unit_type,
                    u.alliance,  # Self = 1, Ally = 2, Neutral = 3, Enemy = 4
                    u.health,
                    u.shield,
                    u.energy,
                    u.cargo_space_taken,
                    int(u.build_progress * 100),  # discretize

                    # Resume API order
                    int(u.health / u.health_max *
                        255) if u.health_max > 0 else 0,
                    int(u.shield / u.shield_max *
                        255) if u.shield_max > 0 else 0,
                    int(u.energy / u.energy_max *
                        255) if u.energy_max > 0 else 0,
                    u.display_type,  # Visible = 1, Snapshot = 2, Hidden = 3
                    u.owner,  # 1-15, 16 = neutral
                    screen_pos.x,
                    screen_pos.y,
                    u.facing,
                    screen_radius,
                    u.
                    cloak,  # Cloaked = 1, CloakedDetected = 2, NotCloaked = 3
                    u.is_selected,
                    u.is_blip,
                    u.is_powered,
                    u.mineral_contents,
                    u.vespene_contents,

                    # Not populated for enemies or neutral
                    u.cargo_space_max,
                    u.assigned_harvesters,
                    u.ideal_harvesters,
                    u.weapon_cooldown,
                    len(u.orders),
                    u.tag if is_raw else 0),
                dtype=np.int64)

        raw = obs.observation.raw_data

        if aif.use_feature_units:
            with sw("feature_units"):
                # Update the camera location so we can calculate world to screen pos
                self._update_camera(point.Point.build(raw.player.camera))
                feature_units = []
                for u in raw.units:
                    if u.is_on_screen and u.display_type != sc_raw.Hidden:
                        feature_units.append(
                            full_unit_vec(u, self._world_to_feature_screen_px))
                out["feature_units"] = named_array.NamedNumpyArray(
                    feature_units, [None, FeatureUnit], dtype=np.int32)

        if aif.use_raw_units:
            with sw("raw_units"):
                raw_units = [
                    full_unit_vec(u, self._world_to_world_tl, is_raw=True)
                    for u in raw.units
                ]
                out["raw_units"] = named_array.NamedNumpyArray(
                    raw_units, [None, FeatureUnit], dtype=np.int32)

        if aif.use_unit_counts:
            with sw("unit_counts"):
                unit_counts = collections.defaultdict(int)
                for u in raw.units:
                    if u.alliance == sc_raw.Self:
                        unit_counts[u.unit_type] += 1
                out["unit_counts"] = named_array.NamedNumpyArray(
                    sorted(unit_counts.items()), [None, UnitCounts],
                    dtype=np.int32)

        if aif.use_camera_position:
            camera_position = self._world_to_world_tl.fwd_pt(
                point.Point.build(raw.player.camera))
            out["camera_position"] = np.array(
                (camera_position.x, camera_position.y), dtype=np.int32)

        out["available_actions"] = np.array(self.available_actions(
            obs.observation),
                                            dtype=np.int32)

        return out
Exemplo n.º 17
0
 def test_bad_names(self, names):
     with self.assertRaises(ValueError):
         named_array.NamedNumpyArray([1, 3, 6], names)