class TerminalWrapper(gym.Wrapper): def __init__(self, env, max_step=4000): super(TerminalWrapper, self).__init__(env) self._max_step = max_step self._enemy_base = None def _reset(self): obs = self.env._reset() self._enemy_base = DestRange(self.env.unwrapped.enemy_base()) return obs def _step(self, action): obs, rwd, done, info = self.env._step(action) return obs, rwd, self._check_terminal(obs, done), info def _check_terminal(self, obs, done): if done: return done scout = self.env.unwrapped.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._enemy_base.check_enter(pos) self._enemy_base.check_hit(pos) self._enemy_base.check_leave(pos) if self._enemy_base.hit: print('***episode terminal while scout hit ***') return True if self._enemy_base.enter and self._enemy_base.leave: print('***episode terminal while scout enter and leave***') return True return done
class OnewayFinalReward(Reward): def __init__(self, weight=50): super(OnewayFinalReward, self).__init__(weight) def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) def compute_rwd(self, obs, reward, done, env): self._compute_rwd(env) if done: if self._dest.hit: #print('compute final rwd, hit rwd=', self.w * 2) self.rwd = self.w * 2 elif self._dest.enter: #print('compute final rwd, enter rwd=', self.w * 1) self.rwd = self.w * 1 else: self.rwd = self.w * -1 else: self.rwd = 0 def _compute_rwd(self, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._dest.check_enter(pos) self._dest.check_hit(pos) self._dest.check_leave(pos)
class RoundTripTerminalWrapper(gym.Wrapper): def __init__(self, env): super(RoundTripTerminalWrapper, self).__init__(env) self._enemy_base = None self._home_base = None self._back = False def _reset(self): obs = self.env._reset() self._enemy_base = DestRange(self.env.unwrapped.enemy_base()) self._home_base = DestRange(self.env.unwrapped.owner_base()) self._back = False return obs def _step(self, action): obs, rwd, done, info = self.env._step(action) info = self._judge_course(obs) return obs, rwd, self._check_terminal(obs, done), info def _check_terminal(self, obs, done): if self._back and self._home_base.enter: return True else: return done def _judge_course(self, obs): scout = self.env.unwrapped.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) if not self._back: self._enemy_base.check_enter(pos) self._enemy_base.check_hit(pos) self._enemy_base.check_leave(pos) else: self._home_base.check_enter(pos) self._home_base.check_hit(pos) self._home_base.check_leave(pos) if (self._enemy_base.enter and self._enemy_base.leave) or self._enemy_base.hit: if not self._back: self._back = True return self._back
class RoundTripFinalReward(Reward): def __init__(self, weight=50): super(RoundTripFinalReward, self).__init__(weight) self._back = False def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) self._src = DestRange(env.owner_base()) def compute_rwd(self, obs, reward, done, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._check(pos) if done: if self._dest.hit and self._src.hit: self.rwd = self.w * 2 elif self._dest.enter and self._src.enter: self.rwd = self.w * 1 else: self.rwd = 0 else: self.rwd = 0 print("RoundTripFinalReward", self.rwd) def _check(self, pos): if not self._back: self._check_dest(pos) else: self._check_src(pos) if self._dest.enter and self._dest.leave: if not self._back: self._back = True def _check_dest(self, pos): self._dest.check_enter(pos) self._dest.check_hit(pos) self._dest.check_leave(pos) def _check_src(self, pos): self._src.check_enter(pos) self._src.check_hit(pos) self._src.check_leave(pos)
class HitEnemyBaseReward(Reward): def __init__(self, weight=10): super(HitEnemyBaseReward, self).__init__(weight) def reset(self, obs, env): self._dest = DestRange(env.enemy_base()) self._hit_once = False def compute_rwd(self, obs, reward, done, env): scout = env.scout() pos = (scout.float_attr.pos_x, scout.float_attr.pos_y) self._dest.check_hit(pos) self._dest.check_enter(pos) self._dest.check_leave(pos) if not self._hit_once and (self._dest.hit or (self._dest.enter and self._dest.leave)): self.rwd = self.w * 1 self._hit_once = True else: self.rwd = 0 print("HitEnemyBaseReward", self.rwd, self._hit_once)