def __init__(self, env): self.sc2_env = env obs_entries_by_name = { # 'single_select': Space((1, 7), 0., 1., np.float32), # 'multi_select': Space(), # 'build_queue': Space(), # 'cargo': Space(), # 'cargo_slots_available': Space((1,), None, None, None), 'vision': Space((24, 84, 84), None, None, None), # 'player': Space((11,), None, None, None), 'control_groups': Space((10, 2), None, None, None), 'available_actions': Space((None,), None, None, None) } act_entries_by_name = { 'func_id': Space((524,), 0., 1., np.float32), 'screen_x': Space((80,), 0., 1., np.float32), 'screen_y': Space((80,), 0., 1., np.float32), 'minimap_x': Space((80,), 0., 1., np.float32), 'minimap_y': Space((80,), 0., 1., np.float32), 'screen2_x': Space((80,), 0., 1., np.float32), 'screen2_y': Space((80,), 0., 1., np.float32), 'queued': Space((2,), 0., 1., np.float32), 'control_group_act': Space((4,), 0., 1., np.float32), 'control_group_id': Space((10,), 0., 1., np.float32), 'select_point_act': Space((4,), 0., 1., np.float32), 'select_add': Space((2,), 0., 1., np.float32), 'select_unit_act': Space((4,), 0., 1., np.float32), 'select_unit_id': Space((500,), 0., 1., np.float32), 'select_worker': Space((4,), 0., 1., np.float32), 'unload_id': Space((500,), 0., 1., np.float32), 'build_queue_id': Space((10,), 0., 1., np.float32), } # remove_feat_op = SC2RemoveFeatures({'player_id'}) self._cpu_preprocessor = ObsPreprocessor( [FlattenSpace({'control_groups'})], Spaces(obs_entries_by_name) ) self._gpu_preprocessor = SC2RemoveAvailableActions( [CastToFloat(), SC2ScaleChannels(24)], self._cpu_preprocessor.observation_space ) self._observation_space = self._gpu_preprocessor.observation_space self._action_space = Spaces(act_entries_by_name)
def update_space(self, old_space): new_shape = (len(self.idxs),) + old_space.shape[1:] return Space(new_shape, old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): new_shape = (len(self._scalar_idxs) + len(reduce(lambda prev, cur: prev + cur, self._ranges_by_feature_idx.values())),) + old_space.shape[1:] return Space(new_shape, old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): return Space((1, ) + old_space.shape[:-1], old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): return Space((1, 84, 84), old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): return Space(old_space.shape, old_space.low, old_space.high, np.float32)
def update_space(self, old_space): return Space(old_space.shape, old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): return Space((reduce(lambda prev, cur: prev * cur, old_space.shape), ), old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): new_shape = (old_space.shape[0] * self.nb_frame, ) + old_space.shape[1:] return Space(new_shape, old_space.low, old_space.high, old_space.dtype)
def update_space(self, old_space): return Space(old_space.shape, 0., 1., np.float32)