def _init(self): self.aircraft_dict = {} self.a2a_list = [] self.target_list = [] self.blue_list = [] self.blue_dic = {} self.attacking_targets = {} self.awacs_team_id = -1 self.disturb_team_id = -1 self.agent_state = 0 self.disturb_state = RedAgentState.AIR_DISTURB1 self.area_hurt_a = RedAgentState.AREA_HUNT11 self.area_hurt_b = RedAgentState.AREA_HUNT11 self.area_hurt_c = RedAgentState.AREA_HUNT11 self.area_hurt_d = RedAgentState.AREA_HUNT11 self.air_attack_time = 0 self.a2g_ha = 0 self.a2g_hb = 0 self.team_id_dic = {} self.Task = None self.dqn = DQN() self.reward = 0 self.done = False ### 先随意初始化一个状态向量 self.last_state = [10000]*config.a2a_LX11 + [6]*config.a2a_LX11 + \ [10000] * config.a2a_LX11 + [1] * config.a2a_LX11
def main(config_path, env_name, train_mode=True, weights_path=None): """Load the environment, create an agent, and train it. """ config = cutils.get_config(config_path) env = cutils.load_environment(env_name) action_size = env.action_space.n state_size = env.observation_space.shape memory = ReplayMem(buffer=config['exp_replay']['buffer']) av_model = SimpleNN(input_shape=state_size[0], output_shape=action_size) policy = EpsGreedy(eps=config['train']['eps_start'], decay=config['train']['eps_decay'], eps_end=config['train']['eps_end']) agent = DQN(config, seed=0, ob_space=state_size[0], ac_space=action_size, av_model=av_model, memory=memory, policy=policy) if weights_path is not None: agent.load(weights_path) game_logger = GameLogger(100, 10) # TODO Add winning threshold to arguments player = Player(agent=agent, env=env, config=config, game_logger=game_logger, train_mode=train_mode) player.play() return player.glogger.scores
def __init__(self, action_set, hParam): h, w = 84, 84 self.qNetwork = DQN(h, w, len(action_set)) self.targetNetwork = DQN(h, w, len(action_set)) self.targetNetwork.load_state_dict(self.qNetwork.state_dict()) self.optimizer = optim.Adam(self.qNetwork.parameters(), lr=1e-4) self.loss_func = nn.MSELoss() self.memory = ReplayMemory(hParam['BUFFER_SIZE']) # self.DISCOUNT_FACTOR = hParam['DISCOUNT_FACTOR'] # 0.99 self.steps_done = 0 self.EPS_START = hParam['EPS_START'] # 1.0 self.EPS_END = hParam['EPS_END'] self.EPS_ITER = 1000000 self.MAX_ITER = hParam['MAX_ITER'] self.eps_threshold = self.EPS_START self.BATCH_SIZE = hParam['BATCH_SIZE'] self.n_actions = len(action_set) # 2 self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.qNetwork.to(self.device) self.targetNetwork.to(self.device) self.qNetwork.train()
def predict(self, transaction, debug=False): if debug: return random.choice(CoinAgent.actions), [0, 0, 0] main_dqn = DQN(Paths.MODEL, len(self.actions)) action_dist = main_dqn.predict([transaction])[0] action = CoinAgent.Action(np.argmax(action_dist)) return action, action_dist
def test_generate_inputs(self): dqn = DQN(Paths.MODEL, len(CoinAgent.actions)) transactions = CoinTransaction.get_transactions(Paths.DATA, 'eth', max_size=5) for result in dqn.generate_input(transactions, []): self.assertIsNotNone(result)
def evaluate(self, transactions): main_dqn = DQN(Paths.MODEL, len(self.actions)) action_dists = main_dqn.predict(transactions) action_max_indices = np.argmax(action_dists, axis=1) actions = [CoinAgent.Action(index) for index in action_max_indices] portfolios, _ = self.__get_rewards(transactions, actions) print('Portfolio: {:>12,.2f}, {}, {}'.format( portfolios[-1], Counter(actions), Counter([tuple(x) for x in action_dists]).most_common(2))) return portfolios
def train(self, transactions, params): r = params['r'] epoch = params['epoch'] main_dqn = DQN(Paths.MODEL, len(self.actions)) copied = self.__copy_model(Paths.MODEL) target_dqn = DQN(copied, len(self.actions)) for n in range(epoch): action_dists = main_dqn.predict(transactions) action_max_indices = np.argmax(action_dists, axis=1) actions = [CoinAgent.Action(index) for index in action_max_indices] next_action_dists = target_dqn.predict(transactions[1:]) next_action_max_values = np.max(next_action_dists, axis=1) portfolios, rewards = self.__get_rewards(transactions, actions) target_dists = [] for i in range(len(transactions)): target_dist = list(action_dists[i]) target_dist[action_max_indices[i]] = rewards[i] if i < len(transactions) - 1: target_dist[ action_max_indices[i]] += r * next_action_max_values[i] target_dists.append(target_dist) result = main_dqn.train(transactions, target_dists) print('[{}] #{}, Loss: {:>8,.4f}, Portfolio: {:>12,.2f}, {}, {}'. format( n, result['global_step'], result['loss'], portfolios[-1], Counter(actions), Counter([tuple(x) for x in action_dists]).most_common(2))) if n > 0 and n % 10 == 0: copied = self.__copy_model(Paths.MODEL) target_dqn = DQN(copied, len(self.actions))
} env = gym.make("CartPole-v0") env = CartPole_Pixel( env, True, False, reference_domain, colors) # Randomize = True, regularize = False for testing env = FrameStack(env, 3) agent = DQN(env, CNN_cartpole, replay_start_size=1000, replay_buffer_size=100000, gamma=0.99, update_target_frequency=1000, minibatch_size=32, learning_rate=1e-4, initial_exploration_rate=1., final_exploration_rate=0.01, final_exploration_step=10000, adam_epsilon=1e-4, logging=True, loss='mse', lam=0, regularize=False, add_current_step=ADD_CURRENT_STEP) scores_all = np.zeros((3, NUM_SEEDS, len(colors))) # 3 agents for name_id, name in enumerate(agents.keys()): for se in range(NUM_SEEDS): print('seed : ', se) agent.load(agents[name][se])
from agent.dqn import DQN from hyperparameters import Config config = Config() config.name = 'CartPole-v0' # 初始化环境 env = gym.make(config.name) env.seed(1) env = env.unwrapped # 初始化agent agent = DQN(config=config, env=env, doubleDQN=False, duelingDQN=True, NoisyDQN=False, N_stepDQN=False, Prioritized=False) # agent.load() iteration = 0 # 总步数 # 记录时间 start_time = time.time() ep_reward_list = [] # 存放每回合的reward mean_ep_reward_list = [] # 整个训练过程的平均reward # 循环nepisode个episode for e in range(config.episode):
print('seed : ', se) env = gym.make("CartPole-v0") env = CartPole_Pixel(env, RANDOMIZE, REGULARIZE, reference_domain, colors) env = FrameStack(env, 3) env.metadata['_max_episode_steps'] = 200 #useless env.reset() agent = DQN(env, CNN_cartpole, replay_start_size=1000, replay_buffer_size=100000, gamma=0.99, update_target_frequency=1000, minibatch_size=32, learning_rate=1e-4, initial_exploration_rate=1., final_exploration_rate=0.01, final_exploration_step=10000, adam_epsilon=1e-4, logging=True, loss='mse', log_folder='results/' + str(name) + '_' + str(se), lam=LAMBDA, regularize=REGULARIZE, add_current_step=ADD_CURRENT_STEP) agent.learn(timesteps=nb_steps, verbose=True) agent.save() env.close()
class RedDQNAgent(Agent): def __init__(self, name, config, **kwargs): super().__init__(name, config['side']) self._init() def _init(self): self.aircraft_dict = {} self.a2a_list = [] self.target_list = [] self.blue_list = [] self.blue_dic = {} self.attacking_targets = {} self.awacs_team_id = -1 self.disturb_team_id = -1 self.agent_state = 0 self.disturb_state = RedAgentState.AIR_DISTURB1 self.area_hurt_a = RedAgentState.AREA_HUNT11 self.area_hurt_b = RedAgentState.AREA_HUNT11 self.area_hurt_c = RedAgentState.AREA_HUNT11 self.area_hurt_d = RedAgentState.AREA_HUNT11 self.air_attack_time = 0 self.a2g_ha = 0 self.a2g_hb = 0 self.team_id_dic = {} self.Task = None self.dqn = DQN() self.reward = 0 self.done = False ### 先随意初始化一个状态向量 self.last_state = [10000]*config.a2a_LX11 + [6]*config.a2a_LX11 + \ [10000] * config.a2a_LX11 + [1] * config.a2a_LX11 def reset(self): self._init() def step(self, sim_time, obs_red, **kwargs): self._parse_teams(obs_red) cmd_list = [] self.done = False self._parse_observation(obs_red) # print('红方情报:',obs_red['qb']) '''第一波次''' # 护卫舰初始化 if self.agent_state == 0: index = 1 for ship in obs_red['units']: if ship['LX'] == 21: if index == 1: cmd_list.extend( self._ship_movedeploy(ship['ID'], SHIP_POINT1)) print('1号护卫舰就位') index += 1 continue if index == 2: cmd_list.extend( self._ship_movedeploy(ship['ID'], SHIP_POINT2)) print('2号护卫舰就位') index += 1 continue if index == 3: cmd_list.extend( self._ship_movedeploy(ship['ID'], SHIP_POINT3)) print('3号护卫舰就位') index += 1 continue self.agent_state = 1 # 预警机1架--YA + 护航歼击机2机编队--JA if self.agent_state == 1: for awas in obs_red['units']: if awas['LX'] == 12: cmd_list.extend( self._awacs_patrol(awas['ID'], AWACS_PATROL_POINT, AWACS_PATROL_PARAMS)) print('预警机巡逻') self.agent_state = 5 if self.agent_state == 5: if 'YA' in list(self.team_id_dic.keys()): cmd_list.extend(self._awacs_escort(self.team_id_dic['YA'])) print('给预警机护航') self.agent_state = 6 # 干扰机3架(一起使用作360°干扰)--RA + 护航歼击机2机编队--JB # 正式发布版本会对模型进行聚合, 即只提供一架干扰机(所以这里作了修改) if self.agent_state == 6: cmd_list.extend( self._takeoff_areapatrol(1, 13, AIR_DISTURB_POINT1, DISTURB_PATROL_PARAMS)) print('干扰机起飞') self.agent_state = 7 if self.agent_state == 7: if 'RA' in list(self.team_id_dic.keys()): cmd_list.extend(self._disturb_escort(self.team_id_dic['RA'])) print('给干扰机护航') self.agent_state = 8 # 干扰机到达待定区域后下达区域干扰指令 # 在待定区开启航线干扰 if self.disturb_state == RedAgentState.AIR_DISTURB1: for disturb in obs_red['units']: if disturb['LX'] == 13 and 15000 < disturb[ 'X'] < 75000 and -20000 < disturb['Y'] < 20000: cmd_list.extend( self._disturb_linepatrol(self.team_id_dic['RA'], NORTH_LINE)) self.disturb_state = RedAgentState.AIR_DISTURB2 print('在待定区开启航线干扰') # 向北部指定区域行进干扰 if self.disturb_state == RedAgentState.AIR_DISTURB2: for disturb in obs_red['units']: if disturb['LX'] == 13 and -50000 < disturb[ 'X'] < -40000 and 50000 < disturb['Y'] < 60000: cmd_list.extend( self._disturb_patrol(self.team_id_dic['RA'], AIR_DISTURB_POINT2, DISTURB_PATROL_PARAMS)) self.disturb_state = RedAgentState.AIR_DISTURB3 print('在待定区开启区域干扰') # 干扰机南下干扰 ship_flag = True if self.disturb_state == RedAgentState.AIR_DISTURB3: for blue_unit in obs_red['qb']: if blue_unit['LX'] == 21: ship_flag = False # if ship_flag: if sim_time > 3000: cmd_list.extend( self._disturb_patrol(self.team_id_dic['RA'], AIR_DISTURB_POINT4, DISTURB_PATROL_PARAMS)) print('干扰机南下干扰') self.disturb_state = RedAgentState.AIR_DISTURB4 # 向南部指定区域行进干扰 if self.disturb_state == RedAgentState.AIR_DISTURB4: for disturb in obs_red['units']: if disturb['LX'] == 13 and -50000 < disturb[ 'X'] < -40000 and -60000 < disturb['Y'] < -50000: cmd_list.extend( self._disturb_patrol(self.team_id_dic['RA'], AIR_DISTURB_POINT4, DISTURB_PATROL_PARAMS)) self.disturb_state = RedAgentState.AIR_DISTURB5 # 轰炸机2机编队--HA + 护航歼击机2机编队--JC if self.agent_state == 15: cmd_list.extend( self._takeoff_areapatrol(2, 15, AREA_HUNT_POINT1, AREA_PATROL_PARAMS)) print('轰炸机HA起飞') self.agent_state = 9 for ship in obs_red['qb']: if ship['LX'] == 21 and 'HA' in list(self.team_id_dic.keys( )) and self.area_hurt_a == RedAgentState.AREA_HUNT11: cmd_list.extend( self._targethunt(self.team_id_dic['HA'], ship['ID'])) self.area_hurt_a = RedAgentState.AREA_HUNT12 print('HA进行目标突击,目标为蓝方舰船') if self.agent_state == 9: if 'HA' in list(self.team_id_dic.keys()): cmd_list.extend(self._A2G_escort(self.team_id_dic['HA'])) print('给轰炸机HA护航') self.agent_state = 10 # JC进入北部警戒阵位1 if self.a2g_ha == 0 and 'HA' in list(self.team_id_dic.keys()): for a2g in obs_red['units']: if a2g['TMID'] == self.team_id_dic['HA'] and -75000 < a2g[ 'X'] < -35000 and 50000 < a2g['Y'] < 80000: for a2a in obs_red['units']: if a2a['TMID'] == self.team_id_dic['JC']: cmd_list.extend( self._areapatrol(a2a['ID'], AIR_PATROL_POINT1, AIR_PATROL_PARAMS_0)) print('JC进入北部警戒阵位1') self.a2g_ha = 1 # 轰炸机6机编队--HB + 护航歼击机2机编队--JD if self.agent_state == 10: cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT0)) cmd_list.extend(self._takeoff_areahunt(4, AREA_HUNT_POINT0)) print('轰炸机HB起飞') self.agent_state = 11 if self.agent_state == 11: if 'HB' in list(self.team_id_dic.keys()): # cmd_list.extend(self._targethunt(self.team_id_dic['HB'], 5011)) cmd_list.extend(self._A2G_escort(self.team_id_dic['HB'])) print('给轰炸机HB护航') self.agent_state = 12 # JD进入北部警戒阵位2 if self.a2g_hb == 0 and 'HB' in list(self.team_id_dic.keys()): for a2g in obs_red['units']: if a2g['TMID'] == self.team_id_dic['HB'] and -75000 < a2g[ 'X'] < -35000 and 50000 < a2g['Y'] < 80000: for a2a in obs_red['units']: if a2a['TMID'] == self.team_id_dic['JD']: cmd_list.extend( self._areapatrol(a2a['ID'], AIR_PATROL_POINT2, AIR_PATROL_PARAMS_0)) self.a2g_hb = 1 # 阻援歼击机2机编队--JE if self.agent_state == 8: cmd_list.extend( self._takeoff_areapatrol(2, 11, AIR_PATROL_POINT3, AIR_PATROL_PARAMS)) print('阻援歼击机JE起飞') self.agent_state = 13 # 歼击机JG, JH-->中部阻援阵位(第二阶段从中部南下警戒) if self.agent_state == 13: cmd_list.extend( self._takeoff_areapatrol(2, 11, [-105001, 0, 8000], AIR_PATROL_PARAMS)) print('阻援歼击机JG起飞') self.agent_state = 14 if self.agent_state == 14: cmd_list.extend( self._takeoff_areapatrol(2, 11, [-105002, 0, 8000], AIR_PATROL_PARAMS)) print('阻援歼击机JH起飞') # print('阻援歼击机JI起飞') self.agent_state = 15 '''第二波次''' # 如果干扰机存活则突击南部,如果干扰机不在,则突击北部 hc = 0 if self.agent_state == 12 and sim_time > 2000: for disturb in obs_red['units']: if disturb['LX'] == 13: hc = 1 break if hc == 1: cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT4_0)) print('轰炸机HC起飞突击南部') self.agent_state = 16 else: cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT0)) print('轰炸机HC起飞突击北部') self.agent_state = 16 if self.agent_state == 16: if 'HC' in list(self.team_id_dic.keys()): cmd_list.extend(self._A2G_escort(self.team_id_dic['HC'])) print('给轰炸机HC护航') self.agent_state = 17 '''第三波次兵力列表(轰炸机6架, 歼击机4架)''' # 空战歼击机2机编队--JK if self.agent_state == 17: cmd_list.extend( self._takeoff_areapatrol(2, 11, AIR_PATROL_POINT7, AIR_PATROL_PARAMS)) print('阻援歼击机JK起飞') self.agent_state = 18 # 轰炸机6机编队--HD + 护航歼击机2机编队--JL if self.agent_state == 18: cmd_list.extend(self._takeoff_areahunt(2, AREA_HUNT_POINT4_0)) cmd_list.extend(self._takeoff_areahunt(4, AREA_HUNT_POINT4_0)) print('轰炸机HD起飞') self.agent_state = 19 if self.agent_state == 19: if 'HD' in list(self.team_id_dic.keys()): cmd_list.extend(self._A2G_escort(self.team_id_dic['HD'])) # cmd_list.extend(self._targethunt(self.team_id_dic['HD'], 5010)) print('给轰炸机HD护航') self.agent_state = 20 # 拦截 blue_lx_list = [] for blue_unit in obs_red['qb']: blue_lx_list.append(blue_unit['LX']) # 获取蓝方单位并且是存活状态 if blue_unit['LX'] == 11 or blue_unit['LX'] == 15 or blue_unit[ 'LX'] == 12: if blue_unit['WH'] == 1 and blue_unit[ 'ID'] not in self.blue_list: distances, fuel, weapons = [], [], [] a2a_units = [] for a2a in obs_red['units']: if a2a['LX'] == 11 and a2a[ 'Fuel'] > 3000 and '170' in list( a2a['WP'].keys()) and int( a2a['WP']['170']) > 0: distance = math.sqrt( math.pow(a2a['X'] - blue_unit['X'], 2) + math.pow(a2a['Y'] - blue_unit['Y'], 2)) distances.append(distance) weapons.append(int(a2a['WP']['170'])) fuel.append(a2a["Fuel"]) a2a_units.append(a2a) if len(a2a_units) != 0: live = len(distances) dead = config.a2a_LX11 - live distances = np.array(distances + [0] * dead) fuel = np.array(fuel + [0] * dead) weapons = np.array(weapons + [0] * dead) dead_mask = np.array([1] * live + [0] * dead) curr_state = np.hstack( (distances, fuel, weapons, dead_mask)) who_index = self.dqn._choose_action( distances, fuel, weapons, dead_mask) print("******************a2a_units length, index: ", len(a2a_units), who_index) a2a = a2a_units[who_index] self.done = True cmd_list.extend( self._airattack(a2a['ID'], blue_unit['ID'])) print("***************************执行战斗任务!!!!", a2a['ID'], " ATTACK ", blue_unit['ID']) self.blue_list.append(blue_unit['ID']) self.blue_dic[a2a['ID']] = blue_unit['ID'] # 红方将蓝方单位击落或者红方拦截飞机被蓝方击落 blue = 0 del_blue = False del_blue2 = False for blue_target in self.blue_list: for blue_unit in obs_red['qb']: if blue_target == blue_unit['ID']: blue = 1 break # 这里是判断我方位飞机是否存活 for a2a_id in list(self.blue_dic.keys()): for a2a in obs_red['units']: if a2a['ID'] == a2a_id: del_blue = True break if del_blue is False: self.reward = -5 self.blue_dic.pop(a2a_id) for a2a_id in list(self.blue_dic.keys()): if blue_target == self.blue_dic[a2a_id]: del_blue2 = True break if blue == 0 or del_blue is False or del_blue2 is False: self.blue_list.remove(blue_target) # 需根据红方飞机当前状态重新下指令 for a2a in obs_red['units']: # 此时对状态为15 或 13 的蓝方飞机进行判断 if a2a['LX'] == 11 and a2a['ID'] in list( self.blue_dic.keys()) and self.blue_dic[ a2a['ID']] == blue_target: self.reward = 5 if a2a['ST'] == 15 or a2a['ST'] == 13: # 如果油量小于4000或者子弹数量为0则返航,否者去预定区域进行区域巡逻 if a2a['Fuel'] < 4000 or int( a2a['WP']['170']) == 0: cmd_list.extend(self._returntobase(a2a['ID'])) else: for Tid in list(self.team_id_dic.keys()): if a2a['TMID'] == self.team_id_dic[Tid]: if Tid == 'JA': cmd_list.extend( self._areapatrol( a2a['ID'], AWACS_PATROL_POINT, AIR_PATROL_PARAMS_0)) if Tid == 'JB': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_DISTURB_POINT2, AIR_PATROL_PARAMS_0)) if Tid == 'JC': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT1, AIR_PATROL_PARAMS_0)) if Tid == 'JD': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT2, AIR_PATROL_PARAMS_0)) if Tid == 'JE': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT3, AIR_PATROL_PARAMS_0)) if Tid == 'JF': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT3, AIR_PATROL_PARAMS_0)) if Tid == 'JG': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT5, AIR_PATROL_PARAMS_0)) if Tid == 'JH': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT6, AIR_PATROL_PARAMS_0)) if Tid == 'JI': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT7, AIR_PATROL_PARAMS_0)) if Tid == 'JJ': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT7, AIR_PATROL_PARAMS_0)) if Tid == 'JK': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT7, AIR_PATROL_PARAMS_0)) if Tid == 'JL': cmd_list.extend( self._areapatrol( a2a['ID'], AIR_PATROL_POINT4, AIR_PATROL_PARAMS_0)) if self.done: self.dqn._store_transition(self.last_state, who_index, self.reward, curr_state) self.dqn._learning() self.last_state = curr_state self.done = False return cmd_list def _parse_observation(self, obs_red): self._parse_teams(obs_red) # 获取编队ID def _parse_teams(self, red_dict): for team in red_dict['teams']: if team['Task']: self.Task = json.loads(team['Task']) # print('self.Task:',self.Task) # 预警机编队 if team['LX'] == UnitType.AWACS: self.team_id_dic['YA'] = team['TMID'] # 干扰机编队 elif team['LX'] == UnitType.DISTURB: if 'fly_num' in list(self.Task.keys()): if self.Task['point_x'] == 45000 and self.Task[ 'point_y'] == 0: self.team_id_dic['RA'] = team['TMID'] # 轰炸机编队 elif team['LX'] == UnitType.A2G: if 'fly_num' in list(self.Task.keys()): # HA if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -55000 and self.Task[ 'point_y'] == 65000: self.team_id_dic['HA'] = team['TMID'] # HB elif self.Task['fly_num'] == 4 and self.Task[ 'point_x'] == -129533.05624 and self.Task[ 'point_y'] == 87664.0398: self.team_id_dic['HB'] = team['TMID'] # HC elif self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -131156.63859 and self.Task[ 'point_y'] == -87887.86736: self.team_id_dic['HC'] = team['TMID'] # HD elif self.Task['fly_num'] == 4 and self.Task[ 'point_x'] == -131156.63859 and self.Task[ 'point_y'] == -87887.86736: self.team_id_dic['HD'] = team['TMID'] # 歼击机编队 elif team['LX'] == UnitType.A2A: if self.Task['maintype'] == 'takeoffprotect': # JA if self.Task['cov_id'] == self.team_id_dic['YA']: self.team_id_dic['JA'] = team['TMID'] # JB elif self.Task['cov_id'] == self.team_id_dic['RA']: self.team_id_dic['JB'] = team['TMID'] # JC elif self.Task['cov_id'] == self.team_id_dic['HA']: self.team_id_dic['JC'] = team['TMID'] # JD elif self.Task['cov_id'] == self.team_id_dic['HB']: self.team_id_dic['JD'] = team['TMID'] # JF elif self.Task['cov_id'] == self.team_id_dic['HC']: self.team_id_dic['JF'] = team['TMID'] # JL elif self.Task['cov_id'] == self.team_id_dic['HD']: self.team_id_dic['JL'] = team['TMID'] else: # JE if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -105000 and self.Task[ 'point_y'] == 0: self.team_id_dic['JE'] = team['TMID'] # JG if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -105001 and self.Task[ 'point_y'] == 0: self.team_id_dic['JG'] = team['TMID'] # JH if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -105002 and self.Task[ 'point_y'] == 0: self.team_id_dic['JH'] = team['TMID'] # JI if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -45000 and self.Task[ 'point_y'] == -55000: self.team_id_dic['JI'] = team['TMID'] # JJ if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -45001 and self.Task[ 'point_y'] == -55000: self.team_id_dic['JJ'] = team['TMID'] # JK if 'fly_num' in list(self.Task.keys()): if self.Task['fly_num'] == 2 and self.Task[ 'point_x'] == -95000 and self.Task[ 'point_y'] == -85000: self.team_id_dic['JK'] = team['TMID'] # 无人机出击 @staticmethod def _uav_areapatrol(uav_id, uav_point, uav_params): return [EnvCmd.make_uav_areapatrol(uav_id, *uav_point, *uav_params)] # 预警机出击 @staticmethod def _awacs_patrol(self_id, AWACS_PATROL_POINT, AWACS_PATROL_PARAMS): return [ EnvCmd.make_awcs_areapatrol(self_id, *AWACS_PATROL_POINT, *AWACS_PATROL_PARAMS) ] # 预警机护航 def _awacs_escort(self, awacs_team_id): return [ EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, awacs_team_id, 0, 100, 250) ] # 干扰机进行区域干扰 def _disturb_patrol(self, disturb_team_id, patrol_point, patrol_params): return [ EnvCmd.make_disturb_areapatrol(disturb_team_id, *patrol_point, *patrol_params) ] # 干扰机进行航线干扰 def _disturb_linepatrol(self, self_id, point_list): return [ EnvCmd.make_disturb_linepatrol(self_id, 160, 0, 'line', point_list) ] # 轰炸机起飞突击 @staticmethod def _takeoff_areahunt(num, area_hunt_point): return [ EnvCmd.make_takeoff_areahunt(RED_AIRPORT_ID, num, 270, 80, *area_hunt_point, *[270, 1000, 1000, 160]) ] # 干扰机护航 def _disturb_escort(self, disturb_team_id): return [ EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, disturb_team_id, 1, 100, 250) ] # 轰炸机护航 def _A2G_escort(self, a2g_team_id): return [ EnvCmd.make_takeoff_protect(RED_AIRPORT_ID, 2, a2g_team_id, 1, 100, 250) ] # 起飞区域巡逻 @staticmethod def _takeoff_areapatrol(num, lx, patrol_point, patrol_params): # patrol_params为5个参数 return [ EnvCmd.make_takeoff_areapatrol(RED_AIRPORT_ID, num, lx, *patrol_point, *patrol_params) ] @staticmethod def _airattack(unit_id, target_id): return [EnvCmd.make_airattack(unit_id, target_id, 0)] # 区域巡逻 @staticmethod def _areapatrol(unit_id, patrol_point, patrol_params): # patrol_params为6个参数 return [EnvCmd.make_areapatrol(unit_id, *patrol_point, *patrol_params)] # 返航 @staticmethod def _returntobase(unit_id): return [EnvCmd.make_returntobase(unit_id, 30001)] # 轰炸机目标突击 @staticmethod def _targethunt(self_id, target_id): return [EnvCmd.make_targethunt(self_id, target_id, 270, 80)] # 轰炸机区域突击 @staticmethod def _areahunt(self_id, point): return [ EnvCmd.make_areahunt(self_id, 270, 80, *point, *AREA_HUNT_PARAMS) ] # 护卫舰区域巡逻 def _ship_areapatrol(self, self_id, point): return [ EnvCmd.make_ship_areapatrol(self_id, *point, *SHIP_PATROL_PARAMS_0) ] # 护卫舰初始化部署 def _ship_movedeploy(self, self_id, point): return [EnvCmd.make_ship_movedeploy(self_id, *point, 90, 1)]
writer = SummaryWriter() current = datetime.today().strftime('%Y%m%d%H%M%S') plot_episode_rewards = [] # 이건 에피소드 받은 리워드 ( 에이전트 동안 받은 개별 리워드 다 더한 값) plot_episode_valid_steps = [] # 에피소드별 action 요청이 하나라도 들어온 step 카운트 plot_episode_count_requested_agent = np.asarray( [0] * N_AGENTS) # 에이전트별 요청받은 에이전트 대수 기록 plot_episode_requested_agents = np.asarray([0] * N_AGENTS) plot_count_per_actions = np.asarray([0] * N_ACTION) plot_episode_epsilon = [] args = get_common_args() ## change policy as a DQN args = dqn_args(args) policy = DQN(args) agents = Agents(args, policy) env = gym.make('CartPole-v0') worker = RolloutWorker(env, agents, args) buffer = ReplayBuffer(args) plt.figure() plt.axis([0, args.n_epoch, 0, 100]) win_rates = [] episode_rewards = [] train_steps = 0 save_path = args.result_dir + '/' + current os.makedirs(save_path, exist_ok=True)
class Agent(): def __init__(self, action_set, hParam): h, w = 84, 84 self.qNetwork = DQN(h, w, len(action_set)) self.targetNetwork = DQN(h, w, len(action_set)) self.targetNetwork.load_state_dict(self.qNetwork.state_dict()) self.optimizer = optim.Adam(self.qNetwork.parameters(), lr=1e-4) self.loss_func = nn.MSELoss() self.memory = ReplayMemory(hParam['BUFFER_SIZE']) # self.DISCOUNT_FACTOR = hParam['DISCOUNT_FACTOR'] # 0.99 self.steps_done = 0 self.EPS_START = hParam['EPS_START'] # 1.0 self.EPS_END = hParam['EPS_END'] self.EPS_ITER = 1000000 self.MAX_ITER = hParam['MAX_ITER'] self.eps_threshold = self.EPS_START self.BATCH_SIZE = hParam['BATCH_SIZE'] self.n_actions = len(action_set) # 2 self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.qNetwork.to(self.device) self.targetNetwork.to(self.device) self.qNetwork.train() def updateTargetNet(self): self.targetNetwork.load_state_dict(self.qNetwork.state_dict()) def getAction(self, state): state = torch.from_numpy(state).float() / 255.0 sample = random.random() state = state.to(self.device) if sample > self.eps_threshold or self.steps_done > 1000000: estimate = self.qNetwork(state).max(1)[1].cpu() del state return estimate.data[0] else: return random.randint(0, self.n_actions - 1) def updateQnet(self): if len(self.memory) < self.BATCH_SIZE: return transitions = self.memory.sample(self.BATCH_SIZE) batch = Transition(*zip(*transitions)) state_batch = torch.cat(batch.state).to(self.device) action_batch = torch.cat(batch.action).to(self.device) next_state_batch = torch.cat(batch.next_state).to(self.device) reward_batch = torch.cat(batch.reward).to(self.device) done_batch = torch.cat(batch.done).to(self.device) with torch.no_grad(): self.targetNetwork.eval() next_state_values = self.targetNetwork(next_state_batch) y_batch = torch.cat(tuple(reward if done else reward + self.DISCOUNT_FACTOR * torch.max(value) for reward, done, value in zip(reward_batch, done_batch, next_state_values))) state_action_values = torch.sum(self.qNetwork(state_batch) * action_batch, dim=1) loss = self.loss_func(state_action_values, y_batch.detach()) self.optimizer.zero_grad() loss.backward() # for param in self.qNetwork.parameters(): # param.grad.data.clamp_(-1, 1) self.optimizer.step() self.updateEPS() return loss.data def updateEPS(self): self.steps_done += 1 if self.EPS_ITER >= self.steps_done: self.eps_threshold = self.EPS_END \ + ((self.EPS_START - self.EPS_END) \ * (self.EPS_ITER - self.steps_done) / self.EPS_ITER) else: self.eps_threshold=self.EPS_END # print('eps: ',self.eps_threshold) def save(self, path='checkpoint.pth.tar'): print('save') torch.save({ 'state_dict': self.qNetwork.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() }, path) def load(self, path='checkpoint.pth.tar'): print('load:', path) checkpoint = torch.load(path) self.qNetwork.load_state_dict(checkpoint['state_dict']) self.targetNetwork.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])