class TerminalWrapper(gym.Wrapper): def __init__(self, env, max_step=4000): super(TerminalWrapper, self).__init__(env) self._max_step = max_step self._enemy_base = None def _reset(self): obs = self.env._reset() self._enemy_base = DestRange(self.env.unwrapped.enemy_base()) return obs def _step(self, action): obs, rwd, done, info = self.env._step(action) return obs, rwd, self._check_terminal(obs, done), info def _check_terminal(self, obs, done): if done: return done scout = self.env.unwrapped.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._enemy_base.check_enter(pos) self._enemy_base.check_hit(pos) self._enemy_base.check_leave(pos) if self._enemy_base.hit: print('***episode terminal while scout hit ***') return True if self._enemy_base.enter and self._enemy_base.leave: print('***episode terminal while scout enter and leave***') return True return done
class OnewayFinalReward(Reward): def __init__(self, weight=50): super(OnewayFinalReward, self).__init__(weight) def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) def compute_rwd(self, obs, reward, done, env): self._compute_rwd(env) if done: if self._dest.hit: #print('compute final rwd, hit rwd=', self.w * 2) self.rwd = self.w * 2 elif self._dest.enter: #print('compute final rwd, enter rwd=', self.w * 1) self.rwd = self.w * 1 else: self.rwd = self.w * -1 else: self.rwd = 0 def _compute_rwd(self, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._dest.check_enter(pos) self._dest.check_hit(pos) self._dest.check_leave(pos)
class RoundTripTerminalWrapper(gym.Wrapper): def __init__(self, env): super(RoundTripTerminalWrapper, self).__init__(env) self._enemy_base = None self._home_base = None self._back = False def _reset(self): obs = self.env._reset() self._enemy_base = DestRange(self.env.unwrapped.enemy_base()) self._home_base = DestRange(self.env.unwrapped.owner_base()) self._back = False return obs def _step(self, action): obs, rwd, done, info = self.env._step(action) info = self._judge_course(obs) return obs, rwd, self._check_terminal(obs, done), info def _check_terminal(self, obs, done): if self._back and self._home_base.enter: return True else: return done def _judge_course(self, obs): scout = self.env.unwrapped.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) if not self._back: self._enemy_base.check_enter(pos) self._enemy_base.check_hit(pos) self._enemy_base.check_leave(pos) else: self._home_base.check_enter(pos) self._home_base.check_hit(pos) self._home_base.check_leave(pos) if (self._enemy_base.enter and self._enemy_base.leave) or self._enemy_base.hit: if not self._back: self._back = True return self._back
class RoundTripFinalReward(Reward): def __init__(self, weight=50): super(RoundTripFinalReward, self).__init__(weight) self._back = False def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) self._src = DestRange(env.owner_base()) def compute_rwd(self, obs, reward, done, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._check(pos) if done: if self._dest.hit and self._src.hit: self.rwd = self.w * 2 elif self._dest.enter and self._src.enter: self.rwd = self.w * 1 else: self.rwd = 0 else: self.rwd = 0 print("RoundTripFinalReward", self.rwd) def _check(self, pos): if not self._back: self._check_dest(pos) else: self._check_src(pos) if self._dest.enter and self._dest.leave: if not self._back: self._back = True def _check_dest(self, pos): self._dest.check_enter(pos) self._dest.check_hit(pos) self._dest.check_leave(pos) def _check_src(self, pos): self._src.check_enter(pos) self._src.check_hit(pos) self._src.check_leave(pos)
class HitEnemyBaseReward(Reward): def __init__(self, weight=10): super(HitEnemyBaseReward, self).__init__(weight) def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) self._hit_once = False def compute_rwd(self, obs, reward, done, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._dest.check_hit(pos) self._dest.check_enter(pos) self._dest.check_leave(pos) if not self._hit_once and (self._dest.hit or (self._dest.enter and self._dest.leave)): self.rwd = self.w * 1 self._hit_once = True else: self.rwd = 0 print("HitEnemyBaseReward", self.rwd, self._hit_once)
class ExploreWithEvadeTerminalWrapper(RoundTripTerminalWrapper): def __init__(self, env): super(ExploreWithEvadeTerminalWrapper, self).__init__(env) self._judge_walkaround_dist = JUDGE_WALKAROUND_DISTANCE self._enemy_base_range_width = ENEMY_BASE_RANGE self._scout_range_width = SCOUT_RANGE self._map_size = self.env.unwrapped.map_size() self._explore_step_required = EXPLORE_STEP def _reset(self): obs = self.env._reset() self._judge_walkaround = False self._task_finished = False self._judge_back = False self._enemy_base_range_map = self.createRangeMap( self.env.unwrapped.enemy_base(), self._enemy_base_range_width) self._home_base = DestRange(self.env.unwrapped.owner_base()) self._curr_explore_step = 0 return obs def _step(self, action): obs, rwd, done, info = self.env._step(action) done = self._check_terminal(obs, done) self.judge_walkAround() self.judge_back() info = { 'walkaround': self._judge_walkaround, 'back': self._judge_back, 'finished': self._task_finished } return obs, rwd, self._check_terminal(obs, done), info def _check_terminal(self, obs, done): if done: return done scout = self.env.unwrapped.scout() # scout_health = scout.float_attr.health # max_health = scout.float_attr.health_max # print("self._home_base.enter",self._home_base.enter) if self._judge_back: self._home_base.check_enter( (scout.float_attr.pos_x, scout.float_attr.pos_y)) if self._home_base.enter: self._task_finished = True return True survive = self.env.unwrapped.scout_survive() if survive: return done else: return True def createRangeMap(self, pos, range_width): range_map = np.zeros(shape=(self._map_size[0], self._map_size[1])) for i in range(int(pos[0] - range_width), int(pos[0] + range_width + 1)): for j in range(int(pos[1] - range_width), int(pos[1] + range_width + 1)): range_map[i][j] = 1 return range_map def updateEnemyBaseMap(self, scout_map): for i in range(self._enemy_base_range_map.shape[0]): for j in range(self._enemy_base_range_map.shape[1]): if self._enemy_base_range_map[i][j] and scout_map[i][j]: self._enemy_base_range_map[i][j] = 0 def check_enemybase_in_range(self): pos_x, pos_y = self.env.unwrapped.enemy_base( )[0], self.env.unwrapped.enemy_base()[1] scout = self.env.unwrapped.scout() x_low = scout.float_attr.pos_x - self._judge_walkaround_dist x_high = scout.float_attr.pos_x + self._judge_walkaround_dist y_low = scout.float_attr.pos_y - self._judge_walkaround_dist y_high = scout.float_attr.pos_y + self._judge_walkaround_dist if pos_x > x_high or pos_x < x_low: return False if pos_y > y_high or pos_y < y_low: return False return True def judge_walkAround(self): if (not self._judge_walkaround) and (self.check_enemybase_in_range()): self._judge_walkaround = True def judge_back(self): print("curr_explore_step:", self._curr_explore_step) if self._judge_walkaround: self._curr_explore_step += 1 if (not self._judge_back) and (self._curr_explore_step > self._explore_step_required): self._judge_back = True